Skip to content

Commit d786975

Browse files
authored
Merge pull request #130 from EnzymeAD/ap/vit
feat: compiling vision transformers
2 parents 831fbdc + 7b71592 commit d786975

File tree

9 files changed

+297
-217
lines changed

9 files changed

+297
-217
lines changed

benchmark/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Boltz = "4544d5e4-abc5-4dea-817f-29e4c205d9c8"
44
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
55
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
66
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
7+
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
78

89
[compat]
910
BenchmarkTools = "1.5"

benchmark/benchmarks.jl

Lines changed: 48 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,27 +5,68 @@ using Boltz, Lux, Random
55

66
const SUITE = BenchmarkGroup()
77

8+
SUITE["runtime"] = BenchmarkGroup()
89
SUITE["comptime"] = BenchmarkGroup()
910

1011
SUITE["comptime"]["basics"] = BenchmarkGroup()
1112
SUITE["comptime"]["basics"]["2D sum"] = @benchmarkable Reactant.compile(sum, (a,)) setup = (
1213
a = Reactant.ConcreteRArray(ones(2, 10))
1314
)
14-
SUITE["comptime"]["basics"]["Basic cos"] = @benchmarkable Reactant.compile(cos, (a,)) setup = (
15+
16+
bcast_cos(x) = cos.(x)
17+
18+
SUITE["comptime"]["basics"]["cos.(x)"] = @benchmarkable begin
19+
Reactant.compile(bcast_cos, (a,))
20+
end setup = begin
1521
a = Reactant.ConcreteRArray(ones(2, 10))
16-
)
22+
end
1723

24+
SUITE["runtime"]["lux neural networks"] = BenchmarkGroup()
1825
SUITE["comptime"]["lux neural networks"] = BenchmarkGroup()
1926

20-
for depth in [11, 13, 16, 19]
21-
SUITE["comptime"]["lux neural networks"]["vgg$depth"] = @benchmarkable Reactant.compile(
22-
vgg, (x, ps_concrete, st_concrete)
23-
) setup = begin
24-
vgg = Vision.VGG($depth; pretrained=false, batchnorm=false)
27+
for depth in [11, 13, 16, 19], batchnorm in [false]# true] <-- not working yet
28+
SUITE["comptime"]["lux neural networks"]["vgg$(depth) bn=$(batchnorm)"] = @benchmarkable begin
29+
@compile vgg(x, ps_concrete, st_concrete)
30+
end setup = begin
31+
vgg = Vision.VGG($depth; pretrained=false, batchnorm=$(batchnorm))
32+
ps, st = Lux.setup(Random.default_rng(), vgg)
33+
ps_concrete = Reactant.to_rarray(ps)
34+
st_concrete = Reactant.to_rarray(Lux.testmode(st))
35+
x = Reactant.to_rarray(rand(Float32, 224, 224, 3, 16))
36+
end
37+
38+
SUITE["runtime"]["lux neural networks"]["vgg$(depth) bn=$(batchnorm) (compiled)"] = @benchmarkable begin
39+
vgg_compiled(x, ps_concrete, st_concrete)
40+
end setup = begin
41+
vgg = Vision.VGG($depth; pretrained=false, batchnorm=$(batchnorm))
2542
ps, st = Lux.setup(Random.default_rng(), vgg)
2643
ps_concrete = Reactant.to_rarray(ps)
2744
st_concrete = Reactant.to_rarray(Lux.testmode(st))
2845
x = Reactant.to_rarray(rand(Float32, 224, 224, 3, 16))
46+
vgg_compiled = @compile vgg(x, ps_concrete, st_concrete)
47+
end
48+
end
49+
50+
for version in (:tiny, :base)
51+
SUITE["comptime"]["lux neural networks"]["ViT $(version)"] = @benchmarkable begin
52+
@compile vit(x, ps_concrete, st_concrete)
53+
end setup = begin
54+
vit = Vision.ViT($(Meta.quot(version)))
55+
ps, st = Lux.setup(Random.default_rng(), vit)
56+
ps_concrete = Reactant.to_rarray(ps)
57+
st_concrete = Reactant.to_rarray(Lux.testmode(st))
58+
x = Reactant.to_rarray(rand(Float32, 256, 256, 3, 16))
59+
end
60+
61+
SUITE["runtime"]["lux neural networks"]["ViT $(version) (compiled)"] = @benchmarkable begin
62+
vit_compiled(x, ps_concrete, st_concrete)
63+
end setup = begin
64+
vit = Vision.ViT($(Meta.quot(version)))
65+
ps, st = Lux.setup(Random.default_rng(), vit)
66+
ps_concrete = Reactant.to_rarray(ps)
67+
st_concrete = Reactant.to_rarray(Lux.testmode(st))
68+
x = Reactant.to_rarray(rand(Float32, 256, 256, 3, 16))
69+
vit_compiled = @compile vit(x, ps_concrete, st_concrete)
2970
end
3071
end
3172

