Skip to content

Commit c10496e

Browse files
committed
fix errors
1 parent ea4cf01 commit c10496e

File tree

6 files changed

+5
-26
lines changed

6 files changed

+5
-26
lines changed

examples/usage.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ model = ICNFModel(
5555
optimizers = (Adam(),),
5656
adtype = AutoZygote(),
5757
batchsize = 512,
58-
sol_kwargs = (; progress = true, epochs = 300), # pass to the solver
58+
sol_kwargs = (; epochs = 300), # pass to the solver
5959
)
6060
mach = machine(model, df)
6161
fit!(mach)

src/exts/mlj_ext/core.jl

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,7 @@ function make_dataloader(
2222
::Int,
2323
data::Tuple,
2424
)
25-
return MLUtils.DataLoader(
26-
data;
27-
batchsize = -1,
28-
shuffle = true,
29-
partial = true,
30-
rng = icnf.rng,
31-
)
25+
return MLUtils.DataLoader(data; batchsize = -1, shuffle = true, partial = true)
3226
end
3327

3428
function make_dataloader(
@@ -45,6 +39,5 @@ function make_dataloader(
4539
end,
4640
shuffle = true,
4741
partial = true,
48-
rng = icnf.rng,
4942
)
5043
end

test/Project.toml

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,13 @@ Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
1010
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
1111
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1212
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
13-
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
1413
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
1514
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
1615
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
1716
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
1817
OrdinaryDiffEqDefault = "50262376-6c5a-4cf5-baba-aaf4f84d72d7"
1918
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
2019
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
21-
TerminalLoggers = "5d786b92-1e48-4d6f-9151-6b4477ca9bed"
2220
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2321
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2422

@@ -34,15 +32,13 @@ Enzyme = "0.13"
3432
ExplicitImports = "1"
3533
ForwardDiff = "1"
3634
JET = "0.9, 0.10"
37-
Logging = "1"
3835
Lux = "1"
3936
LuxCore = "1"
4037
MLDataDevices = "1"
4138
MLJBase = "1"
4239
OrdinaryDiffEqDefault = "1"
4340
SciMLSensitivity = "7"
4441
StableRNGs = "1"
45-
TerminalLoggers = "0.1"
4642
Test = "1"
4743
Zygote = "0.7"
4844
julia = "1.10"

test/ci_tests/regression_tests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ Test.@testset verbose = true showtiming = true failfast = false "Regression Test
3535
model = ContinuousNormalizingFlows.ICNFModel(
3636
icnf;
3737
batchsize = 0,
38-
sol_kwargs = (; progress = true, epochs = 300),
38+
sol_kwargs = (; epochs = 300),
3939
)
4040

4141
mach = MLJBase.machine(model, df)

test/ci_tests/smoke_tests.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,6 @@ Test.@testset verbose = true showtiming = true failfast = false "Smoke Tests" be
207207

208208
Test.@testset verbose = true showtiming = true failfast = false "$adtype on loss" for adtype in
209209
adtypes
210-
211210
Test.@test !isnothing(DifferentiationInterface.gradient(diff_loss, adtype, ps)) broken =
212211
compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} && (
213212
omode isa ContinuousNormalizingFlows.TrainMode || (
@@ -228,7 +227,7 @@ Test.@testset verbose = true showtiming = true failfast = false "Smoke Tests" be
228227
icnf;
229228
adtype,
230229
batchsize = 0,
231-
sol_kwargs = (; progress = true, epochs = 2),
230+
sol_kwargs = (; epochs = 2),
232231
)
233232
mach = MLJBase.machine(model, (df, df2))
234233

@@ -249,7 +248,7 @@ Test.@testset verbose = true showtiming = true failfast = false "Smoke Tests" be
249248
icnf;
250249
adtype,
251250
batchsize = 0,
252-
sol_kwargs = (; progress = true, epochs = 2),
251+
sol_kwargs = (; epochs = 2),
253252
)
254253
mach = MLJBase.machine(model, df)
255254

test/runtests.jl

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,28 +9,19 @@ import ADTypes,
99
ExplicitImports,
1010
ForwardDiff,
1111
JET,
12-
Logging,
1312
Lux,
1413
LuxCore,
1514
MLDataDevices,
1615
MLJBase,
1716
OrdinaryDiffEqDefault,
1817
SciMLSensitivity,
1918
StableRNGs,
20-
TerminalLoggers,
2119
Test,
2220
Zygote,
2321
ContinuousNormalizingFlows
2422

2523
GROUP = get(ENV, "GROUP", "All")
2624

27-
if GROUP == "All"
28-
GC.enable_logging(true)
29-
30-
debuglogger = TerminalLoggers.TerminalLogger(stderr, Logging.Debug)
31-
Logging.global_logger(debuglogger)
32-
end
33-
3425
Test.@testset verbose = true showtiming = true failfast = false "Overall" begin
3526
if GROUP == "All" || GROUP in ["SmokeXOut", "SmokeXIn", "SmokeXYOut", "SmokeXYIn"]
3627
include(joinpath("ci_tests", "smoke_tests.jl"))

0 commit comments

Comments
 (0)