@@ -5,43 +5,143 @@ Test.@testset "CheckByJET" begin
55 mode = :typo,
66 )
77
8- nvars = 2 ^ 3
9- naugs = nvars
10- n_in = nvars + naugs
11- n = 2 ^ 6
12- nn = Lux. Chain(Lux. Dense(n_in => n_in, tanh))
13-
14- icnf = ContinuousNormalizingFlows. construct(
15- ContinuousNormalizingFlows. ICNF,
16- nn,
17- nvars,
18- naugs;
19- compute_mode = ContinuousNormalizingFlows. LuxVecJacMatrixMode(ADTypes. AutoZygote()),
20- tspan = (0.0f0 , 1.0f0 ),
21- steer_rate = 1.0f-1 ,
22- λ₁ = 1.0f-2 ,
23- λ₂ = 1.0f-2 ,
24- λ₃ = 1.0f-2 ,
25- sol_kwargs = (;
26- save_everystep = false ,
27- alg = OrdinaryDiffEqDefault. DefaultODEAlgorithm(),
28- sensealg = SciMLSensitivity. InterpolatingAdjoint(),
8+ mts = Type{<: ContinuousNormalizingFlows.AbstractICNF }[ContinuousNormalizingFlows. ICNF]
9+ omodes = ContinuousNormalizingFlows. Mode[
10+ ContinuousNormalizingFlows. TrainMode(),
11+ ContinuousNormalizingFlows. TestMode(),
12+ ]
13+ conds = Bool[false , true ]
14+ inplaces = Bool[false , true ]
15+ planars = Bool[false , true ]
16+ nvars_ = Int[2 ]
17+ ndata_ = Int[4 ]
18+ data_types = Type{<: AbstractFloat }[Float32]
19+ devices = MLDataDevices. AbstractDevice[MLDataDevices. cpu_device()]
20+ compute_modes = ContinuousNormalizingFlows. ComputeMode[
21+ ContinuousNormalizingFlows. LuxVecJacMatrixMode(ADTypes. AutoZygote()),
22+ ContinuousNormalizingFlows. DIVecJacVectorMode(ADTypes. AutoZygote()),
23+ ContinuousNormalizingFlows. DIVecJacMatrixMode(ADTypes. AutoZygote()),
24+ ContinuousNormalizingFlows. LuxJacVecMatrixMode(ADTypes. AutoForwardDiff()),
25+ ContinuousNormalizingFlows. DIJacVecVectorMode(ADTypes. AutoForwardDiff()),
26+ ContinuousNormalizingFlows. DIJacVecMatrixMode(ADTypes. AutoForwardDiff()),
27+ ContinuousNormalizingFlows. DIVecJacVectorMode(
28+ ADTypes. AutoEnzyme(;
29+ mode = Enzyme. set_runtime_activity(Enzyme. Reverse),
30+ function_annotation = Enzyme. Const,
31+ ),
2932 ),
30- )
31- ps, st = LuxCore. setup(icnf. rng, icnf)
32- ps = ComponentArrays. ComponentArray(ps)
33- r = rand(icnf. rng, Float32, nvars, n)
33+ ContinuousNormalizingFlows. DIVecJacMatrixMode(
34+ ADTypes. AutoEnzyme(;
35+ mode = Enzyme. set_runtime_activity(Enzyme. Reverse),
36+ function_annotation = Enzyme. Const,
37+ ),
38+ ),
39+ ContinuousNormalizingFlows. DIJacVecVectorMode(
40+ ADTypes. AutoEnzyme(;
41+ mode = Enzyme. set_runtime_activity(Enzyme. Forward),
42+ function_annotation = Enzyme. Const,
43+ ),
44+ ),
45+ ContinuousNormalizingFlows. DIJacVecMatrixMode(
46+ ADTypes. AutoEnzyme(;
47+ mode = Enzyme. set_runtime_activity(Enzyme. Forward),
48+ function_annotation = Enzyme. Const,
49+ ),
50+ ),
51+ ]
3452
35- ContinuousNormalizingFlows. loss(icnf, ContinuousNormalizingFlows. TrainMode(), r, ps, st)
36- JET. test_call(
37- ContinuousNormalizingFlows. loss,
38- Base. typesof(icnf, ContinuousNormalizingFlows. TrainMode(), r, ps, st);
39- target_modules = [ContinuousNormalizingFlows],
40- mode = :typo,
41- )
42- JET. test_opt(
43- ContinuousNormalizingFlows. loss,
44- Base. typesof(icnf, ContinuousNormalizingFlows. TrainMode(), r, ps, st);
45- target_modules = [ContinuousNormalizingFlows],
46- )
53+ Test. @testset " $device | $data_type | $compute_mode | ndata = $ndata | nvars = $nvars | inplace = $inplace | cond = $cond | planar = $planar | $omode | $mt " for device in
54+ devices,
55+ data_type in data_types,
56+ compute_mode in compute_modes,
57+ ndata in ndata_,
58+ nvars in nvars_,
59+ inplace in inplaces,
60+ cond in conds,
61+ planar in planars,
62+ omode in omodes,
63+ mt in mts
64+
65+ data_dist =
66+ Distributions. Beta{data_type}(convert(Tuple{data_type, data_type}, (2 , 4 )). .. )
67+ data_dist2 =
68+ Distributions. Beta{data_type}(convert(Tuple{data_type, data_type}, (4 , 2 )). .. )
69+ if compute_mode isa ContinuousNormalizingFlows. VectorMode
70+ r = convert.(data_type, rand(data_dist, nvars))
71+ r2 = convert.(data_type, rand(data_dist2, nvars))
72+ elseif compute_mode isa ContinuousNormalizingFlows. MatrixMode
73+ r = convert.(data_type, rand(data_dist, nvars, ndata))
74+ r2 = convert.(data_type, rand(data_dist2, nvars, ndata))
75+ end
76+
77+ nn = ifelse(
78+ cond,
79+ ifelse(
80+ planar,
81+ Lux. Chain(
82+ ContinuousNormalizingFlows. PlanarLayer(nvars * 2 , tanh; n_cond = nvars),
83+ ),
84+ Lux. Chain(Lux. Dense(nvars * 3 => nvars * 2 , tanh)),
85+ ),
86+ ifelse(
87+ planar,
88+ Lux. Chain(ContinuousNormalizingFlows. PlanarLayer(nvars * 2 , tanh)),
89+ Lux. Chain(Lux. Dense(nvars * 2 => nvars * 2 , tanh)),
90+ ),
91+ )
92+ icnf = ContinuousNormalizingFlows. construct(
93+ mt,
94+ nn,
95+ nvars,
96+ nvars;
97+ data_type,
98+ compute_mode,
99+ inplace,
100+ cond,
101+ device,
102+ steer_rate = convert(data_type, 1.0e-1 ),
103+ λ₁ = convert(data_type, 1.0e-2 ),
104+ λ₂ = convert(data_type, 1.0e-2 ),
105+ λ₃ = convert(data_type, 1.0e-2 ),
106+ sol_kwargs = (;
107+ save_everystep = false ,
108+ alg = OrdinaryDiffEqDefault. DefaultODEAlgorithm(),
109+ sensealg = SciMLSensitivity. InterpolatingAdjoint(),
110+ ),
111+ )
112+ ps, st = LuxCore. setup(icnf. rng, icnf)
113+ ps = ComponentArrays. ComponentArray(ps)
114+ r = device(r)
115+ r2 = device(r2)
116+ ps = device(ps)
117+ st = device(st)
118+
119+ if cond
120+ ContinuousNormalizingFlows. loss(icnf, omode, r, r2, ps, st)
121+ JET. test_call(
122+ ContinuousNormalizingFlows. loss,
123+ Base. typesof(icnf, omode, r, r2, ps, st);
124+ target_modules = [ContinuousNormalizingFlows],
125+ mode = :typo,
126+ )
127+ JET. test_opt(
128+ ContinuousNormalizingFlows. loss,
129+ Base. typesof(icnf, omode, r, r2, ps, st);
130+ target_modules = [ContinuousNormalizingFlows],
131+ )
132+ else
133+ ContinuousNormalizingFlows. loss(icnf, omode, r, ps, st)
134+ JET. test_call(
135+ ContinuousNormalizingFlows. loss,
136+ Base. typesof(icnf, omode, r, ps, st);
137+ target_modules = [ContinuousNormalizingFlows],
138+ mode = :typo,
139+ )
140+ JET. test_opt(
141+ ContinuousNormalizingFlows. loss,
142+ Base. typesof(icnf, omode, r, ps, st);
143+ target_modules = [ContinuousNormalizingFlows],
144+ )
145+ end
146+ end
47147end
0 commit comments