@@ -2,10 +2,10 @@ import ADTypes,
22 BenchmarkTools,
33 ComponentArrays,
44 DifferentiationInterface,
5- Enzyme,
65 Lux,
76 PkgBenchmark,
87 StableRNGs,
8+ Zygote,
99 ContinuousNormalizingFlows
1010
1111SUITE = BenchmarkTools. BenchmarkGroup()
@@ -33,12 +33,7 @@ icnf = ContinuousNormalizingFlows.construct(
3333 nn,
3434 nvars,
3535 naugs;
36- compute_mode = ContinuousNormalizingFlows. DIVecJacMatrixMode(
37- ADTypes. AutoEnzyme(;
38- mode = Enzyme. set_runtime_activity(Enzyme. Reverse),
39- function_annotation = Enzyme. Const,
40- ),
41- ),
36+ compute_mode = ContinuousNormalizingFlows. DIVecJacMatrixMode(ADTypes. AutoZygote()),
4237 tspan = (0.0f0 , 13.0f0 ),
4338 steer_rate = 1.0f-1 ,
4439 λ₃ = 1.0f-2 ,
5752
5853diff_loss_tn(ps)
5954diff_loss_tt(ps)
60- DifferentiationInterface. gradient(
61- diff_loss_tn,
62- ADTypes. AutoEnzyme(;
63- mode = Enzyme. set_runtime_activity(Enzyme. Reverse),
64- function_annotation = Enzyme. Const,
65- ),
66- ps,
67- )
68- DifferentiationInterface. gradient(
69- diff_loss_tt,
70- ADTypes. AutoEnzyme(;
71- mode = Enzyme. set_runtime_activity(Enzyme. Reverse),
72- function_annotation = Enzyme. Const,
73- ),
74- ps,
75- )
55+ DifferentiationInterface. gradient(diff_loss_tn, ADTypes. AutoZygote(), ps)
56+ DifferentiationInterface. gradient(diff_loss_tt, ADTypes. AutoZygote(), ps)
7657GC. gc()
7758
7859SUITE[" main" ][" no_inplace" ][" direct" ][" train" ] =
@@ -82,19 +63,13 @@ SUITE["main"]["no_inplace"]["direct"]["test"] =
8263SUITE[" main" ][" no_inplace" ][" AD-1-order" ][" train" ] =
8364 BenchmarkTools. @benchmarkable DifferentiationInterface. gradient(
8465 diff_loss_tn,
85- ADTypes. AutoEnzyme(;
86- mode = Enzyme. set_runtime_activity(Enzyme. Reverse),
87- function_annotation = Enzyme. Const,
88- ),
66+ ADTypes. AutoZygote(),
8967 ps,
9068 )
9169SUITE[" main" ][" no_inplace" ][" AD-1-order" ][" test" ] =
9270 BenchmarkTools. @benchmarkable DifferentiationInterface. gradient(
9371 diff_loss_tt,
94- ADTypes. AutoEnzyme(;
95- mode = Enzyme. set_runtime_activity(Enzyme. Reverse),
96- function_annotation = Enzyme. Const,
97- ),
72+ ADTypes. AutoZygote(),
9873 ps,
9974 )
10075
@@ -104,12 +79,7 @@ icnf2 = ContinuousNormalizingFlows.construct(
10479 nvars,
10580 naugs;
10681 inplace = true ,
107- compute_mode = ContinuousNormalizingFlows. DIVecJacMatrixMode(
108- ADTypes. AutoEnzyme(;
109- mode = Enzyme. set_runtime_activity(Enzyme. Reverse),
110- function_annotation = Enzyme. Const,
111- ),
112- ),
82+ compute_mode = ContinuousNormalizingFlows. DIVecJacMatrixMode(ADTypes. AutoZygote()),
11383 tspan = (0.0f0 , 13.0f0 ),
11484 steer_rate = 1.0f-1 ,
11585 λ₃ = 1.0f-2 ,
12595
12696diff_loss_tn2(ps)
12797diff_loss_tt2(ps)
128- DifferentiationInterface. gradient(
129- diff_loss_tn2,
130- ADTypes. AutoEnzyme(;
131- mode = Enzyme. set_runtime_activity(Enzyme. Reverse),
132- function_annotation = Enzyme. Const,
133- ),
134- ps,
135- )
136- DifferentiationInterface. gradient(
137- diff_loss_tt2,
138- ADTypes. AutoEnzyme(;
139- mode = Enzyme. set_runtime_activity(Enzyme. Reverse),
140- function_annotation = Enzyme. Const,
141- ),
142- ps,
143- )
98+ DifferentiationInterface. gradient(diff_loss_tn2, ADTypes. AutoZygote(), ps)
99+ DifferentiationInterface. gradient(diff_loss_tt2, ADTypes. AutoZygote(), ps)
144100GC. gc()
145101
146102SUITE[" main" ][" inplace" ][" direct" ][" train" ] =
@@ -149,18 +105,12 @@ SUITE["main"]["inplace"]["direct"]["test"] = BenchmarkTools.@benchmarkable diff_
149105SUITE[" main" ][" inplace" ][" AD-1-order" ][" train" ] =
150106 BenchmarkTools. @benchmarkable DifferentiationInterface. gradient(
151107 diff_loss_tn2,
152- ADTypes. AutoEnzyme(;
153- mode = Enzyme. set_runtime_activity(Enzyme. Reverse),
154- function_annotation = Enzyme. Const,
155- ),
108+ ADTypes. AutoZygote(),
156109 ps,
157110 )
158111SUITE[" main" ][" inplace" ][" AD-1-order" ][" test" ] =
159112 BenchmarkTools. @benchmarkable DifferentiationInterface. gradient(
160113 diff_loss_tt2,
161- ADTypes. AutoEnzyme(;
162- mode = Enzyme. set_runtime_activity(Enzyme. Reverse),
163- function_annotation = Enzyme. Const,
164- ),
114+ ADTypes. AutoZygote(),
165115 ps,
166116 )
0 commit comments