diff --git a/.github/workflows/CI-CheckBy.yml b/.github/workflows/CI-CheckBy.yml index 509e217b..4944b04e 100644 --- a/.github/workflows/CI-CheckBy.yml +++ b/.github/workflows/CI-CheckBy.yml @@ -26,7 +26,7 @@ jobs: version: - release - lts - - nightly + # - nightly os: - ubuntu-latest # - macOS-latest diff --git a/src/icnf.jl b/src/icnf.jl index 404c22ee..43cf13fd 100644 --- a/src/icnf.jl +++ b/src/icnf.jl @@ -120,7 +120,7 @@ function augmented_f( n_aug = n_augment(icnf, mode) snn = LuxCore.StatefulLuxLayer{true}(nn, p, st) z = u[begin:(end - n_aug - 1)] - ż, J = DifferentiationInterface.value_and_jacobian(snn, icnf.compute_mode.adback, z) + ż, J = icnf_jacobian(icnf, mode, snn, z) l̇ = -LinearAlgebra.tr(J) return vcat(ż, l̇) end @@ -139,7 +139,7 @@ function augmented_f( n_aug = n_augment(icnf, mode) snn = LuxCore.StatefulLuxLayer{true}(nn, p, st) z = u[begin:(end - n_aug - 1)] - ż, J = DifferentiationInterface.value_and_jacobian(snn, icnf.compute_mode.adback, z) + ż, J = icnf_jacobian(icnf, mode, snn, z) du[begin:(end - n_aug - 1)] .= ż du[(end - n_aug)] = -LinearAlgebra.tr(J) return nothing @@ -158,8 +158,8 @@ function augmented_f( n_aug = n_augment(icnf, mode) snn = LuxCore.StatefulLuxLayer{true}(nn, p, st) z = u[begin:(end - n_aug - 1), :] - ż, J = jacobian_batched(icnf, snn, z) - l̇ = -transpose(LinearAlgebra.tr.(J)) + ż, J = icnf_jacobian(icnf, mode, snn, z) + l̇ = -transpose(LinearAlgebra.tr.(eachslice(J; dims = 3))) return vcat(ż, l̇) end @@ -177,9 +177,9 @@ function augmented_f( n_aug = n_augment(icnf, mode) snn = LuxCore.StatefulLuxLayer{true}(nn, p, st) z = u[begin:(end - n_aug - 1), :] - ż, J = jacobian_batched(icnf, snn, z) + ż, J = icnf_jacobian(icnf, mode, snn, z) du[begin:(end - n_aug - 1), :] .= ż - du[(end - n_aug), :] .= -(LinearAlgebra.tr.(J)) + du[(end - n_aug), :] .= -(LinearAlgebra.tr.(eachslice(J; dims = 3))) return nothing end @@ -196,9 +196,7 @@ function augmented_f( n_aug = n_augment(icnf, mode) snn = LuxCore.StatefulLuxLayer{true}(nn, p, st) z = u[begin:(end - n_aug - 1)] - ż, ϵJ = - DifferentiationInterface.value_and_pullback(snn, icnf.compute_mode.adback, z, (ϵ,)) - ϵJ = only(ϵJ) + ż, ϵJ = icnf_jacobian(icnf, mode, snn, z, ϵ) l̇ = -LinearAlgebra.dot(ϵJ, ϵ) Ė = if NORM_Z LinearAlgebra.norm(ż) @@ -227,9 +225,7 @@ function augmented_f( n_aug = n_augment(icnf, mode) snn = LuxCore.StatefulLuxLayer{true}(nn, p, st) z = u[begin:(end - n_aug - 1)] - ż, ϵJ = - DifferentiationInterface.value_and_pullback(snn, icnf.compute_mode.adback, z, (ϵ,)) - ϵJ = only(ϵJ) + ż, ϵJ = icnf_jacobian(icnf, mode, snn, z, ϵ) du[begin:(end - n_aug - 1)] .= ż du[(end - n_aug)] = -LinearAlgebra.dot(ϵJ, ϵ) du[(end - n_aug + 1)] = if NORM_Z @@ -258,13 +254,7 @@ function augmented_f( n_aug = n_augment(icnf, mode) snn = LuxCore.StatefulLuxLayer{true}(nn, p, st) z = u[begin:(end - n_aug - 1)] - ż, Jϵ = DifferentiationInterface.value_and_pushforward( - snn, - icnf.compute_mode.adback, - z, - (ϵ,), - ) - Jϵ = only(Jϵ) + ż, Jϵ = icnf_jacobian(icnf, mode, snn, z, ϵ) l̇ = -LinearAlgebra.dot(ϵ, Jϵ) Ė = if NORM_Z LinearAlgebra.norm(ż) @@ -293,13 +283,7 @@ function augmented_f( n_aug = n_augment(icnf, mode) snn = LuxCore.StatefulLuxLayer{true}(nn, p, st) z = u[begin:(end - n_aug - 1)] - ż, Jϵ = DifferentiationInterface.value_and_pushforward( - snn, - icnf.compute_mode.adback, - z, - (ϵ,), - ) - Jϵ = only(Jϵ) + ż, Jϵ = icnf_jacobian(icnf, mode, snn, z, ϵ) du[begin:(end - n_aug - 1)] .= ż du[(end - n_aug)] = -LinearAlgebra.dot(ϵ, Jϵ) du[(end - n_aug + 1)] = if NORM_Z @@ -328,9 +312,7 @@ function augmented_f( n_aug = n_augment(icnf, mode) snn = LuxCore.StatefulLuxLayer{true}(nn, p, st) z = u[begin:(end - n_aug - 1), :] - ż, ϵJ = - DifferentiationInterface.value_and_pullback(snn, icnf.compute_mode.adback, z, (ϵ,)) - ϵJ = only(ϵJ) + ż, ϵJ = icnf_jacobian(icnf, mode, snn, z, ϵ) l̇ = -sum(ϵJ .* ϵ; dims = 1) Ė = transpose(if NORM_Z LinearAlgebra.norm.(eachcol(ż)) @@ -363,9 +345,7 @@ function augmented_f( n_aug = n_augment(icnf, mode) snn = LuxCore.StatefulLuxLayer{true}(nn, p, st) z = u[begin:(end - n_aug - 1), :] - ż, ϵJ = - DifferentiationInterface.value_and_pullback(snn, icnf.compute_mode.adback, z, (ϵ,)) - ϵJ = only(ϵJ) + ż, ϵJ = icnf_jacobian(icnf, mode, snn, z, ϵ) du[begin:(end - n_aug - 1), :] .= ż du[(end - n_aug), :] .= -vec(sum(ϵJ .* ϵ; dims = 1)) du[(end - n_aug + 1), :] .= if NORM_Z @@ -394,13 +374,7 @@ function augmented_f( n_aug = n_augment(icnf, mode) snn = LuxCore.StatefulLuxLayer{true}(nn, p, st) z = u[begin:(end - n_aug - 1), :] - ż, Jϵ = DifferentiationInterface.value_and_pushforward( - snn, - icnf.compute_mode.adback, - z, - (ϵ,), - ) - Jϵ = only(Jϵ) + ż, Jϵ = icnf_jacobian(icnf, mode, snn, z, ϵ) l̇ = -sum(ϵ .* Jϵ; dims = 1) Ė = transpose(if NORM_Z LinearAlgebra.norm.(eachcol(ż)) @@ -433,13 +407,7 @@ function augmented_f( n_aug = n_augment(icnf, mode) snn = LuxCore.StatefulLuxLayer{true}(nn, p, st) z = u[begin:(end - n_aug - 1), :] - ż, Jϵ = DifferentiationInterface.value_and_pushforward( - snn, - icnf.compute_mode.adback, - z, - (ϵ,), - ) - Jϵ = only(Jϵ) + ż, Jϵ = icnf_jacobian(icnf, mode, snn, z, ϵ) du[begin:(end - n_aug - 1), :] .= ż du[(end - n_aug), :] .= -vec(sum(ϵ .* Jϵ; dims = 1)) du[(end - n_aug + 1), :] .= if NORM_Z @@ -468,8 +436,7 @@ function augmented_f( n_aug = n_augment(icnf, mode) snn = LuxCore.StatefulLuxLayer{true}(nn, p, st) z = u[begin:(end - n_aug - 1), :] - ż = snn(z) - ϵJ = Lux.vector_jacobian_product(snn, icnf.compute_mode.adback, z, ϵ) + ż, ϵJ = icnf_jacobian(icnf, mode, snn, z, ϵ) l̇ = -sum(ϵJ .* ϵ; dims = 1) Ė = transpose(if NORM_Z LinearAlgebra.norm.(eachcol(ż)) @@ -502,8 +469,7 @@ function augmented_f( n_aug = n_augment(icnf, mode) snn = LuxCore.StatefulLuxLayer{true}(nn, p, st) z = u[begin:(end - n_aug - 1), :] - ż = snn(z) - ϵJ = Lux.vector_jacobian_product(snn, icnf.compute_mode.adback, z, ϵ) + ż, ϵJ = icnf_jacobian(icnf, mode, snn, z, ϵ) du[begin:(end - n_aug - 1), :] .= ż du[(end - n_aug), :] .= -vec(sum(ϵJ .* ϵ; dims = 1)) du[(end - n_aug + 1), :] .= if NORM_Z @@ -532,8 +498,7 @@ function augmented_f( n_aug = n_augment(icnf, mode) snn = LuxCore.StatefulLuxLayer{true}(nn, p, st) z = u[begin:(end - n_aug - 1), :] - ż = snn(z) - Jϵ = Lux.jacobian_vector_product(snn, icnf.compute_mode.adback, z, ϵ) + ż, Jϵ = icnf_jacobian(icnf, mode, snn, z, ϵ) l̇ = -sum(ϵ .* Jϵ; dims = 1) Ė = transpose(if NORM_Z LinearAlgebra.norm.(eachcol(ż)) @@ -566,8 +531,7 @@ function augmented_f( n_aug = n_augment(icnf, mode) snn = LuxCore.StatefulLuxLayer{true}(nn, p, st) z = u[begin:(end - n_aug - 1), :] - ż = snn(z) - Jϵ = Lux.jacobian_vector_product(snn, icnf.compute_mode.adback, z, ϵ) + ż, Jϵ = icnf_jacobian(icnf, mode, snn, z, ϵ) du[begin:(end - n_aug - 1), :] .= ż du[(end - n_aug), :] .= -vec(sum(ϵ .* Jϵ; dims = 1)) du[(end - n_aug + 1), :] .= if NORM_Z diff --git a/src/utils.jl b/src/utils.jl index df9780dc..c0ce980e 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,40 +1,66 @@ -function jacobian_batched( +function icnf_jacobian( + icnf::AbstractICNF{<:AbstractFloat, <:DIVectorMode}, + ::TestMode, + f::LuxCore.StatefulLuxLayer, + xs::AbstractVector{<:Real}, +) + y = f(xs) + return y, + oftype(hcat(y), DifferentiationInterface.jacobian(f, icnf.compute_mode.adback, xs)) +end + +function icnf_jacobian( + icnf::AbstractICNF{<:AbstractFloat, <:DIMatrixMode}, + ::TestMode, + f::LuxCore.StatefulLuxLayer, + xs::AbstractMatrix{<:Real}, +) + y = f(xs) + J = DifferentiationInterface.jacobian(f, icnf.compute_mode.adback, xs) + return y, + oftype( + cat(y; dims = Val(3)), + cat( + ( + J[i:j, i:j] for (i, j) in zip( + firstindex(J, 1):size(y, 1):lastindex(J, 1), + (firstindex(J, 1) + size(y, 1) - 1):size(y, 1):lastindex(J, 1), + ) + )...; + dims = Val(3), + ), + ) +end + +function icnf_jacobian( icnf::AbstractICNF{T, <:DIVecJacMatrixMode}, + ::TestMode, f::LuxCore.StatefulLuxLayer, xs::AbstractMatrix{<:Real}, -) where {T} +) where {T <: AbstractFloat} y = f(xs) z = similar(xs) ChainRulesCore.@ignore_derivatives fill!(z, zero(T)) - res = Zygote.Buffer( - convert.(promote_type(eltype(xs), eltype(f.ps)), xs), - size(xs, 1), - size(xs, 1), - size(xs, 2), - ) + res = Zygote.Buffer(y, size(xs, 1), size(xs, 1), size(xs, 2)) for i in axes(xs, 1) ChainRulesCore.@ignore_derivatives z[i, :] .= one(T) res[i, :, :] = only(DifferentiationInterface.pullback(f, icnf.compute_mode.adback, xs, (z,))) ChainRulesCore.@ignore_derivatives z[i, :] .= zero(T) end - return y, eachslice(copy(res); dims = 3) + return y, oftype(cat(y; dims = Val(3)), copy(res)) end -function jacobian_batched( +function icnf_jacobian( icnf::AbstractICNF{T, <:DIJacVecMatrixMode}, + ::TestMode, f::LuxCore.StatefulLuxLayer, xs::AbstractMatrix{<:Real}, -) where {T} +) where {T <: AbstractFloat} y = f(xs) z = similar(xs) ChainRulesCore.@ignore_derivatives fill!(z, zero(T)) - res = Zygote.Buffer( - convert.(promote_type(eltype(xs), eltype(f.ps)), xs), - size(xs, 1), - size(xs, 1), - size(xs, 2), - ) + res = Zygote.Buffer(y, size(xs, 1), size(xs, 1), size(xs, 2)) for i in axes(xs, 1) ChainRulesCore.@ignore_derivatives z[i, :] .= one(T) res[:, i, :] = only( @@ -42,33 +68,98 @@ function jacobian_batched( ) ChainRulesCore.@ignore_derivatives z[i, :] .= zero(T) end - return y, eachslice(copy(res); dims = 3) + return y, oftype(cat(y; dims = Val(3)), copy(res)) end -function jacobian_batched( - icnf::AbstractICNF{T, <:DIMatrixMode}, +function icnf_jacobian( + icnf::AbstractICNF{<:AbstractFloat, <:LuxMatrixMode}, + ::TestMode, f::LuxCore.StatefulLuxLayer, xs::AbstractMatrix{<:Real}, -) where {T} - y, J = DifferentiationInterface.value_and_jacobian(f, icnf.compute_mode.adback, xs) - return y, split_jac(J, size(xs, 1)) +) + y = f(xs) + return y, + oftype(cat(y; dims = Val(3)), Lux.batched_jacobian(f, icnf.compute_mode.adback, xs)) end -function split_jac(x::AbstractMatrix{<:Real}, sz::Integer) - return ( - x[i:j, i:j] for (i, j) in zip( - firstindex(x, 1):sz:lastindex(x, 1), - (firstindex(x, 1) + sz - 1):sz:lastindex(x, 1), - ) +function icnf_jacobian( + icnf::AbstractICNF{T, <:DIVecJacVectorMode}, + ::TrainMode, + f::LuxCore.StatefulLuxLayer, + xs::AbstractVector{<:Real}, + ϵ::AbstractVector{T}, +) where {T <: AbstractFloat} + y = f(xs) + return y, + oftype( + y, + only(DifferentiationInterface.pullback(f, icnf.compute_mode.adback, xs, (ϵ,))), + ) +end + +function icnf_jacobian( + icnf::AbstractICNF{T, <:DIJacVecVectorMode}, + ::TrainMode, + f::LuxCore.StatefulLuxLayer, + xs::AbstractVector{<:Real}, + ϵ::AbstractVector{T}, +) where {T <: AbstractFloat} + y = f(xs) + return y, + oftype( + y, + only(DifferentiationInterface.pushforward(f, icnf.compute_mode.adback, xs, (ϵ,))), + ) +end + +function icnf_jacobian( + icnf::AbstractICNF{T, <:DIVecJacMatrixMode}, + ::TrainMode, + f::LuxCore.StatefulLuxLayer, + xs::AbstractMatrix{<:Real}, + ϵ::AbstractMatrix{T}, +) where {T <: AbstractFloat} + y = f(xs) + return y, + oftype( + y, + only(DifferentiationInterface.pullback(f, icnf.compute_mode.adback, xs, (ϵ,))), + ) +end + +function icnf_jacobian( + icnf::AbstractICNF{T, <:DIJacVecMatrixMode}, + ::TrainMode, + f::LuxCore.StatefulLuxLayer, + xs::AbstractMatrix{<:Real}, + ϵ::AbstractMatrix{T}, +) where {T <: AbstractFloat} + y = f(xs) + return y, + oftype( + y, + only(DifferentiationInterface.pushforward(f, icnf.compute_mode.adback, xs, (ϵ,))), ) end -function jacobian_batched( - icnf::AbstractICNF{T, <:LuxMatrixMode}, +function icnf_jacobian( + icnf::AbstractICNF{T, <:LuxVecJacMatrixMode}, + ::TrainMode, + f::LuxCore.StatefulLuxLayer, + xs::AbstractMatrix{<:Real}, + ϵ::AbstractMatrix{T}, +) where {T <: AbstractFloat} + y = f(xs) + return y, oftype(y, Lux.vector_jacobian_product(f, icnf.compute_mode.adback, xs, ϵ)) +end + +function icnf_jacobian( + icnf::AbstractICNF{T, <:LuxJacVecMatrixMode}, + ::TrainMode, f::LuxCore.StatefulLuxLayer, xs::AbstractMatrix{<:Real}, -) where {T} + ϵ::AbstractMatrix{T}, +) where {T <: AbstractFloat} y = f(xs) - J = Lux.batched_jacobian(f, icnf.compute_mode.adback, xs) - return y, eachslice(J; dims = 3) + return y, oftype(y, Lux.jacobian_vector_product(f, icnf.compute_mode.adback, xs, ϵ)) end diff --git a/test/checkby_JET_tests.jl b/test/checkby_JET_tests.jl index 2e57b511..6119889e 100644 --- a/test/checkby_JET_tests.jl +++ b/test/checkby_JET_tests.jl @@ -5,43 +5,143 @@ Test.@testset "CheckByJET" begin mode = :typo, ) - nvars = 2^3 - naugs = nvars - n_in = nvars + naugs - n = 2^6 - nn = Lux.Chain(Lux.Dense(n_in => n_in, tanh)) - - icnf = ContinuousNormalizingFlows.construct( - ContinuousNormalizingFlows.ICNF, - nn, - nvars, - naugs; - compute_mode = ContinuousNormalizingFlows.LuxVecJacMatrixMode(ADTypes.AutoZygote()), - tspan = (0.0f0, 1.0f0), - steer_rate = 1.0f-1, - λ₁ = 1.0f-2, - λ₂ = 1.0f-2, - λ₃ = 1.0f-2, - sol_kwargs = (; - save_everystep = false, - alg = OrdinaryDiffEqDefault.DefaultODEAlgorithm(), - sensealg = SciMLSensitivity.InterpolatingAdjoint(), + mts = Type{<:ContinuousNormalizingFlows.AbstractICNF}[ContinuousNormalizingFlows.ICNF] + omodes = ContinuousNormalizingFlows.Mode[ + ContinuousNormalizingFlows.TrainMode(), + ContinuousNormalizingFlows.TestMode(), + ] + conds = Bool[false, true] + inplaces = Bool[false, true] + planars = Bool[false, true] + nvars_ = Int[2] + ndata_ = Int[4] + data_types = Type{<:AbstractFloat}[Float32] + devices = MLDataDevices.AbstractDevice[MLDataDevices.cpu_device()] + compute_modes = ContinuousNormalizingFlows.ComputeMode[ + ContinuousNormalizingFlows.LuxVecJacMatrixMode(ADTypes.AutoZygote()), + ContinuousNormalizingFlows.DIVecJacVectorMode(ADTypes.AutoZygote()), + ContinuousNormalizingFlows.DIVecJacMatrixMode(ADTypes.AutoZygote()), + ContinuousNormalizingFlows.LuxJacVecMatrixMode(ADTypes.AutoForwardDiff()), + ContinuousNormalizingFlows.DIJacVecVectorMode(ADTypes.AutoForwardDiff()), + ContinuousNormalizingFlows.DIJacVecMatrixMode(ADTypes.AutoForwardDiff()), + ContinuousNormalizingFlows.DIVecJacVectorMode( + ADTypes.AutoEnzyme(; + mode = Enzyme.set_runtime_activity(Enzyme.Reverse), + function_annotation = Enzyme.Const, + ), ), - ) - ps, st = LuxCore.setup(icnf.rng, icnf) - ps = ComponentArrays.ComponentArray(ps) - r = rand(icnf.rng, Float32, nvars, n) + ContinuousNormalizingFlows.DIVecJacMatrixMode( + ADTypes.AutoEnzyme(; + mode = Enzyme.set_runtime_activity(Enzyme.Reverse), + function_annotation = Enzyme.Const, + ), + ), + ContinuousNormalizingFlows.DIJacVecVectorMode( + ADTypes.AutoEnzyme(; + mode = Enzyme.set_runtime_activity(Enzyme.Forward), + function_annotation = Enzyme.Const, + ), + ), + ContinuousNormalizingFlows.DIJacVecMatrixMode( + ADTypes.AutoEnzyme(; + mode = Enzyme.set_runtime_activity(Enzyme.Forward), + function_annotation = Enzyme.Const, + ), + ), + ] - ContinuousNormalizingFlows.loss(icnf, ContinuousNormalizingFlows.TrainMode(), r, ps, st) - JET.test_call( - ContinuousNormalizingFlows.loss, - Base.typesof(icnf, ContinuousNormalizingFlows.TrainMode(), r, ps, st); - target_modules = [ContinuousNormalizingFlows], - mode = :typo, - ) - JET.test_opt( - ContinuousNormalizingFlows.loss, - Base.typesof(icnf, ContinuousNormalizingFlows.TrainMode(), r, ps, st); - target_modules = [ContinuousNormalizingFlows], - ) + Test.@testset "$device | $data_type | $compute_mode | ndata = $ndata | nvars = $nvars | inplace = $inplace | cond = $cond | planar = $planar | $omode | $mt" for device in + devices, + data_type in data_types, + compute_mode in compute_modes, + ndata in ndata_, + nvars in nvars_, + inplace in inplaces, + cond in conds, + planar in planars, + omode in omodes, + mt in mts + + data_dist = + Distributions.Beta{data_type}(convert(Tuple{data_type, data_type}, (2, 4))...) + data_dist2 = + Distributions.Beta{data_type}(convert(Tuple{data_type, data_type}, (4, 2))...) + if compute_mode isa ContinuousNormalizingFlows.VectorMode + r = convert.(data_type, rand(data_dist, nvars)) + r2 = convert.(data_type, rand(data_dist2, nvars)) + elseif compute_mode isa ContinuousNormalizingFlows.MatrixMode + r = convert.(data_type, rand(data_dist, nvars, ndata)) + r2 = convert.(data_type, rand(data_dist2, nvars, ndata)) + end + + nn = ifelse( + cond, + ifelse( + planar, + Lux.Chain( + ContinuousNormalizingFlows.PlanarLayer(nvars * 2, tanh; n_cond = nvars), + ), + Lux.Chain(Lux.Dense(nvars * 3 => nvars * 2, tanh)), + ), + ifelse( + planar, + Lux.Chain(ContinuousNormalizingFlows.PlanarLayer(nvars * 2, tanh)), + Lux.Chain(Lux.Dense(nvars * 2 => nvars * 2, tanh)), + ), + ) + icnf = ContinuousNormalizingFlows.construct( + mt, + nn, + nvars, + nvars; + data_type, + compute_mode, + inplace, + cond, + device, + steer_rate = convert(data_type, 1.0e-1), + λ₁ = convert(data_type, 1.0e-2), + λ₂ = convert(data_type, 1.0e-2), + λ₃ = convert(data_type, 1.0e-2), + sol_kwargs = (; + save_everystep = false, + alg = OrdinaryDiffEqDefault.DefaultODEAlgorithm(), + sensealg = SciMLSensitivity.InterpolatingAdjoint(), + ), + ) + ps, st = LuxCore.setup(icnf.rng, icnf) + ps = ComponentArrays.ComponentArray(ps) + r = device(r) + r2 = device(r2) + ps = device(ps) + st = device(st) + + if cond + ContinuousNormalizingFlows.loss(icnf, omode, r, r2, ps, st) + JET.test_call( + ContinuousNormalizingFlows.loss, + Base.typesof(icnf, omode, r, r2, ps, st); + target_modules = [ContinuousNormalizingFlows], + mode = :typo, + ) + JET.test_opt( + ContinuousNormalizingFlows.loss, + Base.typesof(icnf, omode, r, r2, ps, st); + target_modules = [ContinuousNormalizingFlows], + ) + else + ContinuousNormalizingFlows.loss(icnf, omode, r, ps, st) + JET.test_call( + ContinuousNormalizingFlows.loss, + Base.typesof(icnf, omode, r, ps, st); + target_modules = [ContinuousNormalizingFlows], + mode = :typo, + ) + JET.test_opt( + ContinuousNormalizingFlows.loss, + Base.typesof(icnf, omode, r, ps, st); + target_modules = [ContinuousNormalizingFlows], + ) + end + end end