@@ -2,27 +2,30 @@ using Boltz, Lux, Random, LuxCUDA
2
2
using Reactant
3
3
using BenchmarkTools
4
4
5
+ Reactant. set_default_backend (" gpu" )
6
+
5
7
dev = gpu_device ()
6
8
7
9
model = Vision. ViT (:tiny );
8
10
ps, st = Lux. setup (Random. default_rng (), model);
9
11
10
- ps_gpu, st_gpu = ps, st |> dev;
12
+ ps_gpu, st_gpu = (ps, st) |> dev;
13
+
14
+ x = rand (Float32, 256 , 256 , 3 , 4 );
15
+ x_gpu = x |> dev;
11
16
12
- x = rand (Float32, 256 , 256 , 3 , 16 );
17
+ lux_timing = @benchmark begin
18
+ Lux. apply ($ model, $ x_gpu, $ ps_gpu, $ st_gpu)
19
+ CUDA. synchronize ()
20
+ end
13
21
14
22
x_ra = Reactant. to_rarray (x);
15
23
ps_ra = Reactant. to_rarray (ps);
16
24
st_ra = Reactant. to_rarray (st);
17
25
18
26
apply_compiled = @compile Lux. apply (model, x_ra, ps_ra, st_ra);
19
27
20
- lux_timing = @benchmark begin
21
- Lux. apply ($ model, $ x, $ ps, $ st)
22
- CUDA. synchronize ()
23
- end
24
-
25
28
reactant_timing = @benchmark begin
26
- res = $ apply_compiled ($ model, $ x_ra, $ ps_ra, $ st_ra)
29
+ res, _ = $ apply_compiled ($ model, $ x_ra, $ ps_ra, $ st_ra)
27
30
Reactant. synchronize (res)
28
31
end
0 commit comments