diff --git a/Project.toml b/Project.toml index f4842780..4772d056 100644 --- a/Project.toml +++ b/Project.toml @@ -29,6 +29,7 @@ SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1" ScientificTypesBase = "30f210dd-8aff-4c5f-94ba-8e64358c1161" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +WeightInitializers = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] @@ -57,5 +58,6 @@ SciMLBase = "2" SciMLSensitivity = "7" ScientificTypesBase = "3" Statistics = "1" +WeightInitializers = "1" Zygote = "0.7" julia = "1.10" diff --git a/benchmark/Project.toml b/benchmark/Project.toml index 3b4d7455..e0ad0695 100644 --- a/benchmark/Project.toml +++ b/benchmark/Project.toml @@ -3,6 +3,7 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" OrdinaryDiffEqDefault = "50262376-6c5a-4cf5-baba-aaf4f84d72d7" @@ -16,6 +17,7 @@ ADTypes = "1" BenchmarkTools = "1" ComponentArrays = "0.15" DifferentiationInterface = "0.7" +Distributions = "0.25" Lux = "1" LuxCore = "1" OrdinaryDiffEqDefault = "1" diff --git a/benchmark/benchmarks.jl b/benchmark/benchmarks.jl index 254017fc..e5368cf0 100644 --- a/benchmark/benchmarks.jl +++ b/benchmark/benchmarks.jl @@ -2,6 +2,7 @@ import ADTypes, BenchmarkTools, ComponentArrays, DifferentiationInterface, + Distributions, Lux, LuxCore, OrdinaryDiffEqDefault, @@ -11,25 +12,18 @@ import ADTypes, Zygote, ContinuousNormalizingFlows -SUITE = BenchmarkTools.BenchmarkGroup() - -SUITE["main"] = BenchmarkTools.BenchmarkGroup(["package", "simple"]) - -SUITE["main"]["no_inplace"] = BenchmarkTools.BenchmarkGroup(["no_inplace"]) -SUITE["main"]["inplace"] = BenchmarkTools.BenchmarkGroup(["inplace"]) - -SUITE["main"]["no_inplace"]["direct"] = BenchmarkTools.BenchmarkGroup(["direct"]) -SUITE["main"]["no_inplace"]["AD-1-order"] = BenchmarkTools.BenchmarkGroup(["gradient"]) - -SUITE["main"]["inplace"]["direct"] = BenchmarkTools.BenchmarkGroup(["direct"]) -SUITE["main"]["inplace"]["AD-1-order"] = BenchmarkTools.BenchmarkGroup(["gradient"]) - rng = StableRNGs.StableRNG(1) -nvars = 2^3 +ndata = 2^10 +ndimension = 1 +data_dist = Distributions.Beta{Float32}(2.0f0, 4.0f0) +r = rand(rng, data_dist, ndimension, ndata) +r = convert.(Float32, r) + +nvars = size(r, 1) naugs = nvars n_in = nvars + naugs -n = 2^6 -nn = Lux.Chain(Lux.Dense(n_in => n_in, tanh)) + +nn = Lux.Chain(Lux.Dense(n_in => 3 * n_in, tanh), Lux.Dense(3 * n_in => n_in, tanh)) icnf = ContinuousNormalizingFlows.construct( ContinuousNormalizingFlows.ICNF, @@ -49,9 +43,29 @@ icnf = ContinuousNormalizingFlows.construct( sensealg = SciMLSensitivity.InterpolatingAdjoint(), ), ) + +icnf2 = ContinuousNormalizingFlows.construct( + ContinuousNormalizingFlows.ICNF, + nn, + nvars, + naugs; + inplace = true, + compute_mode = ContinuousNormalizingFlows.LuxVecJacMatrixMode(ADTypes.AutoZygote()), + tspan = (0.0f0, 1.0f0), + steer_rate = 1.0f-1, + λ₁ = 1.0f-2, + λ₂ = 1.0f-2, + λ₃ = 1.0f-2, + rng, + sol_kwargs = (; + save_everystep = false, + alg = OrdinaryDiffEqDefault.DefaultODEAlgorithm(), + sensealg = SciMLSensitivity.InterpolatingAdjoint(), + ), +) + ps, st = LuxCore.setup(icnf.rng, icnf) ps = ComponentArrays.ComponentArray(ps) -r = rand(icnf.rng, Float32, nvars, n) function diff_loss_tn(x::Any) return ContinuousNormalizingFlows.loss( @@ -72,49 +86,6 @@ function diff_loss_tt(x::Any) ) end -diff_loss_tn(ps) -diff_loss_tt(ps) -DifferentiationInterface.gradient(diff_loss_tn, ADTypes.AutoZygote(), ps) -DifferentiationInterface.gradient(diff_loss_tt, ADTypes.AutoZygote(), ps) -GC.gc() - -SUITE["main"]["no_inplace"]["direct"]["train"] = - BenchmarkTools.@benchmarkable diff_loss_tn(ps) -SUITE["main"]["no_inplace"]["direct"]["test"] = - BenchmarkTools.@benchmarkable diff_loss_tt(ps) -SUITE["main"]["no_inplace"]["AD-1-order"]["train"] = - BenchmarkTools.@benchmarkable DifferentiationInterface.gradient( - diff_loss_tn, - ADTypes.AutoZygote(), - ps, - ) -SUITE["main"]["no_inplace"]["AD-1-order"]["test"] = - BenchmarkTools.@benchmarkable DifferentiationInterface.gradient( - diff_loss_tt, - ADTypes.AutoZygote(), - ps, - ) - -icnf2 = ContinuousNormalizingFlows.construct( - ContinuousNormalizingFlows.ICNF, - nn, - nvars, - naugs; - inplace = true, - compute_mode = ContinuousNormalizingFlows.LuxVecJacMatrixMode(ADTypes.AutoZygote()), - tspan = (0.0f0, 1.0f0), - steer_rate = 1.0f-1, - λ₁ = 1.0f-2, - λ₂ = 1.0f-2, - λ₃ = 1.0f-2, - rng, - sol_kwargs = (; - save_everystep = false, - alg = OrdinaryDiffEqDefault.DefaultODEAlgorithm(), - sensealg = SciMLSensitivity.InterpolatingAdjoint(), - ), -) - function diff_loss_tn2(x::Any) return ContinuousNormalizingFlows.loss( icnf2, @@ -134,12 +105,47 @@ function diff_loss_tt2(x::Any) ) end +diff_loss_tn(ps) +diff_loss_tt(ps) +DifferentiationInterface.gradient(diff_loss_tn, ADTypes.AutoZygote(), ps) +DifferentiationInterface.gradient(diff_loss_tt, ADTypes.AutoZygote(), ps) + diff_loss_tn2(ps) diff_loss_tt2(ps) DifferentiationInterface.gradient(diff_loss_tn2, ADTypes.AutoZygote(), ps) DifferentiationInterface.gradient(diff_loss_tt2, ADTypes.AutoZygote(), ps) + GC.gc() +SUITE = BenchmarkTools.BenchmarkGroup() + +SUITE["main"] = BenchmarkTools.BenchmarkGroup(["package", "simple"]) + +SUITE["main"]["no_inplace"] = BenchmarkTools.BenchmarkGroup(["no_inplace"]) +SUITE["main"]["inplace"] = BenchmarkTools.BenchmarkGroup(["inplace"]) + +SUITE["main"]["no_inplace"]["direct"] = BenchmarkTools.BenchmarkGroup(["direct"]) +SUITE["main"]["no_inplace"]["AD-1-order"] = BenchmarkTools.BenchmarkGroup(["gradient"]) + +SUITE["main"]["inplace"]["direct"] = BenchmarkTools.BenchmarkGroup(["direct"]) +SUITE["main"]["inplace"]["AD-1-order"] = BenchmarkTools.BenchmarkGroup(["gradient"]) + +SUITE["main"]["no_inplace"]["direct"]["train"] = + BenchmarkTools.@benchmarkable diff_loss_tn(ps) +SUITE["main"]["no_inplace"]["direct"]["test"] = + BenchmarkTools.@benchmarkable diff_loss_tt(ps) +SUITE["main"]["no_inplace"]["AD-1-order"]["train"] = + BenchmarkTools.@benchmarkable DifferentiationInterface.gradient( + diff_loss_tn, + ADTypes.AutoZygote(), + ps, + ) +SUITE["main"]["no_inplace"]["AD-1-order"]["test"] = + BenchmarkTools.@benchmarkable DifferentiationInterface.gradient( + diff_loss_tt, + ADTypes.AutoZygote(), + ps, + ) SUITE["main"]["inplace"]["direct"]["train"] = BenchmarkTools.@benchmarkable diff_loss_tn2(ps) SUITE["main"]["inplace"]["direct"]["test"] = BenchmarkTools.@benchmarkable diff_loss_tt2(ps) diff --git a/examples/usage.jl b/examples/usage.jl index 7d3591ec..eb80271e 100644 --- a/examples/usage.jl +++ b/examples/usage.jl @@ -54,7 +54,7 @@ model = ICNFModel( icnf; optimizers = (Adam(),), adtype = AutoZygote(), - batch_size = 512, + batchsize = 512, sol_kwargs = (; progress = true, epochs = 300), # pass to the solver ) mach = machine(model, df) diff --git a/src/ContinuousNormalizingFlows.jl b/src/ContinuousNormalizingFlows.jl index 19df19b4..37fa31fd 100644 --- a/src/ContinuousNormalizingFlows.jl +++ b/src/ContinuousNormalizingFlows.jl @@ -25,6 +25,7 @@ import ADTypes, SciMLSensitivity, ScientificTypesBase, Statistics, + WeightInitializers, Zygote export construct, @@ -54,13 +55,10 @@ export construct, include(joinpath("layers", "cond_layer.jl")) include(joinpath("layers", "planar_layer.jl")) -include("types.jl") - -include("base_icnf.jl") - -include("icnf.jl") - -include("utils.jl") +include(joinpath("core", "types.jl")) +include(joinpath("core", "base_icnf.jl")) +include(joinpath("core", "icnf.jl")) +include(joinpath("core", "utils.jl")) include(joinpath("exts", "mlj_ext", "core.jl")) include(joinpath("exts", "mlj_ext", "core_icnf.jl")) diff --git a/src/base_icnf.jl b/src/core/base_icnf.jl similarity index 97% rename from src/base_icnf.jl rename to src/core/base_icnf.jl index 25de89b5..5991348c 100644 --- a/src/base_icnf.jl +++ b/src/core/base_icnf.jl @@ -94,12 +94,12 @@ function n_augment_input(::AbstractICNF) end function steer_tspan( - icnf::AbstractICNF{T, <:ComputeMode, INPLACE, COND, AUGMENTED, true}, + icnf::AbstractICNF{<:AbstractFloat, <:ComputeMode, INPLACE, COND, AUGMENTED, true}, ::TrainMode, -) where {T <: AbstractFloat, INPLACE, COND, AUGMENTED} +) where {INPLACE, COND, AUGMENTED} t₀, t₁ = icnf.tspan Δt = abs(t₁ - t₀) - r = convert(T, rand(icnf.rng, icnf.steerdist)) + r = oftype(t₁, rand(icnf.rng, icnf.steerdist)) t₁_new = muladd(Δt, r, t₁) return (t₀, t₁_new) end @@ -504,12 +504,12 @@ function loss( end function make_ode_func( - icnf::AbstractICNF{T, CM, INPLACE}, + icnf::AbstractICNF{T}, mode::Mode, nn::LuxCore.AbstractLuxLayer, st::NamedTuple, ϵ::AbstractVecOrMat{T}, -) where {T <: AbstractFloat, CM, INPLACE} +) where {T <: AbstractFloat} function ode_func(u::Any, p::Any, t::Any) return augmented_f(u, p, t, icnf, mode, nn, st, ϵ) end @@ -521,19 +521,19 @@ function make_ode_func( return ode_func end -function (icnf::AbstractICNF{T, CM, INPLACE, false})( +function (icnf::AbstractICNF{<:AbstractFloat, <:ComputeMode, INPLACE, false})( xs::AbstractVecOrMat, ps::Any, st::NamedTuple, -) where {T, CM, INPLACE} +) where {INPLACE} return first(inference(icnf, TrainMode(), xs, ps, st)), st end -function (icnf::AbstractICNF{T, CM, INPLACE, true})( +function (icnf::AbstractICNF{<:AbstractFloat, <:ComputeMode, INPLACE, true})( xs_ys::Tuple, ps::Any, st::NamedTuple, -) where {T, CM, INPLACE} +) where {INPLACE} xs, ys = xs_ys return first(inference(icnf, TrainMode(), xs, ys, ps, st)), st end diff --git a/src/icnf.jl b/src/core/icnf.jl similarity index 100% rename from src/icnf.jl rename to src/core/icnf.jl diff --git a/src/types.jl b/src/core/types.jl similarity index 100% rename from src/types.jl rename to src/core/types.jl diff --git a/src/utils.jl b/src/core/utils.jl similarity index 100% rename from src/utils.jl rename to src/core/utils.jl diff --git a/src/exts/dist_ext/core_cond_icnf.jl b/src/exts/dist_ext/core_cond_icnf.jl index ca12e213..bd60d667 100644 --- a/src/exts/dist_ext/core_cond_icnf.jl +++ b/src/exts/dist_ext/core_cond_icnf.jl @@ -15,51 +15,66 @@ function CondICNFDist( return CondICNFDist(mach.model.m, mode, ys, ps, st) end -function Distributions._logpdf(d::CondICNFDist, x::AbstractVector{<:Real}) - return if d.m isa AbstractICNF{<:AbstractFloat, <:VectorMode} - first(inference(d.m, d.mode, x, d.ys, d.ps, d.st)) - elseif d.m isa AbstractICNF{<:AbstractFloat, <:MatrixMode} - @warn "to compute by matrices, data should be a matrix." - first(Distributions._logpdf(d, hcat(x))) - else - error("Not Implemented") - end +function Distributions._logpdf( + d::CondICNFDist{<:AbstractICNF{<:AbstractFloat, <:VectorMode}}, + x::AbstractVector{<:Real}, +) + return first(inference(d.m, d.mode, x, d.ys, d.ps, d.st)) +end + +function Distributions._logpdf( + d::CondICNFDist{<:AbstractICNF{<:AbstractFloat, <:MatrixMode}}, + x::AbstractVector{<:Real}, +) + @warn "to compute by matrices, data should be a matrix." + return first(Distributions._logpdf(d, hcat(x))) +end + +function Distributions._logpdf( + d::CondICNFDist{<:AbstractICNF{<:AbstractFloat, <:VectorMode}}, + A::AbstractMatrix{<:Real}, +) + @warn "to compute by vectors, data should be a vector." + return Distributions._logpdf.(d, collect(collect.(eachcol(A)))) end -function Distributions._logpdf(d::CondICNFDist, A::AbstractMatrix{<:Real}) - return if d.m isa AbstractICNF{<:AbstractFloat, <:VectorMode} - @warn "to compute by vectors, data should be a vector." - Distributions._logpdf.(d, collect(collect.(eachcol(A)))) - elseif d.m isa AbstractICNF{<:AbstractFloat, <:MatrixMode} - first(inference(d.m, d.mode, A, d.ys[:, begin:size(A, 2)], d.ps, d.st)) - else - error("Not Implemented") - end + +function Distributions._logpdf( + d::CondICNFDist{<:AbstractICNF{<:AbstractFloat, <:MatrixMode}}, + A::AbstractMatrix{<:Real}, +) + return first(inference(d.m, d.mode, A, d.ys[:, begin:size(A, 2)], d.ps, d.st)) end + function Distributions._rand!( rng::Random.AbstractRNG, - d::CondICNFDist, + d::CondICNFDist{<:AbstractICNF{<:AbstractFloat, <:VectorMode}}, x::AbstractVector{<:Real}, ) - return if d.m isa AbstractICNF{<:AbstractFloat, <:VectorMode} - x .= generate(d.m, d.mode, d.ys, d.ps, d.st) - elseif d.m isa AbstractICNF{<:AbstractFloat, <:MatrixMode} - @warn "to compute by matrices, data should be a matrix." - x .= Distributions._rand!(rng, d, hcat(x)) - else - error("Not Implemented") - end + return x .= generate(d.m, d.mode, d.ys, d.ps, d.st) end + +function Distributions._rand!( + rng::Random.AbstractRNG, + d::CondICNFDist{<:AbstractICNF{<:AbstractFloat, <:MatrixMode}}, + x::AbstractVector{<:Real}, +) + @warn "to compute by matrices, data should be a matrix." + return x .= Distributions._rand!(rng, d, hcat(x)) +end + +function Distributions._rand!( + rng::Random.AbstractRNG, + d::CondICNFDist{<:AbstractICNF{<:AbstractFloat, <:VectorMode}}, + A::AbstractMatrix{<:Real}, +) + @warn "to compute by vectors, data should be a vector." + return A .= hcat(Distributions._rand!.(rng, d, collect(collect.(eachcol(A))))...) +end + function Distributions._rand!( rng::Random.AbstractRNG, - d::CondICNFDist, + d::CondICNFDist{<:AbstractICNF{<:AbstractFloat, <:MatrixMode}}, A::AbstractMatrix{<:Real}, ) - return if d.m isa AbstractICNF{<:AbstractFloat, <:VectorMode} - @warn "to compute by vectors, data should be a vector." - A .= hcat(Distributions._rand!.(rng, d, collect(collect.(eachcol(A))))...) - elseif d.m isa AbstractICNF{<:AbstractFloat, <:MatrixMode} - A .= generate(d.m, d.mode, d.ys[:, begin:size(A, 2)], d.ps, d.st, size(A, 2)) - else - error("Not Implemented") - end + return A .= generate(d.m, d.mode, d.ys[:, begin:size(A, 2)], d.ps, d.st, size(A, 2)) end diff --git a/src/exts/dist_ext/core_icnf.jl b/src/exts/dist_ext/core_icnf.jl index 2c937229..05d8279e 100644 --- a/src/exts/dist_ext/core_icnf.jl +++ b/src/exts/dist_ext/core_icnf.jl @@ -10,53 +10,66 @@ function ICNFDist(mach::MLJBase.Machine{<:ICNFModel}, mode::Mode) return ICNFDist(mach.model.m, mode, ps, st) end -function Distributions._logpdf(d::ICNFDist, x::AbstractVector{<:Real}) - return if d.m isa AbstractICNF{<:AbstractFloat, <:VectorMode} - first(inference(d.m, d.mode, x, d.ps, d.st)) - elseif d.m isa AbstractICNF{<:AbstractFloat, <:MatrixMode} - @warn "to compute by matrices, data should be a matrix." - first(Distributions._logpdf(d, hcat(x))) - else - error("Not Implemented") - end -end - -function Distributions._logpdf(d::ICNFDist, A::AbstractMatrix{<:Real}) - return if d.m isa AbstractICNF{<:AbstractFloat, <:VectorMode} - @warn "to compute by vectors, data should be a vector." - Distributions._logpdf.(d, collect(collect.(eachcol(A)))) - elseif d.m isa AbstractICNF{<:AbstractFloat, <:MatrixMode} - first(inference(d.m, d.mode, A, d.ps, d.st)) - else - error("Not Implemented") - end +function Distributions._logpdf( + d::ICNFDist{<:AbstractICNF{<:AbstractFloat, <:VectorMode}}, + x::AbstractVector{<:Real}, +) + return first(inference(d.m, d.mode, x, d.ps, d.st)) +end + +function Distributions._logpdf( + d::ICNFDist{<:AbstractICNF{<:AbstractFloat, <:MatrixMode}}, + x::AbstractVector{<:Real}, +) + @warn "to compute by matrices, data should be a matrix." + return first(Distributions._logpdf(d, hcat(x))) +end + +function Distributions._logpdf( + d::ICNFDist{<:AbstractICNF{<:AbstractFloat, <:VectorMode}}, + A::AbstractMatrix{<:Real}, +) + @warn "to compute by vectors, data should be a vector." + return Distributions._logpdf.(d, collect(collect.(eachcol(A)))) +end + +function Distributions._logpdf( + d::ICNFDist{<:AbstractICNF{<:AbstractFloat, <:MatrixMode}}, + A::AbstractMatrix{<:Real}, +) + return first(inference(d.m, d.mode, A, d.ps, d.st)) end function Distributions._rand!( rng::Random.AbstractRNG, - d::ICNFDist, + d::ICNFDist{<:AbstractICNF{<:AbstractFloat, <:VectorMode}}, x::AbstractVector{<:Real}, ) - return if d.m isa AbstractICNF{<:AbstractFloat, <:VectorMode} - x .= generate(d.m, d.mode, d.ps, d.st) - elseif d.m isa AbstractICNF{<:AbstractFloat, <:MatrixMode} - @warn "to compute by matrices, data should be a matrix." - x .= Distributions._rand!(rng, d, hcat(x)) - else - error("Not Implemented") - end + return x .= generate(d.m, d.mode, d.ps, d.st) end + +function Distributions._rand!( + rng::Random.AbstractRNG, + d::ICNFDist{<:AbstractICNF{<:AbstractFloat, <:MatrixMode}}, + x::AbstractVector{<:Real}, +) + @warn "to compute by matrices, data should be a matrix." + return x .= Distributions._rand!(rng, d, hcat(x)) +end + +function Distributions._rand!( + rng::Random.AbstractRNG, + d::ICNFDist{<:AbstractICNF{<:AbstractFloat, <:VectorMode}}, + A::AbstractMatrix{<:Real}, +) + @warn "to compute by vectors, data should be a vector." + return A .= hcat(Distributions._rand!.(rng, d, collect(collect.(eachcol(A))))...) +end + function Distributions._rand!( rng::Random.AbstractRNG, - d::ICNFDist, + d::ICNFDist{<:AbstractICNF{<:AbstractFloat, <:MatrixMode}}, A::AbstractMatrix{<:Real}, ) - return if d.m isa AbstractICNF{<:AbstractFloat, <:VectorMode} - @warn "to compute by vectors, data should be a vector." - A .= hcat(Distributions._rand!.(rng, d, collect(collect.(eachcol(A))))...) - elseif d.m isa AbstractICNF{<:AbstractFloat, <:MatrixMode} - A .= generate(d.m, d.mode, d.ps, d.st, size(A, 2)) - else - error("Not Implemented") - end + return A .= generate(d.m, d.mode, d.ps, d.st, size(A, 2)) end diff --git a/src/exts/mlj_ext/core.jl b/src/exts/mlj_ext/core.jl index 17e27a35..da466ae8 100644 --- a/src/exts/mlj_ext/core.jl +++ b/src/exts/mlj_ext/core.jl @@ -3,12 +3,7 @@ function MLJModelInterface.fitted_params(::MLJICNF, fitresult) return (learned_parameters = ps, states = st) end -function make_opt_loss( - icnf::AbstractICNF{T, CM, INPLACE, COND}, - mode::Mode, - st::NamedTuple, - loss_::Function, -) where {T, CM, INPLACE, COND} +function make_opt_loss(icnf::AbstractICNF, mode::Mode, st::NamedTuple, loss_::Function) function opt_loss(u::Any, data::Tuple{<:Any}) xs, = data return loss_(icnf, mode, xs, u, st) @@ -21,3 +16,35 @@ function make_opt_loss( return opt_loss end + +function make_dataloader( + icnf::AbstractICNF{<:AbstractFloat, <:VectorMode}, + ::Int, + data::Tuple, +) + return MLUtils.DataLoader( + data; + batchsize = -1, + shuffle = true, + partial = true, + rng = icnf.rng, + ) +end + +function make_dataloader( + icnf::AbstractICNF{<:AbstractFloat, <:MatrixMode}, + batchsize::Int, + data::Tuple, +) + return MLUtils.DataLoader( + data; + batchsize = if iszero(batchsize) + last(maximum(broadcast(size, data))) + else + batchsize + end, + shuffle = true, + partial = true, + rng = icnf.rng, + ) +end diff --git a/src/exts/mlj_ext/core_cond_icnf.jl b/src/exts/mlj_ext/core_cond_icnf.jl index ae8b2056..67b785e1 100644 --- a/src/exts/mlj_ext/core_cond_icnf.jl +++ b/src/exts/mlj_ext/core_cond_icnf.jl @@ -5,7 +5,7 @@ mutable struct CondICNFModel{AICNF <: AbstractICNF} <: MLJICNF{AICNF} optimizers::Tuple adtype::ADTypes.AbstractADType - batch_size::Int + batchsize::Int sol_kwargs::NamedTuple end @@ -14,10 +14,10 @@ function CondICNFModel( loss::Function = loss; optimizers::Tuple = (Optimisers.Adam(),), adtype::ADTypes.AbstractADType = ADTypes.AutoZygote(), - batch_size::Int = 32, + batchsize::Int = 32, sol_kwargs::NamedTuple = (;), ) - return CondICNFModel(m, loss, optimizers, adtype, batch_size, sol_kwargs) + return CondICNFModel(m, loss, optimizers, adtype, batchsize, sol_kwargs) end function MLJModelInterface.fit(model::CondICNFModel, verbosity, XY) @@ -30,22 +30,7 @@ function MLJModelInterface.fit(model::CondICNFModel, verbosity, XY) y = model.m.device(y) ps = model.m.device(ps) st = model.m.device(st) - data = if model.m.compute_mode isa VectorMode - MLUtils.DataLoader((x, y); batchsize = -1, shuffle = true, partial = true) - elseif model.m.compute_mode isa MatrixMode - MLUtils.DataLoader( - (x, y); - batchsize = if iszero(model.batch_size) - max(size(x, 2), size(y, 2)) - else - model.batch_size - end, - shuffle = true, - partial = true, - ) - else - error("Not Implemented") - end + data = make_dataloader(model.m, model.batchsize, (x, y)) data = model.m.device(data) optprob = SciMLBase.OptimizationProblem{true}( SciMLBase.OptimizationFunction{true}( diff --git a/src/exts/mlj_ext/core_icnf.jl b/src/exts/mlj_ext/core_icnf.jl index a38120e2..c5a8dbe3 100644 --- a/src/exts/mlj_ext/core_icnf.jl +++ b/src/exts/mlj_ext/core_icnf.jl @@ -5,7 +5,7 @@ mutable struct ICNFModel{AICNF <: AbstractICNF} <: MLJICNF{AICNF} optimizers::Tuple adtype::ADTypes.AbstractADType - batch_size::Int + batchsize::Int sol_kwargs::NamedTuple end @@ -14,10 +14,10 @@ function ICNFModel( loss::Function = loss; optimizers::Tuple = (Optimisers.Adam(),), adtype::ADTypes.AbstractADType = ADTypes.AutoZygote(), - batch_size::Int = 32, + batchsize::Int = 32, sol_kwargs::NamedTuple = (;), ) - return ICNFModel(m, loss, optimizers, adtype, batch_size, sol_kwargs) + return ICNFModel(m, loss, optimizers, adtype, batchsize, sol_kwargs) end function MLJModelInterface.fit(model::ICNFModel, verbosity, X) @@ -27,22 +27,7 @@ function MLJModelInterface.fit(model::ICNFModel, verbosity, X) x = model.m.device(x) ps = model.m.device(ps) st = model.m.device(st) - data = if model.m.compute_mode isa VectorMode - MLUtils.DataLoader((x,); batchsize = -1, shuffle = true, partial = true) - elseif model.m.compute_mode isa MatrixMode - MLUtils.DataLoader( - (x,); - batchsize = if iszero(model.batch_size) - size(x, 2) - else - model.batch_size - end, - shuffle = true, - partial = true, - ) - else - error("Not Implemented") - end + data = make_dataloader(model.m, model.batchsize, (x,)) data = model.m.device(data) optprob = SciMLBase.OptimizationProblem{true}( SciMLBase.OptimizationFunction{true}( diff --git a/src/layers/planar_layer.jl b/src/layers/planar_layer.jl index a1d8290d..5682ac00 100644 --- a/src/layers/planar_layer.jl +++ b/src/layers/planar_layer.jl @@ -14,8 +14,8 @@ end function PlanarLayer( nvars::Int, activation::Any = identity; - init_weight::Any = Lux.glorot_uniform, - init_bias::Any = Lux.zeros32, + init_weight::Any = WeightInitializers.glorot_uniform, + init_bias::Any = WeightInitializers.zeros32, use_bias::Bool = true, n_cond::Int = 0, ) diff --git a/test/regression_tests.jl b/test/ci_tests/regression_tests.jl similarity index 98% rename from test/regression_tests.jl rename to test/ci_tests/regression_tests.jl index 3089e93d..f90d02a0 100644 --- a/test/regression_tests.jl +++ b/test/ci_tests/regression_tests.jl @@ -34,7 +34,7 @@ Test.@testset "Regression Tests" begin df = DataFrames.DataFrame(transpose(r), :auto) model = ContinuousNormalizingFlows.ICNFModel( icnf; - batch_size = 0, + batchsize = 0, sol_kwargs = (; progress = true, epochs = 300), ) diff --git a/test/smoke_tests.jl b/test/ci_tests/smoke_tests.jl similarity index 98% rename from test/smoke_tests.jl rename to test/ci_tests/smoke_tests.jl index 1bd2fe26..13d71fad 100644 --- a/test/smoke_tests.jl +++ b/test/ci_tests/smoke_tests.jl @@ -127,10 +127,10 @@ Test.@testset "Smoke Tests" begin ) ps, st = LuxCore.setup(icnf.rng, icnf) ps = ComponentArrays.ComponentArray(ps) - r = device(r) - r2 = device(r2) - ps = device(ps) - st = device(st) + r = icnf.device(r) + r2 = icnf.device(r2) + ps = icnf.device(ps) + st = icnf.device(st) if GROUP != "All" && compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} && @@ -224,7 +224,7 @@ Test.@testset "Smoke Tests" begin model = ContinuousNormalizingFlows.CondICNFModel( icnf; adtype, - batch_size = 0, + batchsize = 0, sol_kwargs = (; progress = true, epochs = 2), ) mach = MLJBase.machine(model, (df, df2)) @@ -251,7 +251,7 @@ Test.@testset "Smoke Tests" begin model = ContinuousNormalizingFlows.ICNFModel( icnf; adtype, - batch_size = 0, + batchsize = 0, sol_kwargs = (; progress = true, epochs = 2), ) mach = MLJBase.machine(model, df) diff --git a/test/speed_tests.jl b/test/ci_tests/speed_tests.jl similarity index 99% rename from test/speed_tests.jl rename to test/ci_tests/speed_tests.jl index 8056aa1a..5b0780ba 100644 --- a/test/speed_tests.jl +++ b/test/ci_tests/speed_tests.jl @@ -62,7 +62,7 @@ Test.@testset "Speed Tests" begin model = ContinuousNormalizingFlows.ICNFModel( icnf; - batch_size = 0, + batchsize = 0, sol_kwargs = (; epochs = 5), ) diff --git a/test/checkby_Aqua_tests.jl b/test/quality_tests/checkby_Aqua_tests.jl similarity index 100% rename from test/checkby_Aqua_tests.jl rename to test/quality_tests/checkby_Aqua_tests.jl diff --git a/test/checkby_ExplicitImports_tests.jl b/test/quality_tests/checkby_ExplicitImports_tests.jl similarity index 100% rename from test/checkby_ExplicitImports_tests.jl rename to test/quality_tests/checkby_ExplicitImports_tests.jl diff --git a/test/checkby_JET_tests.jl b/test/quality_tests/checkby_JET_tests.jl similarity index 98% rename from test/checkby_JET_tests.jl rename to test/quality_tests/checkby_JET_tests.jl index 021bb0d9..efc7240e 100644 --- a/test/checkby_JET_tests.jl +++ b/test/quality_tests/checkby_JET_tests.jl @@ -111,10 +111,10 @@ Test.@testset "CheckByJET" begin ) ps, st = LuxCore.setup(icnf.rng, icnf) ps = ComponentArrays.ComponentArray(ps) - r = device(r) - r2 = device(r2) - ps = device(ps) - st = device(st) + r = icnf.device(r) + r2 = icnf.device(r2) + ps = icnf.device(ps) + st = icnf.device(st) if GROUP != "All" && compute_mode.adback isa ADTypes.AutoEnzyme{<:Enzyme.ForwardMode} && diff --git a/test/runtests.jl b/test/runtests.jl index 1e542e87..1c916e13 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -33,26 +33,26 @@ end Test.@testset "Overall" begin if GROUP == "All" || GROUP in ["SmokeXOut", "SmokeXIn", "SmokeXYOut", "SmokeXYIn"] - include("smoke_tests.jl") + include(joinpath("ci_tests", "smoke_tests.jl")) end if GROUP == "All" || GROUP == "Regression" - include("regression_tests.jl") + include(joinpath("ci_tests", "regression_tests.jl")) end if GROUP == "All" || GROUP == "Speed" - include("speed_tests.jl") + include(joinpath("ci_tests", "speed_tests.jl")) end if GROUP == "All" || GROUP == "CheckByAqua" - include("checkby_Aqua_tests.jl") + include(joinpath("quality_tests", "checkby_Aqua_tests.jl")) end if GROUP == "All" || GROUP == "CheckByJET" - include("checkby_JET_tests.jl") + include(joinpath("quality_tests", "checkby_JET_tests.jl")) end if GROUP == "All" || GROUP == "CheckByExplicitImports" - include("checkby_ExplicitImports_tests.jl") + include(joinpath("quality_tests", "checkby_ExplicitImports_tests.jl")) end end