Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions test/checkby_JET_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,13 @@ Test.@testset "CheckByJET" begin
ps = device(ps)
st = device(st)

if GROUP != "All" &&
compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} &&
planar &&
VERSION >= v"1.11"
continue
end

if cond
ContinuousNormalizingFlows.loss(icnf, omode, r, r2, ps, st)
JET.test_call(
Expand Down
125 changes: 108 additions & 17 deletions test/smoke_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -200,14 +200,25 @@ Test.@testset "Smoke Tests" begin
Test.@test !isnothing(rand(d))
Test.@test !isnothing(rand(d, ndata))

if GROUP != "All" &&
compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode}
continue
end

Test.@testset "$adtype on loss" for adtype in adtypes
Test.@test !isnothing(DifferentiationInterface.gradient(diff_loss, adtype, ps))
Test.@test !isnothing(DifferentiationInterface.gradient(diff2_loss, adtype, r))
Test.@test !isnothing(DifferentiationInterface.gradient(diff_loss, adtype, ps)) broken =
GROUP != "All" &&
compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} &&
(
omode isa ContinuousNormalizingFlows.TrainMode || (
omode isa ContinuousNormalizingFlows.TestMode &&
compute_mode isa ContinuousNormalizingFlows.VectorMode
)
)
Test.@test !isnothing(DifferentiationInterface.gradient(diff2_loss, adtype, r)) broken =
GROUP != "All" &&
compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} &&
(
omode isa ContinuousNormalizingFlows.TrainMode || (
omode isa ContinuousNormalizingFlows.TestMode &&
compute_mode isa ContinuousNormalizingFlows.VectorMode
)
)

if cond
model = ContinuousNormalizingFlows.CondICNFModel(
Expand All @@ -218,14 +229,54 @@ Test.@testset "Smoke Tests" begin
)
mach = MLJBase.machine(model, (df, df2))

Test.@test !isnothing(MLJBase.fit!(mach))
Test.@test !isnothing(MLJBase.transform(mach, (df, df2)))
Test.@test !isnothing(MLJBase.fitted_params(mach))
Test.@test !isnothing(MLJBase.serializable(mach))
Test.@test !isnothing(MLJBase.fit!(mach)) broken =
GROUP != "All" &&
compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} &&
(
omode isa ContinuousNormalizingFlows.TrainMode || (
omode isa ContinuousNormalizingFlows.TestMode &&
compute_mode isa ContinuousNormalizingFlows.VectorMode
)
)
Test.@test !isnothing(MLJBase.transform(mach, (df, df2))) broken =
GROUP != "All" &&
compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} &&
(
omode isa ContinuousNormalizingFlows.TrainMode || (
omode isa ContinuousNormalizingFlows.TestMode &&
compute_mode isa ContinuousNormalizingFlows.VectorMode
)
)
Test.@test !isnothing(MLJBase.fitted_params(mach)) broken =
GROUP != "All" &&
compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} &&
(
omode isa ContinuousNormalizingFlows.TrainMode || (
omode isa ContinuousNormalizingFlows.TestMode &&
compute_mode isa ContinuousNormalizingFlows.VectorMode
)
)
Test.@test !isnothing(MLJBase.serializable(mach)) broken =
GROUP != "All" &&
compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} &&
(
omode isa ContinuousNormalizingFlows.TrainMode || (
omode isa ContinuousNormalizingFlows.TestMode &&
compute_mode isa ContinuousNormalizingFlows.VectorMode
)
)

Test.@test !isnothing(
ContinuousNormalizingFlows.CondICNFDist(mach, omode, r2),
)
) broken =
GROUP != "All" &&
compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} &&
(
omode isa ContinuousNormalizingFlows.TrainMode || (
omode isa ContinuousNormalizingFlows.TestMode &&
compute_mode isa ContinuousNormalizingFlows.VectorMode
)
)
else
model = ContinuousNormalizingFlows.ICNFModel(
icnf;
Expand All @@ -235,12 +286,52 @@ Test.@testset "Smoke Tests" begin
)
mach = MLJBase.machine(model, df)

Test.@test !isnothing(MLJBase.fit!(mach))
Test.@test !isnothing(MLJBase.transform(mach, df))
Test.@test !isnothing(MLJBase.fitted_params(mach))
Test.@test !isnothing(MLJBase.serializable(mach))
Test.@test !isnothing(MLJBase.fit!(mach)) broken =
GROUP != "All" &&
compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} &&
(
omode isa ContinuousNormalizingFlows.TrainMode || (
omode isa ContinuousNormalizingFlows.TestMode &&
compute_mode isa ContinuousNormalizingFlows.VectorMode
)
)
Test.@test !isnothing(MLJBase.transform(mach, df)) broken =
GROUP != "All" &&
compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} &&
(
omode isa ContinuousNormalizingFlows.TrainMode || (
omode isa ContinuousNormalizingFlows.TestMode &&
compute_mode isa ContinuousNormalizingFlows.VectorMode
)
)
Test.@test !isnothing(MLJBase.fitted_params(mach)) broken =
GROUP != "All" &&
compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} &&
(
omode isa ContinuousNormalizingFlows.TrainMode || (
omode isa ContinuousNormalizingFlows.TestMode &&
compute_mode isa ContinuousNormalizingFlows.VectorMode
)
)
Test.@test !isnothing(MLJBase.serializable(mach)) broken =
GROUP != "All" &&
compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} &&
(
omode isa ContinuousNormalizingFlows.TrainMode || (
omode isa ContinuousNormalizingFlows.TestMode &&
compute_mode isa ContinuousNormalizingFlows.VectorMode
)
)

Test.@test !isnothing(ContinuousNormalizingFlows.ICNFDist(mach, omode))
Test.@test !isnothing(ContinuousNormalizingFlows.ICNFDist(mach, omode)) broken =
GROUP != "All" &&
compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} &&
(
omode isa ContinuousNormalizingFlows.TrainMode || (
omode isa ContinuousNormalizingFlows.TestMode &&
compute_mode isa ContinuousNormalizingFlows.VectorMode
)
)
end
end
end
Expand Down
Loading