Skip to content

Commit b970935

Browse files
committed
update JET tests
1 parent ce0f413 commit b970935

File tree

2 files changed

+138
-38
lines changed

2 files changed

+138
-38
lines changed

.github/workflows/CI-CheckBy.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ jobs:
2626
version:
2727
- release
2828
- lts
29-
- nightly
29+
# - nightly
3030
os:
3131
- ubuntu-latest
3232
# - macOS-latest

test/checkby_JET_tests.jl

Lines changed: 137 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -5,43 +5,143 @@ Test.@testset "CheckByJET" begin
55
mode = :typo,
66
)
77

8-
nvars = 2^3
9-
naugs = nvars
10-
n_in = nvars + naugs
11-
n = 2^6
12-
nn = Lux.Chain(Lux.Dense(n_in => n_in, tanh))
13-
14-
icnf = ContinuousNormalizingFlows.construct(
15-
ContinuousNormalizingFlows.ICNF,
16-
nn,
17-
nvars,
18-
naugs;
19-
compute_mode = ContinuousNormalizingFlows.LuxVecJacMatrixMode(ADTypes.AutoZygote()),
20-
tspan = (0.0f0, 1.0f0),
21-
steer_rate = 1.0f-1,
22-
λ₁ = 1.0f-2,
23-
λ₂ = 1.0f-2,
24-
λ₃ = 1.0f-2,
25-
sol_kwargs = (;
26-
save_everystep = false,
27-
alg = OrdinaryDiffEqDefault.DefaultODEAlgorithm(),
28-
sensealg = SciMLSensitivity.InterpolatingAdjoint(),
8+
mts = Type{<:ContinuousNormalizingFlows.AbstractICNF}[ContinuousNormalizingFlows.ICNF]
9+
omodes = ContinuousNormalizingFlows.Mode[
10+
ContinuousNormalizingFlows.TrainMode(),
11+
ContinuousNormalizingFlows.TestMode(),
12+
]
13+
conds = Bool[false, true]
14+
inplaces = Bool[false, true]
15+
planars = Bool[false, true]
16+
nvars_ = Int[2]
17+
ndata_ = Int[4]
18+
data_types = Type{<:AbstractFloat}[Float32]
19+
devices = MLDataDevices.AbstractDevice[MLDataDevices.cpu_device()]
20+
compute_modes = ContinuousNormalizingFlows.ComputeMode[
21+
ContinuousNormalizingFlows.LuxVecJacMatrixMode(ADTypes.AutoZygote()),
22+
ContinuousNormalizingFlows.DIVecJacVectorMode(ADTypes.AutoZygote()),
23+
ContinuousNormalizingFlows.DIVecJacMatrixMode(ADTypes.AutoZygote()),
24+
ContinuousNormalizingFlows.LuxJacVecMatrixMode(ADTypes.AutoForwardDiff()),
25+
ContinuousNormalizingFlows.DIJacVecVectorMode(ADTypes.AutoForwardDiff()),
26+
ContinuousNormalizingFlows.DIJacVecMatrixMode(ADTypes.AutoForwardDiff()),
27+
ContinuousNormalizingFlows.DIVecJacVectorMode(
28+
ADTypes.AutoEnzyme(;
29+
mode = Enzyme.set_runtime_activity(Enzyme.Reverse),
30+
function_annotation = Enzyme.Const,
31+
),
2932
),
30-
)
31-
ps, st = LuxCore.setup(icnf.rng, icnf)
32-
ps = ComponentArrays.ComponentArray(ps)
33-
r = rand(icnf.rng, Float32, nvars, n)
33+
ContinuousNormalizingFlows.DIVecJacMatrixMode(
34+
ADTypes.AutoEnzyme(;
35+
mode = Enzyme.set_runtime_activity(Enzyme.Reverse),
36+
function_annotation = Enzyme.Const,
37+
),
38+
),
39+
ContinuousNormalizingFlows.DIJacVecVectorMode(
40+
ADTypes.AutoEnzyme(;
41+
mode = Enzyme.set_runtime_activity(Enzyme.Forward),
42+
function_annotation = Enzyme.Const,
43+
),
44+
),
45+
ContinuousNormalizingFlows.DIJacVecMatrixMode(
46+
ADTypes.AutoEnzyme(;
47+
mode = Enzyme.set_runtime_activity(Enzyme.Forward),
48+
function_annotation = Enzyme.Const,
49+
),
50+
),
51+
]
3452

35-
ContinuousNormalizingFlows.loss(icnf, ContinuousNormalizingFlows.TrainMode(), r, ps, st)
36-
JET.test_call(
37-
ContinuousNormalizingFlows.loss,
38-
Base.typesof(icnf, ContinuousNormalizingFlows.TrainMode(), r, ps, st);
39-
target_modules = [ContinuousNormalizingFlows],
40-
mode = :typo,
41-
)
42-
JET.test_opt(
43-
ContinuousNormalizingFlows.loss,
44-
Base.typesof(icnf, ContinuousNormalizingFlows.TrainMode(), r, ps, st);
45-
target_modules = [ContinuousNormalizingFlows],
46-
)
53+
Test.@testset "$device | $data_type | $compute_mode | ndata = $ndata | nvars = $nvars | inplace = $inplace | cond = $cond | planar = $planar | $omode | $mt" for device in
54+
devices,
55+
data_type in data_types,
56+
compute_mode in compute_modes,
57+
ndata in ndata_,
58+
nvars in nvars_,
59+
inplace in inplaces,
60+
cond in conds,
61+
planar in planars,
62+
omode in omodes,
63+
mt in mts
64+
65+
data_dist =
66+
Distributions.Beta{data_type}(convert(Tuple{data_type, data_type}, (2, 4))...)
67+
data_dist2 =
68+
Distributions.Beta{data_type}(convert(Tuple{data_type, data_type}, (4, 2))...)
69+
if compute_mode isa ContinuousNormalizingFlows.VectorMode
70+
r = convert.(data_type, rand(data_dist, nvars))
71+
r2 = convert.(data_type, rand(data_dist2, nvars))
72+
elseif compute_mode isa ContinuousNormalizingFlows.MatrixMode
73+
r = convert.(data_type, rand(data_dist, nvars, ndata))
74+
r2 = convert.(data_type, rand(data_dist2, nvars, ndata))
75+
end
76+
77+
nn = ifelse(
78+
cond,
79+
ifelse(
80+
planar,
81+
Lux.Chain(
82+
ContinuousNormalizingFlows.PlanarLayer(nvars * 2, tanh; n_cond = nvars),
83+
),
84+
Lux.Chain(Lux.Dense(nvars * 3 => nvars * 2, tanh)),
85+
),
86+
ifelse(
87+
planar,
88+
Lux.Chain(ContinuousNormalizingFlows.PlanarLayer(nvars * 2, tanh)),
89+
Lux.Chain(Lux.Dense(nvars * 2 => nvars * 2, tanh)),
90+
),
91+
)
92+
icnf = ContinuousNormalizingFlows.construct(
93+
mt,
94+
nn,
95+
nvars,
96+
nvars;
97+
data_type,
98+
compute_mode,
99+
inplace,
100+
cond,
101+
device,
102+
steer_rate = convert(data_type, 1.0e-1),
103+
λ₁ = convert(data_type, 1.0e-2),
104+
λ₂ = convert(data_type, 1.0e-2),
105+
λ₃ = convert(data_type, 1.0e-2),
106+
sol_kwargs = (;
107+
save_everystep = false,
108+
alg = OrdinaryDiffEqDefault.DefaultODEAlgorithm(),
109+
sensealg = SciMLSensitivity.InterpolatingAdjoint(),
110+
),
111+
)
112+
ps, st = LuxCore.setup(icnf.rng, icnf)
113+
ps = ComponentArrays.ComponentArray(ps)
114+
r = device(r)
115+
r2 = device(r2)
116+
ps = device(ps)
117+
st = device(st)
118+
119+
if cond
120+
ContinuousNormalizingFlows.loss(icnf, omode, r, r2, ps, st)
121+
JET.test_call(
122+
ContinuousNormalizingFlows.loss,
123+
Base.typesof(icnf, omode, r, r2, ps, st);
124+
target_modules = [ContinuousNormalizingFlows],
125+
mode = :typo,
126+
)
127+
JET.test_opt(
128+
ContinuousNormalizingFlows.loss,
129+
Base.typesof(icnf, omode, r, r2, ps, st);
130+
target_modules = [ContinuousNormalizingFlows],
131+
)
132+
else
133+
ContinuousNormalizingFlows.loss(icnf, omode, r, ps, st)
134+
JET.test_call(
135+
ContinuousNormalizingFlows.loss,
136+
Base.typesof(icnf, omode, r, ps, st);
137+
target_modules = [ContinuousNormalizingFlows],
138+
mode = :typo,
139+
)
140+
JET.test_opt(
141+
ContinuousNormalizingFlows.loss,
142+
Base.typesof(icnf, omode, r, ps, st);
143+
target_modules = [ContinuousNormalizingFlows],
144+
)
145+
end
146+
end
47147
end

0 commit comments

Comments
 (0)