Skip to content

Commit c51a517

Browse files
committed
perf: FNOs
1 parent 89b526e commit c51a517

File tree

1 file changed

+22
-0
lines changed

1 file changed

+22
-0
lines changed

perf/neuraloperators/main.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,4 +36,26 @@ function run_deeponet_benchmarks()
3636
return nothing
3737
end
3838

39+
function run_fno_benchmarks()
40+
@info "Running FNO benchmarks"
41+
42+
model = FourierNeuralOperator((16, 16), 3, 8, 64)
43+
ps, st = xdev(Lux.setup(Random.default_rng(), model))
44+
x = xdev(rand(Float32, 64, 64, 1, 256))
45+
46+
primal_timings = Reactant.with_config(;
47+
dot_general_precision=PrecisionConfig.HIGH,
48+
convolution_precision=PrecisionConfig.HIGH,
49+
) do
50+
benchmark_nn_primal(
51+
model, x, ps, st; disable_scatter_gather_bench=true, disable_pad_bench=true
52+
)
53+
end
54+
55+
pretty_print_table(permutedims(hcat([[t...] for t in primal_timings]...), (2, 1)))
56+
57+
return nothing
58+
end
59+
60+
run_fno_benchmarks()
3961
run_deeponet_benchmarks()

0 commit comments

Comments
 (0)