diff --git a/test/checkby_JET_tests.jl b/test/checkby_JET_tests.jl index 6119889e..021bb0d9 100644 --- a/test/checkby_JET_tests.jl +++ b/test/checkby_JET_tests.jl @@ -116,6 +116,13 @@ Test.@testset "CheckByJET" begin ps = device(ps) st = device(st) + if GROUP != "All" && + compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} && + planar && + VERSION >= v"1.11" + continue + end + if cond ContinuousNormalizingFlows.loss(icnf, omode, r, r2, ps, st) JET.test_call( diff --git a/test/smoke_tests.jl b/test/smoke_tests.jl index 99f6f708..377130b6 100644 --- a/test/smoke_tests.jl +++ b/test/smoke_tests.jl @@ -200,14 +200,25 @@ Test.@testset "Smoke Tests" begin Test.@test !isnothing(rand(d)) Test.@test !isnothing(rand(d, ndata)) - if GROUP != "All" && - compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} - continue - end - Test.@testset "$adtype on loss" for adtype in adtypes - Test.@test !isnothing(DifferentiationInterface.gradient(diff_loss, adtype, ps)) - Test.@test !isnothing(DifferentiationInterface.gradient(diff2_loss, adtype, r)) + Test.@test !isnothing(DifferentiationInterface.gradient(diff_loss, adtype, ps)) broken = + GROUP != "All" && + compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} && + ( + omode isa ContinuousNormalizingFlows.TrainMode || ( + omode isa ContinuousNormalizingFlows.TestMode && + compute_mode isa ContinuousNormalizingFlows.VectorMode + ) + ) + Test.@test !isnothing(DifferentiationInterface.gradient(diff2_loss, adtype, r)) broken = + GROUP != "All" && + compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} && + ( + omode isa ContinuousNormalizingFlows.TrainMode || ( + omode isa ContinuousNormalizingFlows.TestMode && + compute_mode isa ContinuousNormalizingFlows.VectorMode + ) + ) if cond model = ContinuousNormalizingFlows.CondICNFModel( @@ -218,14 +229,54 @@ Test.@testset "Smoke Tests" begin ) mach = MLJBase.machine(model, (df, df2)) - Test.@test !isnothing(MLJBase.fit!(mach)) - Test.@test !isnothing(MLJBase.transform(mach, (df, df2))) - Test.@test !isnothing(MLJBase.fitted_params(mach)) - Test.@test !isnothing(MLJBase.serializable(mach)) + Test.@test !isnothing(MLJBase.fit!(mach)) broken = + GROUP != "All" && + compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} && + ( + omode isa ContinuousNormalizingFlows.TrainMode || ( + omode isa ContinuousNormalizingFlows.TestMode && + compute_mode isa ContinuousNormalizingFlows.VectorMode + ) + ) + Test.@test !isnothing(MLJBase.transform(mach, (df, df2))) broken = + GROUP != "All" && + compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} && + ( + omode isa ContinuousNormalizingFlows.TrainMode || ( + omode isa ContinuousNormalizingFlows.TestMode && + compute_mode isa ContinuousNormalizingFlows.VectorMode + ) + ) + Test.@test !isnothing(MLJBase.fitted_params(mach)) broken = + GROUP != "All" && + compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} && + ( + omode isa ContinuousNormalizingFlows.TrainMode || ( + omode isa ContinuousNormalizingFlows.TestMode && + compute_mode isa ContinuousNormalizingFlows.VectorMode + ) + ) + Test.@test !isnothing(MLJBase.serializable(mach)) broken = + GROUP != "All" && + compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} && + ( + omode isa ContinuousNormalizingFlows.TrainMode || ( + omode isa ContinuousNormalizingFlows.TestMode && + compute_mode isa ContinuousNormalizingFlows.VectorMode + ) + ) Test.@test !isnothing( ContinuousNormalizingFlows.CondICNFDist(mach, omode, r2), - ) + ) broken = + GROUP != "All" && + compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} && + ( + omode isa ContinuousNormalizingFlows.TrainMode || ( + omode isa ContinuousNormalizingFlows.TestMode && + compute_mode isa ContinuousNormalizingFlows.VectorMode + ) + ) else model = ContinuousNormalizingFlows.ICNFModel( icnf; @@ -235,12 +286,52 @@ Test.@testset "Smoke Tests" begin ) mach = MLJBase.machine(model, df) - Test.@test !isnothing(MLJBase.fit!(mach)) - Test.@test !isnothing(MLJBase.transform(mach, df)) - Test.@test !isnothing(MLJBase.fitted_params(mach)) - Test.@test !isnothing(MLJBase.serializable(mach)) + Test.@test !isnothing(MLJBase.fit!(mach)) broken = + GROUP != "All" && + compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} && + ( + omode isa ContinuousNormalizingFlows.TrainMode || ( + omode isa ContinuousNormalizingFlows.TestMode && + compute_mode isa ContinuousNormalizingFlows.VectorMode + ) + ) + Test.@test !isnothing(MLJBase.transform(mach, df)) broken = + GROUP != "All" && + compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} && + ( + omode isa ContinuousNormalizingFlows.TrainMode || ( + omode isa ContinuousNormalizingFlows.TestMode && + compute_mode isa ContinuousNormalizingFlows.VectorMode + ) + ) + Test.@test !isnothing(MLJBase.fitted_params(mach)) broken = + GROUP != "All" && + compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} && + ( + omode isa ContinuousNormalizingFlows.TrainMode || ( + omode isa ContinuousNormalizingFlows.TestMode && + compute_mode isa ContinuousNormalizingFlows.VectorMode + ) + ) + Test.@test !isnothing(MLJBase.serializable(mach)) broken = + GROUP != "All" && + compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} && + ( + omode isa ContinuousNormalizingFlows.TrainMode || ( + omode isa ContinuousNormalizingFlows.TestMode && + compute_mode isa ContinuousNormalizingFlows.VectorMode + ) + ) - Test.@test !isnothing(ContinuousNormalizingFlows.ICNFDist(mach, omode)) + Test.@test !isnothing(ContinuousNormalizingFlows.ICNFDist(mach, omode)) broken = + GROUP != "All" && + compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} && + ( + omode isa ContinuousNormalizingFlows.TrainMode || ( + omode isa ContinuousNormalizingFlows.TestMode && + compute_mode isa ContinuousNormalizingFlows.VectorMode + ) + ) end end end