diff --git a/benchmark/benchmarks.jl b/benchmark/benchmarks.jl index e5368cf0..6717b3f4 100644 --- a/benchmark/benchmarks.jl +++ b/benchmark/benchmarks.jl @@ -70,7 +70,7 @@ ps = ComponentArrays.ComponentArray(ps) function diff_loss_tn(x::Any) return ContinuousNormalizingFlows.loss( icnf, - ContinuousNormalizingFlows.TrainMode(), + ContinuousNormalizingFlows.TrainMode{true}(), r, x, st, @@ -79,7 +79,7 @@ end function diff_loss_tt(x::Any) return ContinuousNormalizingFlows.loss( icnf, - ContinuousNormalizingFlows.TestMode(), + ContinuousNormalizingFlows.TestMode{true}(), r, x, st, @@ -89,7 +89,7 @@ end function diff_loss_tn2(x::Any) return ContinuousNormalizingFlows.loss( icnf2, - ContinuousNormalizingFlows.TrainMode(), + ContinuousNormalizingFlows.TrainMode{true}(), r, x, st, @@ -98,7 +98,7 @@ end function diff_loss_tt2(x::Any) return ContinuousNormalizingFlows.loss( icnf2, - ContinuousNormalizingFlows.TestMode(), + ContinuousNormalizingFlows.TestMode{true}(), r, x, st, diff --git a/src/core/base_icnf.jl b/src/core/base_icnf.jl index 5991348c..b73de94b 100644 --- a/src/core/base_icnf.jl +++ b/src/core/base_icnf.jl @@ -95,7 +95,7 @@ end function steer_tspan( icnf::AbstractICNF{<:AbstractFloat, <:ComputeMode, INPLACE, COND, AUGMENTED, true}, - ::TrainMode, + ::TrainMode{true}, ) where {INPLACE, COND, AUGMENTED} t₀, t₁ = icnf.tspan Δt = abs(t₁ - t₀) @@ -124,9 +124,9 @@ end function inference_sol( icnf::AbstractICNF{T, <:VectorMode, INPLACE, COND, AUGMENTED, STEER, NORM_Z_AUG}, - mode::Mode, + mode::Mode{REG}, prob::SciMLBase.AbstractODEProblem{<:AbstractVector{<:Real}, NTuple{2, T}, INPLACE}, -) where {T <: AbstractFloat, INPLACE, COND, AUGMENTED, STEER, NORM_Z_AUG} +) where {T <: AbstractFloat, INPLACE, COND, AUGMENTED, STEER, NORM_Z_AUG, REG} n_aug = n_augment(icnf, mode) fsol = base_sol(icnf, prob) z = fsol[begin:(end - n_aug - 1)] @@ -134,7 +134,7 @@ function inference_sol( augs = fsol[(end - n_aug + 1):end] logpz = oftype(Δlogp, Distributions.logpdf(icnf.basedist, z)) logp̂x = logpz - Δlogp - Ȧ = if (NORM_Z_AUG && AUGMENTED) + Ȧ = if NORM_Z_AUG && AUGMENTED && REG n_aug_input = n_augment_input(icnf) z_aug = z[(end - n_aug_input + 1):end] LinearAlgebra.norm(z_aug) @@ -146,9 +146,9 @@ end function inference_sol( icnf::AbstractICNF{T, <:MatrixMode, INPLACE, COND, AUGMENTED, STEER, NORM_Z_AUG}, - mode::Mode, + mode::Mode{REG}, prob::SciMLBase.AbstractODEProblem{<:AbstractMatrix{<:Real}, NTuple{2, T}, INPLACE}, -) where {T <: AbstractFloat, INPLACE, COND, AUGMENTED, STEER, NORM_Z_AUG} +) where {T <: AbstractFloat, INPLACE, COND, AUGMENTED, STEER, NORM_Z_AUG, REG} n_aug = n_augment(icnf, mode) fsol = base_sol(icnf, prob) z = fsol[begin:(end - n_aug - 1), :] @@ -156,7 +156,7 @@ function inference_sol( augs = fsol[(end - n_aug + 1):end, :] logpz = oftype(Δlogp, Distributions.logpdf(icnf.basedist, z)) logp̂x = logpz - Δlogp - Ȧ = transpose(if (NORM_Z_AUG && AUGMENTED) + Ȧ = transpose(if NORM_Z_AUG && AUGMENTED && REG n_aug_input = n_augment_input(icnf) z_aug = z[(end - n_aug_input + 1):end, :] LinearAlgebra.norm.(eachcol(z_aug)) @@ -526,7 +526,7 @@ function (icnf::AbstractICNF{<:AbstractFloat, <:ComputeMode, INPLACE, false})( ps::Any, st::NamedTuple, ) where {INPLACE} - return first(inference(icnf, TrainMode(), xs, ps, st)), st + return first(inference(icnf, TrainMode{false}(), xs, ps, st)), st end function (icnf::AbstractICNF{<:AbstractFloat, <:ComputeMode, INPLACE, true})( @@ -535,5 +535,5 @@ function (icnf::AbstractICNF{<:AbstractFloat, <:ComputeMode, INPLACE, true})( st::NamedTuple, ) where {INPLACE} xs, ys = xs_ys - return first(inference(icnf, TrainMode(), xs, ys, ps, st)), st + return first(inference(icnf, TrainMode{false}(), xs, ys, ps, st)), st end diff --git a/src/core/icnf.jl b/src/core/icnf.jl index 43cf13fd..0c4622f0 100644 --- a/src/core/icnf.jl +++ b/src/core/icnf.jl @@ -103,7 +103,7 @@ struct ICNF{ λ₃::T end -function n_augment(::ICNF, ::TrainMode) +function n_augment(::ICNF, ::Mode) return 2 end @@ -111,18 +111,24 @@ function augmented_f( u::Any, p::Any, ::Any, - icnf::ICNF{T, <:DIVectorMode, false}, - mode::TestMode, + icnf::ICNF{T, <:DIVectorMode, false, COND, AUGMENTED, STEER, NORM_Z}, + mode::TestMode{REG}, nn::LuxCore.AbstractLuxLayer, st::NamedTuple, ϵ::AbstractVector{T}, -) where {T <: AbstractFloat} +) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, REG} n_aug = n_augment(icnf, mode) snn = LuxCore.StatefulLuxLayer{true}(nn, p, st) z = u[begin:(end - n_aug - 1)] ż, J = icnf_jacobian(icnf, mode, snn, z) l̇ = -LinearAlgebra.tr(J) - return vcat(ż, l̇) + Ė = if NORM_Z && REG + LinearAlgebra.norm(ż) + else + zero(T) + end + ṅ = zero(T) + return vcat(ż, l̇, Ė, ṅ) end function augmented_f( @@ -130,18 +136,24 @@ function augmented_f( u::Any, p::Any, ::Any, - icnf::ICNF{T, <:DIVectorMode, true}, - mode::TestMode, + icnf::ICNF{T, <:DIVectorMode, true, COND, AUGMENTED, STEER, NORM_Z}, + mode::TestMode{REG}, nn::LuxCore.AbstractLuxLayer, st::NamedTuple, ϵ::AbstractVector{T}, -) where {T <: AbstractFloat} +) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, REG} n_aug = n_augment(icnf, mode) snn = LuxCore.StatefulLuxLayer{true}(nn, p, st) z = u[begin:(end - n_aug - 1)] ż, J = icnf_jacobian(icnf, mode, snn, z) du[begin:(end - n_aug - 1)] .= ż du[(end - n_aug)] = -LinearAlgebra.tr(J) + du[(end - n_aug + 1)] = if NORM_Z && REG + LinearAlgebra.norm(ż) + else + zero(T) + end + du[(end - n_aug + 2)] = zero(T) return nothing end @@ -149,18 +161,30 @@ function augmented_f( u::Any, p::Any, ::Any, - icnf::ICNF{T, <:MatrixMode, false}, - mode::TestMode, + icnf::ICNF{T, <:MatrixMode, false, COND, AUGMENTED, STEER, NORM_Z}, + mode::TestMode{REG}, nn::LuxCore.AbstractLuxLayer, st::NamedTuple, ϵ::AbstractMatrix{T}, -) where {T <: AbstractFloat} +) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, REG} n_aug = n_augment(icnf, mode) snn = LuxCore.StatefulLuxLayer{true}(nn, p, st) z = u[begin:(end - n_aug - 1), :] ż, J = icnf_jacobian(icnf, mode, snn, z) l̇ = -transpose(LinearAlgebra.tr.(eachslice(J; dims = 3))) - return vcat(ż, l̇) + Ė = transpose(if NORM_Z && REG + LinearAlgebra.norm.(eachcol(ż)) + else + zrs_Ė = similar(ż, size(ż, 2)) + ChainRulesCore.@ignore_derivatives fill!(zrs_Ė, zero(T)) + zrs_Ė + end) + ṅ = transpose(begin + zrs_ṅ = similar(ż, size(ż, 2)) + ChainRulesCore.@ignore_derivatives fill!(zrs_ṅ, zero(T)) + zrs_ṅ + end) + return vcat(ż, l̇, Ė, ṅ) end function augmented_f( @@ -168,18 +192,24 @@ function augmented_f( u::Any, p::Any, ::Any, - icnf::ICNF{T, <:MatrixMode, true}, - mode::TestMode, + icnf::ICNF{T, <:MatrixMode, true, COND, AUGMENTED, STEER, NORM_Z}, + mode::TestMode{REG}, nn::LuxCore.AbstractLuxLayer, st::NamedTuple, ϵ::AbstractMatrix{T}, -) where {T <: AbstractFloat} +) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, REG} n_aug = n_augment(icnf, mode) snn = LuxCore.StatefulLuxLayer{true}(nn, p, st) z = u[begin:(end - n_aug - 1), :] ż, J = icnf_jacobian(icnf, mode, snn, z) du[begin:(end - n_aug - 1), :] .= ż du[(end - n_aug), :] .= -(LinearAlgebra.tr.(eachslice(J; dims = 3))) + du[(end - n_aug + 1), :] .= if NORM_Z && REG + LinearAlgebra.norm.(eachcol(ż)) + else + zero(T) + end + du[(end - n_aug + 2), :] .= zero(T) return nothing end @@ -188,22 +218,22 @@ function augmented_f( p::Any, ::Any, icnf::ICNF{T, <:DIVecJacVectorMode, false, COND, AUGMENTED, STEER, NORM_Z, NORM_J}, - mode::TrainMode, + mode::TrainMode{REG}, nn::LuxCore.AbstractLuxLayer, st::NamedTuple, ϵ::AbstractVector{T}, -) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J} +) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J, REG} n_aug = n_augment(icnf, mode) snn = LuxCore.StatefulLuxLayer{true}(nn, p, st) z = u[begin:(end - n_aug - 1)] ż, ϵJ = icnf_jacobian(icnf, mode, snn, z, ϵ) l̇ = -LinearAlgebra.dot(ϵJ, ϵ) - Ė = if NORM_Z + Ė = if NORM_Z && REG LinearAlgebra.norm(ż) else zero(T) end - ṅ = if NORM_J + ṅ = if NORM_J && REG LinearAlgebra.norm(ϵJ) else zero(T) @@ -217,23 +247,23 @@ function augmented_f( p::Any, ::Any, icnf::ICNF{T, <:DIVecJacVectorMode, true, COND, AUGMENTED, STEER, NORM_Z, NORM_J}, - mode::TrainMode, + mode::TrainMode{REG}, nn::LuxCore.AbstractLuxLayer, st::NamedTuple, ϵ::AbstractVector{T}, -) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J} +) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J, REG} n_aug = n_augment(icnf, mode) snn = LuxCore.StatefulLuxLayer{true}(nn, p, st) z = u[begin:(end - n_aug - 1)] ż, ϵJ = icnf_jacobian(icnf, mode, snn, z, ϵ) du[begin:(end - n_aug - 1)] .= ż du[(end - n_aug)] = -LinearAlgebra.dot(ϵJ, ϵ) - du[(end - n_aug + 1)] = if NORM_Z + du[(end - n_aug + 1)] = if NORM_Z && REG LinearAlgebra.norm(ż) else zero(T) end - du[(end - n_aug + 2)] = if NORM_J + du[(end - n_aug + 2)] = if NORM_J && REG LinearAlgebra.norm(ϵJ) else zero(T) @@ -246,22 +276,22 @@ function augmented_f( p::Any, ::Any, icnf::ICNF{T, <:DIJacVecVectorMode, false, COND, AUGMENTED, STEER, NORM_Z, NORM_J}, - mode::TrainMode, + mode::TrainMode{REG}, nn::LuxCore.AbstractLuxLayer, st::NamedTuple, ϵ::AbstractVector{T}, -) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J} +) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J, REG} n_aug = n_augment(icnf, mode) snn = LuxCore.StatefulLuxLayer{true}(nn, p, st) z = u[begin:(end - n_aug - 1)] ż, Jϵ = icnf_jacobian(icnf, mode, snn, z, ϵ) l̇ = -LinearAlgebra.dot(ϵ, Jϵ) - Ė = if NORM_Z + Ė = if NORM_Z && REG LinearAlgebra.norm(ż) else zero(T) end - ṅ = if NORM_J + ṅ = if NORM_J && REG LinearAlgebra.norm(Jϵ) else zero(T) @@ -275,23 +305,23 @@ function augmented_f( p::Any, ::Any, icnf::ICNF{T, <:DIJacVecVectorMode, true, COND, AUGMENTED, STEER, NORM_Z, NORM_J}, - mode::TrainMode, + mode::TrainMode{REG}, nn::LuxCore.AbstractLuxLayer, st::NamedTuple, ϵ::AbstractVector{T}, -) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J} +) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J, REG} n_aug = n_augment(icnf, mode) snn = LuxCore.StatefulLuxLayer{true}(nn, p, st) z = u[begin:(end - n_aug - 1)] ż, Jϵ = icnf_jacobian(icnf, mode, snn, z, ϵ) du[begin:(end - n_aug - 1)] .= ż du[(end - n_aug)] = -LinearAlgebra.dot(ϵ, Jϵ) - du[(end - n_aug + 1)] = if NORM_Z + du[(end - n_aug + 1)] = if NORM_Z && REG LinearAlgebra.norm(ż) else zero(T) end - du[(end - n_aug + 2)] = if NORM_J + du[(end - n_aug + 2)] = if NORM_J && REG LinearAlgebra.norm(Jϵ) else zero(T) @@ -304,24 +334,24 @@ function augmented_f( p::Any, ::Any, icnf::ICNF{T, <:DIVecJacMatrixMode, false, COND, AUGMENTED, STEER, NORM_Z, NORM_J}, - mode::TrainMode, + mode::TrainMode{REG}, nn::LuxCore.AbstractLuxLayer, st::NamedTuple, ϵ::AbstractMatrix{T}, -) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J} +) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J, REG} n_aug = n_augment(icnf, mode) snn = LuxCore.StatefulLuxLayer{true}(nn, p, st) z = u[begin:(end - n_aug - 1), :] ż, ϵJ = icnf_jacobian(icnf, mode, snn, z, ϵ) l̇ = -sum(ϵJ .* ϵ; dims = 1) - Ė = transpose(if NORM_Z + Ė = transpose(if NORM_Z && REG LinearAlgebra.norm.(eachcol(ż)) else zrs_Ė = similar(ż, size(ż, 2)) ChainRulesCore.@ignore_derivatives fill!(zrs_Ė, zero(T)) zrs_Ė end) - ṅ = transpose(if NORM_J + ṅ = transpose(if NORM_J && REG LinearAlgebra.norm.(eachcol(ϵJ)) else zrs_ṅ = similar(ż, size(ż, 2)) @@ -337,23 +367,23 @@ function augmented_f( p::Any, ::Any, icnf::ICNF{T, <:DIVecJacMatrixMode, true, COND, AUGMENTED, STEER, NORM_Z, NORM_J}, - mode::TrainMode, + mode::TrainMode{REG}, nn::LuxCore.AbstractLuxLayer, st::NamedTuple, ϵ::AbstractMatrix{T}, -) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J} +) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J, REG} n_aug = n_augment(icnf, mode) snn = LuxCore.StatefulLuxLayer{true}(nn, p, st) z = u[begin:(end - n_aug - 1), :] ż, ϵJ = icnf_jacobian(icnf, mode, snn, z, ϵ) du[begin:(end - n_aug - 1), :] .= ż du[(end - n_aug), :] .= -vec(sum(ϵJ .* ϵ; dims = 1)) - du[(end - n_aug + 1), :] .= if NORM_Z + du[(end - n_aug + 1), :] .= if NORM_Z && REG LinearAlgebra.norm.(eachcol(ż)) else zero(T) end - du[(end - n_aug + 2), :] .= if NORM_J + du[(end - n_aug + 2), :] .= if NORM_J && REG LinearAlgebra.norm.(eachcol(ϵJ)) else zero(T) @@ -366,24 +396,24 @@ function augmented_f( p::Any, ::Any, icnf::ICNF{T, <:DIJacVecMatrixMode, false, COND, AUGMENTED, STEER, NORM_Z, NORM_J}, - mode::TrainMode, + mode::TrainMode{REG}, nn::LuxCore.AbstractLuxLayer, st::NamedTuple, ϵ::AbstractMatrix{T}, -) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J} +) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J, REG} n_aug = n_augment(icnf, mode) snn = LuxCore.StatefulLuxLayer{true}(nn, p, st) z = u[begin:(end - n_aug - 1), :] ż, Jϵ = icnf_jacobian(icnf, mode, snn, z, ϵ) l̇ = -sum(ϵ .* Jϵ; dims = 1) - Ė = transpose(if NORM_Z + Ė = transpose(if NORM_Z && REG LinearAlgebra.norm.(eachcol(ż)) else zrs_Ė = similar(ż, size(ż, 2)) ChainRulesCore.@ignore_derivatives fill!(zrs_Ė, zero(T)) zrs_Ė end) - ṅ = transpose(if NORM_J + ṅ = transpose(if NORM_J && REG LinearAlgebra.norm.(eachcol(Jϵ)) else zrs_ṅ = similar(ż, size(ż, 2)) @@ -399,23 +429,23 @@ function augmented_f( p::Any, ::Any, icnf::ICNF{T, <:DIJacVecMatrixMode, true, COND, AUGMENTED, STEER, NORM_Z, NORM_J}, - mode::TrainMode, + mode::TrainMode{REG}, nn::LuxCore.AbstractLuxLayer, st::NamedTuple, ϵ::AbstractMatrix{T}, -) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J} +) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J, REG} n_aug = n_augment(icnf, mode) snn = LuxCore.StatefulLuxLayer{true}(nn, p, st) z = u[begin:(end - n_aug - 1), :] ż, Jϵ = icnf_jacobian(icnf, mode, snn, z, ϵ) du[begin:(end - n_aug - 1), :] .= ż du[(end - n_aug), :] .= -vec(sum(ϵ .* Jϵ; dims = 1)) - du[(end - n_aug + 1), :] .= if NORM_Z + du[(end - n_aug + 1), :] .= if NORM_Z && REG LinearAlgebra.norm.(eachcol(ż)) else zero(T) end - du[(end - n_aug + 2), :] .= if NORM_J + du[(end - n_aug + 2), :] .= if NORM_J && REG LinearAlgebra.norm.(eachcol(Jϵ)) else zero(T) @@ -428,24 +458,24 @@ function augmented_f( p::Any, ::Any, icnf::ICNF{T, <:LuxVecJacMatrixMode, false, COND, AUGMENTED, STEER, NORM_Z, NORM_J}, - mode::TrainMode, + mode::TrainMode{REG}, nn::LuxCore.AbstractLuxLayer, st::NamedTuple, ϵ::AbstractMatrix{T}, -) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J} +) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J, REG} n_aug = n_augment(icnf, mode) snn = LuxCore.StatefulLuxLayer{true}(nn, p, st) z = u[begin:(end - n_aug - 1), :] ż, ϵJ = icnf_jacobian(icnf, mode, snn, z, ϵ) l̇ = -sum(ϵJ .* ϵ; dims = 1) - Ė = transpose(if NORM_Z + Ė = transpose(if NORM_Z && REG LinearAlgebra.norm.(eachcol(ż)) else zrs_Ė = similar(ż, size(ż, 2)) ChainRulesCore.@ignore_derivatives fill!(zrs_Ė, zero(T)) zrs_Ė end) - ṅ = transpose(if NORM_J + ṅ = transpose(if NORM_J && REG LinearAlgebra.norm.(eachcol(ϵJ)) else zrs_ṅ = similar(ż, size(ż, 2)) @@ -461,23 +491,23 @@ function augmented_f( p::Any, ::Any, icnf::ICNF{T, <:LuxVecJacMatrixMode, true, COND, AUGMENTED, STEER, NORM_Z, NORM_J}, - mode::TrainMode, + mode::TrainMode{REG}, nn::LuxCore.AbstractLuxLayer, st::NamedTuple, ϵ::AbstractMatrix{T}, -) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J} +) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J, REG} n_aug = n_augment(icnf, mode) snn = LuxCore.StatefulLuxLayer{true}(nn, p, st) z = u[begin:(end - n_aug - 1), :] ż, ϵJ = icnf_jacobian(icnf, mode, snn, z, ϵ) du[begin:(end - n_aug - 1), :] .= ż du[(end - n_aug), :] .= -vec(sum(ϵJ .* ϵ; dims = 1)) - du[(end - n_aug + 1), :] .= if NORM_Z + du[(end - n_aug + 1), :] .= if NORM_Z && REG LinearAlgebra.norm.(eachcol(ż)) else zero(T) end - du[(end - n_aug + 2), :] .= if NORM_J + du[(end - n_aug + 2), :] .= if NORM_J && REG LinearAlgebra.norm.(eachcol(ϵJ)) else zero(T) @@ -490,24 +520,24 @@ function augmented_f( p::Any, ::Any, icnf::ICNF{T, <:LuxJacVecMatrixMode, false, COND, AUGMENTED, STEER, NORM_Z, NORM_J}, - mode::TrainMode, + mode::TrainMode{REG}, nn::LuxCore.AbstractLuxLayer, st::NamedTuple, ϵ::AbstractMatrix{T}, -) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J} +) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J, REG} n_aug = n_augment(icnf, mode) snn = LuxCore.StatefulLuxLayer{true}(nn, p, st) z = u[begin:(end - n_aug - 1), :] ż, Jϵ = icnf_jacobian(icnf, mode, snn, z, ϵ) l̇ = -sum(ϵ .* Jϵ; dims = 1) - Ė = transpose(if NORM_Z + Ė = transpose(if NORM_Z && REG LinearAlgebra.norm.(eachcol(ż)) else zrs_Ė = similar(ż, size(ż, 2)) ChainRulesCore.@ignore_derivatives fill!(zrs_Ė, zero(T)) zrs_Ė end) - ṅ = transpose(if NORM_J + ṅ = transpose(if NORM_J && REG LinearAlgebra.norm.(eachcol(Jϵ)) else zrs_ṅ = similar(ż, size(ż, 2)) @@ -523,23 +553,23 @@ function augmented_f( p::Any, ::Any, icnf::ICNF{T, <:LuxJacVecMatrixMode, true, COND, AUGMENTED, STEER, NORM_Z, NORM_J}, - mode::TrainMode, + mode::TrainMode{REG}, nn::LuxCore.AbstractLuxLayer, st::NamedTuple, ϵ::AbstractMatrix{T}, -) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J} +) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J, REG} n_aug = n_augment(icnf, mode) snn = LuxCore.StatefulLuxLayer{true}(nn, p, st) z = u[begin:(end - n_aug - 1), :] ż, Jϵ = icnf_jacobian(icnf, mode, snn, z, ϵ) du[begin:(end - n_aug - 1), :] .= ż du[(end - n_aug), :] .= -vec(sum(ϵ .* Jϵ; dims = 1)) - du[(end - n_aug + 1), :] .= if NORM_Z + du[(end - n_aug + 1), :] .= if NORM_Z && REG LinearAlgebra.norm.(eachcol(ż)) else zero(T) end - du[(end - n_aug + 2), :] .= if NORM_J + du[(end - n_aug + 2), :] .= if NORM_J && REG LinearAlgebra.norm.(eachcol(Jϵ)) else zero(T) @@ -549,7 +579,7 @@ end function loss( icnf::ICNF{<:AbstractFloat, <:VectorMode}, - mode::TrainMode, + mode::Mode, xs::AbstractVector{<:Real}, ps::Any, st::NamedTuple, @@ -560,7 +590,7 @@ end function loss( icnf::ICNF{<:AbstractFloat, <:VectorMode}, - mode::TrainMode, + mode::Mode, xs::AbstractVector{<:Real}, ys::AbstractVector{<:Real}, ps::Any, @@ -572,7 +602,7 @@ end function loss( icnf::ICNF{<:AbstractFloat, <:MatrixMode}, - mode::TrainMode, + mode::Mode, xs::AbstractMatrix{<:Real}, ps::Any, st::NamedTuple, @@ -583,7 +613,7 @@ end function loss( icnf::ICNF{<:AbstractFloat, <:MatrixMode}, - mode::TrainMode, + mode::Mode, xs::AbstractMatrix{<:Real}, ys::AbstractMatrix{<:Real}, ps::Any, diff --git a/src/core/types.jl b/src/core/types.jl index bcef0720..6f61224a 100644 --- a/src/core/types.jl +++ b/src/core/types.jl @@ -1,6 +1,14 @@ -abstract type Mode end -struct TestMode <: Mode end -struct TrainMode <: Mode end +abstract type Mode{REG} end +struct TestMode{REG} <: Mode{REG} end +struct TrainMode{REG} <: Mode{REG} end + +function TestMode() + return TestMode{false}() +end + +function TrainMode() + return TrainMode{true}() +end abstract type ComputeMode{ADBack} end abstract type VectorMode{ADBack} <: ComputeMode{ADBack} end diff --git a/src/exts/mlj_ext/core_cond_icnf.jl b/src/exts/mlj_ext/core_cond_icnf.jl index 67b785e1..a453af17 100644 --- a/src/exts/mlj_ext/core_cond_icnf.jl +++ b/src/exts/mlj_ext/core_cond_icnf.jl @@ -34,7 +34,7 @@ function MLJModelInterface.fit(model::CondICNFModel, verbosity, XY) data = model.m.device(data) optprob = SciMLBase.OptimizationProblem{true}( SciMLBase.OptimizationFunction{true}( - make_opt_loss(model.m, TrainMode(), st, model.loss), + make_opt_loss(model.m, TrainMode{true}(), st, model.loss), model.adtype, ), ps, @@ -66,13 +66,13 @@ function MLJModelInterface.transform(model::CondICNFModel, fitresult, XYnew) @warn "to compute by vectors, data should be a vector." broadcast( function (x::AbstractVector{<:Real}, y::AbstractVector{<:Real}) - return first(inference(model.m, TestMode(), x, y, ps, st)) + return first(inference(model.m, TestMode{false}(), x, y, ps, st)) end, collect(collect.(eachcol(xnew))), collect(collect.(eachcol(ynew))), ) elseif model.m.compute_mode isa MatrixMode - first(inference(model.m, TestMode(), xnew, ynew, ps, st)) + first(inference(model.m, TestMode{false}(), xnew, ynew, ps, st)) else error("Not Implemented") end diff --git a/src/exts/mlj_ext/core_icnf.jl b/src/exts/mlj_ext/core_icnf.jl index c5a8dbe3..d144840a 100644 --- a/src/exts/mlj_ext/core_icnf.jl +++ b/src/exts/mlj_ext/core_icnf.jl @@ -31,7 +31,7 @@ function MLJModelInterface.fit(model::ICNFModel, verbosity, X) data = model.m.device(data) optprob = SciMLBase.OptimizationProblem{true}( SciMLBase.OptimizationFunction{true}( - make_opt_loss(model.m, TrainMode(), st, model.loss), + make_opt_loss(model.m, TrainMode{true}(), st, model.loss), model.adtype, ), ps, @@ -60,12 +60,12 @@ function MLJModelInterface.transform(model::ICNFModel, fitresult, Xnew) @warn "to compute by vectors, data should be a vector." broadcast( function (x::AbstractVector{<:Real}) - return first(inference(model.m, TestMode(), x, ps, st)) + return first(inference(model.m, TestMode{false}(), x, ps, st)) end, collect(collect.(eachcol(xnew))), ) elseif model.m.compute_mode isa MatrixMode - first(inference(model.m, TestMode(), xnew, ps, st)) + first(inference(model.m, TestMode{false}(), xnew, ps, st)) else error("Not Implemented") end diff --git a/test/ci_tests/regression_tests.jl b/test/ci_tests/regression_tests.jl index f90d02a0..da21510e 100644 --- a/test/ci_tests/regression_tests.jl +++ b/test/ci_tests/regression_tests.jl @@ -41,7 +41,10 @@ Test.@testset "Regression Tests" begin mach = MLJBase.machine(model, df) MLJBase.fit!(mach) - d = ContinuousNormalizingFlows.ICNFDist(mach, ContinuousNormalizingFlows.TestMode()) + d = ContinuousNormalizingFlows.ICNFDist( + mach, + ContinuousNormalizingFlows.TestMode{true}(), + ) actual_pdf = Distributions.pdf.(data_dist, r) estimated_pdf = Distributions.pdf(d, r) diff --git a/test/ci_tests/smoke_tests.jl b/test/ci_tests/smoke_tests.jl index 13d71fad..2ed695cd 100644 --- a/test/ci_tests/smoke_tests.jl +++ b/test/ci_tests/smoke_tests.jl @@ -1,8 +1,8 @@ Test.@testset "Smoke Tests" begin mts = Type{<:ContinuousNormalizingFlows.AbstractICNF}[ContinuousNormalizingFlows.ICNF] omodes = ContinuousNormalizingFlows.Mode[ - ContinuousNormalizingFlows.TrainMode(), - ContinuousNormalizingFlows.TestMode(), + ContinuousNormalizingFlows.TrainMode{true}(), + ContinuousNormalizingFlows.TestMode{true}(), ] conds, inplaces = if GROUP == "SmokeXOut" Bool[false], Bool[false] diff --git a/test/ci_tests/speed_tests.jl b/test/ci_tests/speed_tests.jl index 5b0780ba..94d28e90 100644 --- a/test/ci_tests/speed_tests.jl +++ b/test/ci_tests/speed_tests.jl @@ -71,7 +71,10 @@ Test.@testset "Speed Tests" begin @show only(MLJBase.report(mach).stats).time - d = ContinuousNormalizingFlows.ICNFDist(mach, ContinuousNormalizingFlows.TestMode()) + d = ContinuousNormalizingFlows.ICNFDist( + mach, + ContinuousNormalizingFlows.TestMode{true}(), + ) actual_pdf = Distributions.pdf.(data_dist, r) estimated_pdf = Distributions.pdf(d, r) diff --git a/test/quality_tests/checkby_JET_tests.jl b/test/quality_tests/checkby_JET_tests.jl index efc7240e..12a95bec 100644 --- a/test/quality_tests/checkby_JET_tests.jl +++ b/test/quality_tests/checkby_JET_tests.jl @@ -7,8 +7,8 @@ Test.@testset "CheckByJET" begin mts = Type{<:ContinuousNormalizingFlows.AbstractICNF}[ContinuousNormalizingFlows.ICNF] omodes = ContinuousNormalizingFlows.Mode[ - ContinuousNormalizingFlows.TrainMode(), - ContinuousNormalizingFlows.TestMode(), + ContinuousNormalizingFlows.TrainMode{true}(), + ContinuousNormalizingFlows.TestMode{true}(), ] conds = Bool[false, true] inplaces = Bool[false, true] diff --git a/test/runtests.jl b/test/runtests.jl index 1c916e13..430284f9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -24,7 +24,7 @@ import ADTypes, GROUP = get(ENV, "GROUP", "All") -if (GROUP == "All") +if GROUP == "All" GC.enable_logging(true) debuglogger = TerminalLoggers.TerminalLogger(stderr, Logging.Debug)