Skip to content

Commit 84c59cb

Browse files
authored
add Lux compute types (#462)
* add Lux compute types * add tests * use lux cm * fix J
1 parent f19a24c commit 84c59cb

File tree

9 files changed

+154
-4
lines changed

9 files changed

+154
-4
lines changed

benchmark/benchmarks.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ icnf = ContinuousNormalizingFlows.construct(
3333
nn,
3434
nvars,
3535
naugs;
36-
compute_mode = ContinuousNormalizingFlows.DIVecJacMatrixMode(ADTypes.AutoZygote()),
36+
compute_mode = ContinuousNormalizingFlows.LuxVecJacMatrixMode(ADTypes.AutoZygote()),
3737
tspan = (0.0f0, 13.0f0),
3838
steer_rate = 1.0f-1,
3939
λ₃ = 1.0f-2,
@@ -79,7 +79,7 @@ icnf2 = ContinuousNormalizingFlows.construct(
7979
nvars,
8080
naugs;
8181
inplace = true,
82-
compute_mode = ContinuousNormalizingFlows.DIVecJacMatrixMode(ADTypes.AutoZygote()),
82+
compute_mode = ContinuousNormalizingFlows.LuxVecJacMatrixMode(ADTypes.AutoZygote()),
8383
tspan = (0.0f0, 13.0f0),
8484
steer_rate = 1.0f-1,
8585
λ₃ = 1.0f-2,

src/ContinuousNormalizingFlows.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ export construct,
4646
DIJacVecVectorMode,
4747
DIVecJacMatrixMode,
4848
DIJacVecMatrixMode,
49+
LuxVecJacMatrixMode,
50+
LuxJacVecMatrixMode,
4951
ICNFModel,
5052
CondICNFModel,
5153
CondLayer,

src/icnf.jl

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -455,6 +455,134 @@ function augmented_f(
455455
nothing
456456
end
457457

458+
function augmented_f(
459+
u::Any,
460+
p::Any,
461+
::Any,
462+
icnf::ICNF{T, <:LuxVecJacMatrixMode, false, COND, AUGMENTED, STEER, NORM_Z, NORM_J},
463+
mode::TrainMode,
464+
nn::LuxCore.AbstractLuxLayer,
465+
st::NamedTuple,
466+
ϵ::AbstractMatrix{T},
467+
) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J}
468+
n_aug = n_augment(icnf, mode)
469+
snn = Lux.StatefulLuxLayer{true}(nn, p, st)
470+
z = u[begin:(end - n_aug - 1), :]
471+
= snn(z)
472+
ϵJ = Lux.vector_jacobian_product(snn, icnf.compute_mode.adback, z, ϵ)
473+
= -sum(ϵJ .* ϵ; dims = 1)
474+
= transpose(if NORM_Z
475+
LinearAlgebra.norm.(eachcol(ż))
476+
else
477+
zrs_Ė = similar(ż, size(ż, 2))
478+
ChainRulesCore.@ignore_derivatives fill!(zrs_Ė, zero(T))
479+
zrs_Ė
480+
end)
481+
= transpose(if NORM_J
482+
LinearAlgebra.norm.(eachcol(ϵJ))
483+
else
484+
zrs_ṅ = similar(ż, size(ż, 2))
485+
ChainRulesCore.@ignore_derivatives fill!(zrs_ṅ, zero(T))
486+
zrs_ṅ
487+
end)
488+
vcat(ż, l̇, Ė, ṅ)
489+
end
490+
491+
function augmented_f(
492+
du::Any,
493+
u::Any,
494+
p::Any,
495+
::Any,
496+
icnf::ICNF{T, <:LuxVecJacMatrixMode, true, COND, AUGMENTED, STEER, NORM_Z, NORM_J},
497+
mode::TrainMode,
498+
nn::LuxCore.AbstractLuxLayer,
499+
st::NamedTuple,
500+
ϵ::AbstractMatrix{T},
501+
) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J}
502+
n_aug = n_augment(icnf, mode)
503+
snn = Lux.StatefulLuxLayer{true}(nn, p, st)
504+
z = u[begin:(end - n_aug - 1), :]
505+
= snn(z)
506+
ϵJ = Lux.vector_jacobian_product(snn, icnf.compute_mode.adback, z, ϵ)
507+
du[begin:(end - n_aug - 1), :] .=
508+
du[(end - n_aug), :] .= -vec(sum(ϵJ .* ϵ; dims = 1))
509+
du[(end - n_aug + 1), :] .= if NORM_Z
510+
LinearAlgebra.norm.(eachcol(ż))
511+
else
512+
zero(T)
513+
end
514+
du[(end - n_aug + 2), :] .= if NORM_J
515+
LinearAlgebra.norm.(eachcol(ϵJ))
516+
else
517+
zero(T)
518+
end
519+
nothing
520+
end
521+
522+
function augmented_f(
523+
u::Any,
524+
p::Any,
525+
::Any,
526+
icnf::ICNF{T, <:LuxJacVecMatrixMode, false, COND, AUGMENTED, STEER, NORM_Z, NORM_J},
527+
mode::TrainMode,
528+
nn::LuxCore.AbstractLuxLayer,
529+
st::NamedTuple,
530+
ϵ::AbstractMatrix{T},
531+
) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J}
532+
n_aug = n_augment(icnf, mode)
533+
snn = Lux.StatefulLuxLayer{true}(nn, p, st)
534+
z = u[begin:(end - n_aug - 1), :]
535+
= snn(z)
536+
= Lux.jacobian_vector_product(snn, icnf.compute_mode.adback, z, ϵ)
537+
= -sum(ϵ .* Jϵ; dims = 1)
538+
= transpose(if NORM_Z
539+
LinearAlgebra.norm.(eachcol(ż))
540+
else
541+
zrs_Ė = similar(ż, size(ż, 2))
542+
ChainRulesCore.@ignore_derivatives fill!(zrs_Ė, zero(T))
543+
zrs_Ė
544+
end)
545+
= transpose(if NORM_J
546+
LinearAlgebra.norm.(eachcol(Jϵ))
547+
else
548+
zrs_ṅ = similar(ż, size(ż, 2))
549+
ChainRulesCore.@ignore_derivatives fill!(zrs_ṅ, zero(T))
550+
zrs_ṅ
551+
end)
552+
vcat(ż, l̇, Ė, ṅ)
553+
end
554+
555+
function augmented_f(
556+
du::Any,
557+
u::Any,
558+
p::Any,
559+
::Any,
560+
icnf::ICNF{T, <:LuxJacVecMatrixMode, true, COND, AUGMENTED, STEER, NORM_Z, NORM_J},
561+
mode::TrainMode,
562+
nn::LuxCore.AbstractLuxLayer,
563+
st::NamedTuple,
564+
ϵ::AbstractMatrix{T},
565+
) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J}
566+
n_aug = n_augment(icnf, mode)
567+
snn = Lux.StatefulLuxLayer{true}(nn, p, st)
568+
z = u[begin:(end - n_aug - 1), :]
569+
= snn(z)
570+
= Lux.jacobian_vector_product(snn, icnf.compute_mode.adback, z, ϵ)
571+
du[begin:(end - n_aug - 1), :] .=
572+
du[(end - n_aug), :] .= -vec(sum(ϵ .* Jϵ; dims = 1))
573+
du[(end - n_aug + 1), :] .= if NORM_Z
574+
LinearAlgebra.norm.(eachcol(ż))
575+
else
576+
zero(T)
577+
end
578+
du[(end - n_aug + 2), :] .= if NORM_J
579+
LinearAlgebra.norm.(eachcol(Jϵ))
580+
else
581+
zero(T)
582+
end
583+
nothing
584+
end
585+
458586
@inline function loss(
459587
icnf::ICNF{<:AbstractFloat, <:VectorMode},
460588
mode::TrainMode,

src/types.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,14 @@ struct DIJacVecMatrixMode{ADBack <: ADTypes.AbstractADType} <: DIMatrixMode{ADBa
2222
adback::ADBack
2323
end
2424

25+
abstract type LuxMatrixMode{ADBack} <: MatrixMode{ADBack} end
26+
struct LuxVecJacMatrixMode{ADBack <: ADTypes.AbstractADType} <: LuxMatrixMode{ADBack}
27+
adback::ADBack
28+
end
29+
struct LuxJacVecMatrixMode{ADBack <: ADTypes.AbstractADType} <: LuxMatrixMode{ADBack}
30+
adback::ADBack
31+
end
32+
2533
abstract type AbstractICNF{
2634
T <: AbstractFloat,
2735
CM <: ComputeMode,

src/utils.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,3 +52,13 @@ end
5252
)
5353
)
5454
end
55+
56+
@inline function jacobian_batched(
57+
icnf::AbstractICNF{T, <:LuxMatrixMode},
58+
f::Lux.StatefulLuxLayer,
59+
xs::AbstractMatrix{<:Real},
60+
) where {T}
61+
y = f(xs)
62+
J = Lux.batched_jacobian(f, icnf.compute_mode.adback, xs)
63+
y, eachslice(J; dims = 3)
64+
end

test/call_tests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ Test.@testset "Call Tests" begin
4040
# ),
4141
]
4242
compute_modes = ContinuousNormalizingFlows.ComputeMode[
43+
ContinuousNormalizingFlows.LuxVecJacMatrixMode(ADTypes.AutoZygote()),
4344
ContinuousNormalizingFlows.DIVecJacVectorMode(ADTypes.AutoZygote()),
4445
ContinuousNormalizingFlows.DIJacVecVectorMode(ADTypes.AutoZygote()),
4546
ContinuousNormalizingFlows.DIVecJacMatrixMode(ADTypes.AutoZygote()),

test/fit_tests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ Test.@testset "Fit Tests" begin
3737
# ),
3838
]
3939
compute_modes = ContinuousNormalizingFlows.ComputeMode[
40+
ContinuousNormalizingFlows.LuxVecJacMatrixMode(ADTypes.AutoZygote()),
4041
ContinuousNormalizingFlows.DIVecJacVectorMode(ADTypes.AutoZygote()),
4142
ContinuousNormalizingFlows.DIJacVecVectorMode(ADTypes.AutoZygote()),
4243
ContinuousNormalizingFlows.DIVecJacMatrixMode(ADTypes.AutoZygote()),

test/instability_tests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ Test.@testset "Instability" begin
1616
nn,
1717
nvars,
1818
naugs;
19-
compute_mode = ContinuousNormalizingFlows.DIVecJacMatrixMode(ADTypes.AutoZygote()),
19+
compute_mode = ContinuousNormalizingFlows.LuxVecJacMatrixMode(ADTypes.AutoZygote()),
2020
tspan = (0.0f0, 13.0f0),
2121
steer_rate = 1.0f-1,
2222
λ₃ = 1.0f-2,

test/regression_tests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ Test.@testset "Regression Tests" begin
1111
nn,
1212
nvars,
1313
naugs;
14-
compute_mode = ContinuousNormalizingFlows.DIVecJacMatrixMode(ADTypes.AutoZygote()),
14+
compute_mode = ContinuousNormalizingFlows.LuxVecJacMatrixMode(ADTypes.AutoZygote()),
1515
tspan = (0.0f0, 13.0f0),
1616
steer_rate = 1.0f-1,
1717
λ₃ = 1.0f-2,

0 commit comments

Comments
 (0)