@@ -200,14 +200,25 @@ Test.@testset "Smoke Tests" begin
200200 Test. @test ! isnothing(rand(d))
201201 Test. @test ! isnothing(rand(d, ndata))
202202
203- if GROUP != " All" &&
204- compute_mode. adback isa ADTypes. AutoEnzyme{<: Enzyme.ForwardMode }
205- continue
206- end
207-
208203 Test. @testset " $adtype on loss" for adtype in adtypes
209- Test. @test ! isnothing(DifferentiationInterface. gradient(diff_loss, adtype, ps))
210- Test. @test ! isnothing(DifferentiationInterface. gradient(diff2_loss, adtype, r))
204+ Test. @test ! isnothing(DifferentiationInterface. gradient(diff_loss, adtype, ps)) broken =
205+ GROUP != " All" &&
206+ compute_mode. adback isa ADTypes. AutoEnzyme{<: Enzyme.ForwardMode } &&
207+ (
208+ omode isa ContinuousNormalizingFlows. TrainMode || (
209+ omode isa ContinuousNormalizingFlows. TestMode &&
210+ compute_mode isa ContinuousNormalizingFlows. VectorMode
211+ )
212+ )
213+ Test. @test ! isnothing(DifferentiationInterface. gradient(diff2_loss, adtype, r)) broken =
214+ GROUP != " All" &&
215+ compute_mode. adback isa ADTypes. AutoEnzyme{<: Enzyme.ForwardMode } &&
216+ (
217+ omode isa ContinuousNormalizingFlows. TrainMode || (
218+ omode isa ContinuousNormalizingFlows. TestMode &&
219+ compute_mode isa ContinuousNormalizingFlows. VectorMode
220+ )
221+ )
211222
212223 if cond
213224 model = ContinuousNormalizingFlows. CondICNFModel(
@@ -218,14 +229,54 @@ Test.@testset "Smoke Tests" begin
218229 )
219230 mach = MLJBase. machine(model, (df, df2))
220231
221- Test. @test ! isnothing(MLJBase. fit!(mach))
222- Test. @test ! isnothing(MLJBase. transform(mach, (df, df2)))
223- Test. @test ! isnothing(MLJBase. fitted_params(mach))
224- Test. @test ! isnothing(MLJBase. serializable(mach))
232+ Test. @test ! isnothing(MLJBase. fit!(mach)) broken =
233+ GROUP != " All" &&
234+ compute_mode. adback isa ADTypes. AutoEnzyme{<: Enzyme.ForwardMode } &&
235+ (
236+ omode isa ContinuousNormalizingFlows. TrainMode || (
237+ omode isa ContinuousNormalizingFlows. TestMode &&
238+ compute_mode isa ContinuousNormalizingFlows. VectorMode
239+ )
240+ )
241+ Test. @test ! isnothing(MLJBase. transform(mach, (df, df2))) broken =
242+ GROUP != " All" &&
243+ compute_mode. adback isa ADTypes. AutoEnzyme{<: Enzyme.ForwardMode } &&
244+ (
245+ omode isa ContinuousNormalizingFlows. TrainMode || (
246+ omode isa ContinuousNormalizingFlows. TestMode &&
247+ compute_mode isa ContinuousNormalizingFlows. VectorMode
248+ )
249+ )
250+ Test. @test ! isnothing(MLJBase. fitted_params(mach)) broken =
251+ GROUP != " All" &&
252+ compute_mode. adback isa ADTypes. AutoEnzyme{<: Enzyme.ForwardMode } &&
253+ (
254+ omode isa ContinuousNormalizingFlows. TrainMode || (
255+ omode isa ContinuousNormalizingFlows. TestMode &&
256+ compute_mode isa ContinuousNormalizingFlows. VectorMode
257+ )
258+ )
259+ Test. @test ! isnothing(MLJBase. serializable(mach)) broken =
260+ GROUP != " All" &&
261+ compute_mode. adback isa ADTypes. AutoEnzyme{<: Enzyme.ForwardMode } &&
262+ (
263+ omode isa ContinuousNormalizingFlows. TrainMode || (
264+ omode isa ContinuousNormalizingFlows. TestMode &&
265+ compute_mode isa ContinuousNormalizingFlows. VectorMode
266+ )
267+ )
225268
226269 Test. @test ! isnothing(
227270 ContinuousNormalizingFlows. CondICNFDist(mach, omode, r2),
228- )
271+ ) broken =
272+ GROUP != " All" &&
273+ compute_mode. adback isa ADTypes. AutoEnzyme{<: Enzyme.ForwardMode } &&
274+ (
275+ omode isa ContinuousNormalizingFlows. TrainMode || (
276+ omode isa ContinuousNormalizingFlows. TestMode &&
277+ compute_mode isa ContinuousNormalizingFlows. VectorMode
278+ )
279+ )
229280 else
230281 model = ContinuousNormalizingFlows. ICNFModel(
231282 icnf;
@@ -235,12 +286,52 @@ Test.@testset "Smoke Tests" begin
235286 )
236287 mach = MLJBase. machine(model, df)
237288
238- Test. @test ! isnothing(MLJBase. fit!(mach))
239- Test. @test ! isnothing(MLJBase. transform(mach, df))
240- Test. @test ! isnothing(MLJBase. fitted_params(mach))
241- Test. @test ! isnothing(MLJBase. serializable(mach))
289+ Test. @test ! isnothing(MLJBase. fit!(mach)) broken =
290+ GROUP != " All" &&
291+ compute_mode. adback isa ADTypes. AutoEnzyme{<: Enzyme.ForwardMode } &&
292+ (
293+ omode isa ContinuousNormalizingFlows. TrainMode || (
294+ omode isa ContinuousNormalizingFlows. TestMode &&
295+ compute_mode isa ContinuousNormalizingFlows. VectorMode
296+ )
297+ )
298+ Test. @test ! isnothing(MLJBase. transform(mach, df)) broken =
299+ GROUP != " All" &&
300+ compute_mode. adback isa ADTypes. AutoEnzyme{<: Enzyme.ForwardMode } &&
301+ (
302+ omode isa ContinuousNormalizingFlows. TrainMode || (
303+ omode isa ContinuousNormalizingFlows. TestMode &&
304+ compute_mode isa ContinuousNormalizingFlows. VectorMode
305+ )
306+ )
307+ Test. @test ! isnothing(MLJBase. fitted_params(mach)) broken =
308+ GROUP != " All" &&
309+ compute_mode. adback isa ADTypes. AutoEnzyme{<: Enzyme.ForwardMode } &&
310+ (
311+ omode isa ContinuousNormalizingFlows. TrainMode || (
312+ omode isa ContinuousNormalizingFlows. TestMode &&
313+ compute_mode isa ContinuousNormalizingFlows. VectorMode
314+ )
315+ )
316+ Test. @test ! isnothing(MLJBase. serializable(mach)) broken =
317+ GROUP != " All" &&
318+ compute_mode. adback isa ADTypes. AutoEnzyme{<: Enzyme.ForwardMode } &&
319+ (
320+ omode isa ContinuousNormalizingFlows. TrainMode || (
321+ omode isa ContinuousNormalizingFlows. TestMode &&
322+ compute_mode isa ContinuousNormalizingFlows. VectorMode
323+ )
324+ )
242325
243- Test. @test ! isnothing(ContinuousNormalizingFlows. ICNFDist(mach, omode))
326+ Test. @test ! isnothing(ContinuousNormalizingFlows. ICNFDist(mach, omode)) broken =
327+ GROUP != " All" &&
328+ compute_mode. adback isa ADTypes. AutoEnzyme{<: Enzyme.ForwardMode } &&
329+ (
330+ omode isa ContinuousNormalizingFlows. TrainMode || (
331+ omode isa ContinuousNormalizingFlows. TestMode &&
332+ compute_mode isa ContinuousNormalizingFlows. VectorMode
333+ )
334+ )
244335 end
245336 end
246337 end
0 commit comments