@@ -22,6 +22,7 @@ Test.@testset "Smoke Tests" begin
2222 data_types = Type{<: AbstractFloat }[Float32]
2323 devices = MLDataDevices. AbstractDevice[MLDataDevices. cpu_device()]
2424 adtypes = ADTypes. AbstractADType[ADTypes. AutoZygote(),
25+ # ADTypes.AutoForwardDiff(),
2526 # ADTypes.AutoEnzyme(;
2627 # mode = Enzyme.set_runtime_activity(Enzyme.Reverse),
2728 # function_annotation = Enzyme.Const,
@@ -30,27 +31,29 @@ Test.@testset "Smoke Tests" begin
3031 # mode = Enzyme.set_runtime_activity(Enzyme.Forward),
3132 # function_annotation = Enzyme.Const,
3233 # ),
33- # ADTypes.AutoForwardDiff(),
3434 ]
3535 compute_modes = ContinuousNormalizingFlows. ComputeMode[
3636 ContinuousNormalizingFlows. LuxVecJacMatrixMode(ADTypes. AutoZygote()),
3737 ContinuousNormalizingFlows. DIVecJacVectorMode(ADTypes. AutoZygote()),
3838 ContinuousNormalizingFlows. DIVecJacMatrixMode(ADTypes. AutoZygote()),
39+ ContinuousNormalizingFlows. LuxJacVecMatrixMode(ADTypes. AutoForwardDiff()),
40+ ContinuousNormalizingFlows. DIJacVecVectorMode(ADTypes. AutoForwardDiff()),
41+ ContinuousNormalizingFlows. DIJacVecMatrixMode(ADTypes. AutoForwardDiff()),
3942 ContinuousNormalizingFlows. DIVecJacVectorMode(
4043 ADTypes. AutoEnzyme(;
4144 mode = Enzyme. set_runtime_activity(Enzyme. Reverse),
4245 function_annotation = Enzyme. Const,
4346 ),
4447 ),
45- ContinuousNormalizingFlows. DIJacVecVectorMode (
48+ ContinuousNormalizingFlows. DIVecJacMatrixMode (
4649 ADTypes. AutoEnzyme(;
47- mode = Enzyme. set_runtime_activity(Enzyme. Forward ),
50+ mode = Enzyme. set_runtime_activity(Enzyme. Reverse ),
4851 function_annotation = Enzyme. Const,
4952 ),
5053 ),
51- ContinuousNormalizingFlows. DIVecJacMatrixMode (
54+ ContinuousNormalizingFlows. DIJacVecVectorMode (
5255 ADTypes. AutoEnzyme(;
53- mode = Enzyme. set_runtime_activity(Enzyme. Reverse ),
56+ mode = Enzyme. set_runtime_activity(Enzyme. Forward ),
5457 function_annotation = Enzyme. Const,
5558 ),
5659 ),
@@ -60,9 +63,6 @@ Test.@testset "Smoke Tests" begin
6063 function_annotation = Enzyme. Const,
6164 ),
6265 ),
63- ContinuousNormalizingFlows. LuxJacVecMatrixMode(ADTypes. AutoForwardDiff()),
64- ContinuousNormalizingFlows. DIJacVecVectorMode(ADTypes. AutoForwardDiff()),
65- ContinuousNormalizingFlows. DIJacVecMatrixMode(ADTypes. AutoForwardDiff()),
6666 ]
6767
6868 Test. @testset " $device | $data_type | $compute_mode | ndata = $ndata | nvars = $nvars | inplace = $inplace | cond = $cond | planar = $planar | $omode | $mt " for device in
@@ -193,6 +193,11 @@ Test.@testset "Smoke Tests" begin
193193 Test. @test ! isnothing(rand(d))
194194 Test. @test ! isnothing(rand(d, ndata))
195195
196+ if GROUP != " All" &&
197+ compute_mode. adback isa ADTypes. AutoEnzyme{<: Enzyme.ForwardMode }
198+ continue
199+ end
200+
196201 Test. @testset " $adtype on loss" for adtype in adtypes
197202 Test. @test ! isnothing(DifferentiationInterface. gradient(diff_loss, adtype, ps))
198203 Test. @test ! isnothing(DifferentiationInterface. gradient(diff2_loss, adtype, r))
0 commit comments