Skip to content

Commit f724d7e

Browse files
committed
fix: correct GPU benchmarking
1 parent db7dbab commit f724d7e

File tree

1 file changed

+11
-8
lines changed

1 file changed

+11
-8
lines changed

benchmark/vit.jl

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,27 +2,30 @@ using Boltz, Lux, Random, LuxCUDA
22
using Reactant
33
using BenchmarkTools
44

5+
Reactant.set_default_backend("gpu")
6+
57
dev = gpu_device()
68

79
model = Vision.ViT(:tiny);
810
ps, st = Lux.setup(Random.default_rng(), model);
911

10-
ps_gpu, st_gpu = ps, st |> dev;
12+
ps_gpu, st_gpu = (ps, st) |> dev;
13+
14+
x = rand(Float32, 256, 256, 3, 4);
15+
x_gpu = x |> dev;
1116

12-
x = rand(Float32, 256, 256, 3, 16);
17+
lux_timing = @benchmark begin
18+
Lux.apply($model, $x_gpu, $ps_gpu, $st_gpu)
19+
CUDA.synchronize()
20+
end
1321

1422
x_ra = Reactant.to_rarray(x);
1523
ps_ra = Reactant.to_rarray(ps);
1624
st_ra = Reactant.to_rarray(st);
1725

1826
apply_compiled = @compile Lux.apply(model, x_ra, ps_ra, st_ra);
1927

20-
lux_timing = @benchmark begin
21-
Lux.apply($model, $x, $ps, $st)
22-
CUDA.synchronize()
23-
end
24-
2528
reactant_timing = @benchmark begin
26-
res = $apply_compiled($model, $x_ra, $ps_ra, $st_ra)
29+
res, _ = $apply_compiled($model, $x_ra, $ps_ra, $st_ra)
2730
Reactant.synchronize(res)
2831
end

0 commit comments

Comments
 (0)