Skip to content

Commit d3445cd

Browse files
authored
test more enzyme (#491)
* test more enzyme * mark them broken * fix maybe
1 parent a27c531 commit d3445cd

File tree

2 files changed

+115
-17
lines changed

2 files changed

+115
-17
lines changed

test/checkby_JET_tests.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,13 @@ Test.@testset "CheckByJET" begin
116116
ps = device(ps)
117117
st = device(st)
118118

119+
if GROUP != "All" &&
120+
compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} &&
121+
planar &&
122+
VERSION >= v"1.11"
123+
continue
124+
end
125+
119126
if cond
120127
ContinuousNormalizingFlows.loss(icnf, omode, r, r2, ps, st)
121128
JET.test_call(

test/smoke_tests.jl

Lines changed: 108 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -200,14 +200,25 @@ Test.@testset "Smoke Tests" begin
200200
Test.@test !isnothing(rand(d))
201201
Test.@test !isnothing(rand(d, ndata))
202202

203-
if GROUP != "All" &&
204-
compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode}
205-
continue
206-
end
207-
208203
Test.@testset "$adtype on loss" for adtype in adtypes
209-
Test.@test !isnothing(DifferentiationInterface.gradient(diff_loss, adtype, ps))
210-
Test.@test !isnothing(DifferentiationInterface.gradient(diff2_loss, adtype, r))
204+
Test.@test !isnothing(DifferentiationInterface.gradient(diff_loss, adtype, ps)) broken =
205+
GROUP != "All" &&
206+
compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} &&
207+
(
208+
omode isa ContinuousNormalizingFlows.TrainMode || (
209+
omode isa ContinuousNormalizingFlows.TestMode &&
210+
compute_mode isa ContinuousNormalizingFlows.VectorMode
211+
)
212+
)
213+
Test.@test !isnothing(DifferentiationInterface.gradient(diff2_loss, adtype, r)) broken =
214+
GROUP != "All" &&
215+
compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} &&
216+
(
217+
omode isa ContinuousNormalizingFlows.TrainMode || (
218+
omode isa ContinuousNormalizingFlows.TestMode &&
219+
compute_mode isa ContinuousNormalizingFlows.VectorMode
220+
)
221+
)
211222

212223
if cond
213224
model = ContinuousNormalizingFlows.CondICNFModel(
@@ -218,14 +229,54 @@ Test.@testset "Smoke Tests" begin
218229
)
219230
mach = MLJBase.machine(model, (df, df2))
220231

221-
Test.@test !isnothing(MLJBase.fit!(mach))
222-
Test.@test !isnothing(MLJBase.transform(mach, (df, df2)))
223-
Test.@test !isnothing(MLJBase.fitted_params(mach))
224-
Test.@test !isnothing(MLJBase.serializable(mach))
232+
Test.@test !isnothing(MLJBase.fit!(mach)) broken =
233+
GROUP != "All" &&
234+
compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} &&
235+
(
236+
omode isa ContinuousNormalizingFlows.TrainMode || (
237+
omode isa ContinuousNormalizingFlows.TestMode &&
238+
compute_mode isa ContinuousNormalizingFlows.VectorMode
239+
)
240+
)
241+
Test.@test !isnothing(MLJBase.transform(mach, (df, df2))) broken =
242+
GROUP != "All" &&
243+
compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} &&
244+
(
245+
omode isa ContinuousNormalizingFlows.TrainMode || (
246+
omode isa ContinuousNormalizingFlows.TestMode &&
247+
compute_mode isa ContinuousNormalizingFlows.VectorMode
248+
)
249+
)
250+
Test.@test !isnothing(MLJBase.fitted_params(mach)) broken =
251+
GROUP != "All" &&
252+
compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} &&
253+
(
254+
omode isa ContinuousNormalizingFlows.TrainMode || (
255+
omode isa ContinuousNormalizingFlows.TestMode &&
256+
compute_mode isa ContinuousNormalizingFlows.VectorMode
257+
)
258+
)
259+
Test.@test !isnothing(MLJBase.serializable(mach)) broken =
260+
GROUP != "All" &&
261+
compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} &&
262+
(
263+
omode isa ContinuousNormalizingFlows.TrainMode || (
264+
omode isa ContinuousNormalizingFlows.TestMode &&
265+
compute_mode isa ContinuousNormalizingFlows.VectorMode
266+
)
267+
)
225268

226269
Test.@test !isnothing(
227270
ContinuousNormalizingFlows.CondICNFDist(mach, omode, r2),
228-
)
271+
) broken =
272+
GROUP != "All" &&
273+
compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} &&
274+
(
275+
omode isa ContinuousNormalizingFlows.TrainMode || (
276+
omode isa ContinuousNormalizingFlows.TestMode &&
277+
compute_mode isa ContinuousNormalizingFlows.VectorMode
278+
)
279+
)
229280
else
230281
model = ContinuousNormalizingFlows.ICNFModel(
231282
icnf;
@@ -235,12 +286,52 @@ Test.@testset "Smoke Tests" begin
235286
)
236287
mach = MLJBase.machine(model, df)
237288

238-
Test.@test !isnothing(MLJBase.fit!(mach))
239-
Test.@test !isnothing(MLJBase.transform(mach, df))
240-
Test.@test !isnothing(MLJBase.fitted_params(mach))
241-
Test.@test !isnothing(MLJBase.serializable(mach))
289+
Test.@test !isnothing(MLJBase.fit!(mach)) broken =
290+
GROUP != "All" &&
291+
compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} &&
292+
(
293+
omode isa ContinuousNormalizingFlows.TrainMode || (
294+
omode isa ContinuousNormalizingFlows.TestMode &&
295+
compute_mode isa ContinuousNormalizingFlows.VectorMode
296+
)
297+
)
298+
Test.@test !isnothing(MLJBase.transform(mach, df)) broken =
299+
GROUP != "All" &&
300+
compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} &&
301+
(
302+
omode isa ContinuousNormalizingFlows.TrainMode || (
303+
omode isa ContinuousNormalizingFlows.TestMode &&
304+
compute_mode isa ContinuousNormalizingFlows.VectorMode
305+
)
306+
)
307+
Test.@test !isnothing(MLJBase.fitted_params(mach)) broken =
308+
GROUP != "All" &&
309+
compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} &&
310+
(
311+
omode isa ContinuousNormalizingFlows.TrainMode || (
312+
omode isa ContinuousNormalizingFlows.TestMode &&
313+
compute_mode isa ContinuousNormalizingFlows.VectorMode
314+
)
315+
)
316+
Test.@test !isnothing(MLJBase.serializable(mach)) broken =
317+
GROUP != "All" &&
318+
compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} &&
319+
(
320+
omode isa ContinuousNormalizingFlows.TrainMode || (
321+
omode isa ContinuousNormalizingFlows.TestMode &&
322+
compute_mode isa ContinuousNormalizingFlows.VectorMode
323+
)
324+
)
242325

243-
Test.@test !isnothing(ContinuousNormalizingFlows.ICNFDist(mach, omode))
326+
Test.@test !isnothing(ContinuousNormalizingFlows.ICNFDist(mach, omode)) broken =
327+
GROUP != "All" &&
328+
compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} &&
329+
(
330+
omode isa ContinuousNormalizingFlows.TrainMode || (
331+
omode isa ContinuousNormalizingFlows.TestMode &&
332+
compute_mode isa ContinuousNormalizingFlows.VectorMode
333+
)
334+
)
244335
end
245336
end
246337
end

0 commit comments

Comments
 (0)