ext/ReactantNNlibExt.jl

Lines changed: 54 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module ReactantNNlibExt
22

33
using NNlib
4-
using Reactant: Reactant, TracedRArray, AnyTracedRArray, materialize_traced_array
4+
using Reactant: Reactant, TracedRArray, AnyTracedRArray, materialize_traced_array, MLIR
55

66
for (jlop, hloop) in (
77
(:(NNlib.tanh_fast), :tanh),
@@ -19,12 +19,11 @@ for (jlop, hloop) in (
1919
end
2020
end
2121

22-
NNlib.relu(x::TracedRArray{T,0}) where {T} = max(x, zero(T))
23-
24-
function NNlib.gelu(x::TracedRArray{T,0}) where {T}
25-
α = T(0.044715)
26-
λλ = T((8 / π))
27-
return x * sigmoid(λλ * x * muladd(x^2, α, one(T)))
22+
# Don't confuse our poor scalar arrays, we no like numbers we like 0D arrays
23+
for nnlib_op in setdiff(Tuple(NNlib.ACTIVATIONS), (:tanh_fast, :sigmoid_fast, :sigmoid, ))
24+
@eval function NNlib.$(nnlib_op)(x::TracedRArray{T,0}) where {T}
25+
return invoke(NNlib.$(nnlib_op), Tuple{Any}, x)
26+
end
2827
end
2928

3029
# TODO handle non finite cases
@@ -206,4 +205,52 @@ end
206205
NNlib.batched_transpose(x::AnyTracedRArray{T,3}) where {T} = permutedims(x, (2, 1, 3))
207206
NNlib.batched_adjoint(x::AnyTracedRArray{<:Real,3}) = NNlib.batched_transpose(x)
208207

208+
function NNlib.batched_mul(x::AnyTracedRArray{T,3}, y::AnyTracedRArray{T,3}) where {T}
209+
if (size(x, 3) != size(y, 3) && size(x, 3) != 1 && size(y, 3) != 1) ||
210+
(size(x, 2) != size(y, 1))
211+
throw(
212+
DimensionMismatch(
213+
lazy"size(x) = $(size(x)), size(y) = $(size(y)) inconsistent for batched_matmul.",
214+
),
215+
)
216+
end
217+
x = permutedims(x, (3, 1, 2))
218+
y = permutedims(y, (3, 1, 2))
219+
220+
B = max(size(x, 1), size(y, 1))
221+
out_shape = (B, size(x, 2), size(y, 3))
222+
resty = MLIR.IR.TensorType(out_shape, eltype(MLIR.IR.type(x.mlir_data)))
223+
224+
if size(x, 1) != size(y, 1)
225+
if size(x, 1) == 1
226+
x = Reactant.broadcast_to_size(x, (B, size(x, 2), size(x, 3)))
227+
elseif size(y, 1) == 1
228+
y = Reactant.broadcast_to_size(y, (B, size(y, 2), size(y, 3)))
229+
end
230+
end
231+
232+
dot_dimension_numbers = MLIR.API.stablehloDotDimensionNumbersGet(
233+
MLIR.IR.context(), 1, [0], 1, [0], 1, [2], 1, [1]
234+
)
235+
236+
prec = MLIR.IR.Attribute(
237+
MLIR.API.stablehloPrecisionAttrGet(MLIR.IR.context(), "DEFAULT")
238+
)
239+
res = TracedRArray{T,3}(
240+
(),
241+
MLIR.IR.result(
242+
MLIR.Dialects.stablehlo.dot_general(
243+
x.mlir_data,
244+
y.mlir_data;
245+
result_0=resty,
246+
dot_dimension_numbers=dot_dimension_numbers,
247+
precision_config=prec,
248+
),
249+
1,
250+
),
251+
size(resty),
252+
)
253+
return permutedims(res, (2, 3, 1))
254+
end
255+
209256
end # module ReactantNNlibExt

src/ConcreteRArray.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,11 @@ function Base.convert(::Type{T}, X::ConcreteRArray{ElType,N}) where {T<:Array,El
4848
# XLA.from_row_major(data)
4949
end
5050

51+
function synchronize(x::ConcreteRArray)
52+
XLA.synced_buffer(x.data)
53+
return nothing
54+
end
55+
5156
# function Base.similar(x::ConcreteRArray{T,N}, ::Type{T2}) where {T,N,T2}
5257
# return ConcreteRArray{T,N}(x.data)
5358
# end

0 commit comments

Comments
 (0)