Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
Expand All @@ -31,6 +30,7 @@ SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
ScientificTypesBase = "30f210dd-8aff-4c5f-94ba-8e64358c1161"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[weakdeps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Expand All @@ -49,7 +49,6 @@ Dates = "1"
DifferentiationInterface = "0.6"
Distributions = "0.25"
DistributionsAD = "0.6"
Enzyme = "0.13"
FillArrays = "1"
LinearAlgebra = "1"
Lux = "1"
Expand All @@ -67,4 +66,5 @@ SciMLBase = "2"
SciMLSensitivity = "7"
ScientificTypesBase = "3"
Statistics = "1"
Zygote = "0.6, 0.7"
julia = "1.10"
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,14 @@ n_in = nvars + naugs # with augmentation
n = 1024

# Model
using ContinuousNormalizingFlows, Lux, ADTypes #, Enzyme, CUDA, ComputationalResources
using ContinuousNormalizingFlows, Lux, ADTypes #, Zygote, CUDA, ComputationalResources
nn = Chain(Dense(n_in => 3 * n_in, tanh), Dense(3 * n_in => n_in, tanh))
icnf = construct(
RNODE,
nn,
nvars, # number of variables
naugs; # number of augmented dimensions
# compute_mode = DIVecJacMatrixMode(AutoEnzyme(; mode = Enzyme.set_runtime_activity(Enzyme.Reverse), function_annotation = Enzyme.Const)), # process data in batches and use Enzyme
# compute_mode = DIVecJacMatrixMode(AutoZygote()), # process data in batches and use Zygote
# inplace = true, # use the inplace version of functions
# resource = CUDALibs(), # process data by GPU
tspan = (0.0f0, 13.0f0), # have bigger time span
Expand All @@ -74,13 +74,13 @@ r = rand(data_dist, nvars, n)
r = convert.(Float32, r)

# Fit It
using DataFrames, MLJBase #, Enzyme, ADTypes, OptimizationOptimisers
using DataFrames, MLJBase #, Zygote, ADTypes, OptimizationOptimisers
df = DataFrame(transpose(r), :auto)
model = ICNFModel(
icnf;
# optimizers = (Lion(),),
# n_epochs = 300,
# adtype = AutoEnzyme(; mode = Enzyme.set_runtime_activity(Enzyme.Reverse), function_annotation = Enzyme.Const),
# adtype = AutoZygote(),
# use_batch = true,
# batch_size = 32,
sol_kwargs = (; progress = true,), # pass to the solver
Expand Down
4 changes: 2 additions & 2 deletions benchmark/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,18 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
PkgBenchmark = "32113eaa-f34f-5b0d-bd6c-c81e245fc73d"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
ADTypes = "1"
BenchmarkTools = "1"
ComponentArrays = "0.15"
DifferentiationInterface = "0.6"
Enzyme = "0.13"
Lux = "1"
PkgBenchmark = "0.2"
StableRNGs = "1"
Zygote = "0.6, 0.7"
julia = "1.10"
72 changes: 11 additions & 61 deletions benchmark/benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ import ADTypes,
BenchmarkTools,
ComponentArrays,
DifferentiationInterface,
Enzyme,
Lux,
PkgBenchmark,
StableRNGs,
Zygote,
ContinuousNormalizingFlows

SUITE = BenchmarkTools.BenchmarkGroup()
Expand Down Expand Up @@ -33,12 +33,7 @@ icnf = ContinuousNormalizingFlows.construct(
nn,
nvars,
naugs;
compute_mode = ContinuousNormalizingFlows.DIVecJacMatrixMode(
ADTypes.AutoEnzyme(;
mode = Enzyme.set_runtime_activity(Enzyme.Reverse),
function_annotation = Enzyme.Const,
),
),
compute_mode = ContinuousNormalizingFlows.DIVecJacMatrixMode(ADTypes.AutoZygote()),
tspan = (0.0f0, 13.0f0),
steer_rate = 1.0f-1,
λ₃ = 1.0f-2,
Expand All @@ -57,22 +52,8 @@ end

diff_loss_tn(ps)
diff_loss_tt(ps)
DifferentiationInterface.gradient(
diff_loss_tn,
ADTypes.AutoEnzyme(;
mode = Enzyme.set_runtime_activity(Enzyme.Reverse),
function_annotation = Enzyme.Const,
),
ps,
)
DifferentiationInterface.gradient(
diff_loss_tt,
ADTypes.AutoEnzyme(;
mode = Enzyme.set_runtime_activity(Enzyme.Reverse),
function_annotation = Enzyme.Const,
),
ps,
)
DifferentiationInterface.gradient(diff_loss_tn, ADTypes.AutoZygote(), ps)
DifferentiationInterface.gradient(diff_loss_tt, ADTypes.AutoZygote(), ps)
GC.gc()

SUITE["main"]["no_inplace"]["direct"]["train"] =
Expand All @@ -82,19 +63,13 @@ SUITE["main"]["no_inplace"]["direct"]["test"] =
SUITE["main"]["no_inplace"]["AD-1-order"]["train"] =
BenchmarkTools.@benchmarkable DifferentiationInterface.gradient(
diff_loss_tn,
ADTypes.AutoEnzyme(;
mode = Enzyme.set_runtime_activity(Enzyme.Reverse),
function_annotation = Enzyme.Const,
),
ADTypes.AutoZygote(),
ps,
)
SUITE["main"]["no_inplace"]["AD-1-order"]["test"] =
BenchmarkTools.@benchmarkable DifferentiationInterface.gradient(
diff_loss_tt,
ADTypes.AutoEnzyme(;
mode = Enzyme.set_runtime_activity(Enzyme.Reverse),
function_annotation = Enzyme.Const,
),
ADTypes.AutoZygote(),
ps,
)

Expand All @@ -104,12 +79,7 @@ icnf2 = ContinuousNormalizingFlows.construct(
nvars,
naugs;
inplace = true,
compute_mode = ContinuousNormalizingFlows.DIVecJacMatrixMode(
ADTypes.AutoEnzyme(;
mode = Enzyme.set_runtime_activity(Enzyme.Reverse),
function_annotation = Enzyme.Const,
),
),
compute_mode = ContinuousNormalizingFlows.DIVecJacMatrixMode(ADTypes.AutoZygote()),
tspan = (0.0f0, 13.0f0),
steer_rate = 1.0f-1,
λ₃ = 1.0f-2,
Expand All @@ -125,22 +95,8 @@ end

diff_loss_tn2(ps)
diff_loss_tt2(ps)
DifferentiationInterface.gradient(
diff_loss_tn2,
ADTypes.AutoEnzyme(;
mode = Enzyme.set_runtime_activity(Enzyme.Reverse),
function_annotation = Enzyme.Const,
),
ps,
)
DifferentiationInterface.gradient(
diff_loss_tt2,
ADTypes.AutoEnzyme(;
mode = Enzyme.set_runtime_activity(Enzyme.Reverse),
function_annotation = Enzyme.Const,
),
ps,
)
DifferentiationInterface.gradient(diff_loss_tn2, ADTypes.AutoZygote(), ps)
DifferentiationInterface.gradient(diff_loss_tt2, ADTypes.AutoZygote(), ps)
GC.gc()

SUITE["main"]["inplace"]["direct"]["train"] =
Expand All @@ -149,18 +105,12 @@ SUITE["main"]["inplace"]["direct"]["test"] = BenchmarkTools.@benchmarkable diff_
SUITE["main"]["inplace"]["AD-1-order"]["train"] =
BenchmarkTools.@benchmarkable DifferentiationInterface.gradient(
diff_loss_tn2,
ADTypes.AutoEnzyme(;
mode = Enzyme.set_runtime_activity(Enzyme.Reverse),
function_annotation = Enzyme.Const,
),
ADTypes.AutoZygote(),
ps,
)
SUITE["main"]["inplace"]["AD-1-order"]["test"] =
BenchmarkTools.@benchmarkable DifferentiationInterface.gradient(
diff_loss_tt2,
ADTypes.AutoEnzyme(;
mode = Enzyme.set_runtime_activity(Enzyme.Reverse),
function_annotation = Enzyme.Const,
),
ADTypes.AutoZygote(),
ps,
)
4 changes: 2 additions & 2 deletions src/ContinuousNormalizingFlows.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import ADTypes,
DifferentiationInterface,
Distributions,
DistributionsAD,
Enzyme,
FillArrays,
LinearAlgebra,
Lux,
Expand All @@ -27,7 +26,8 @@ import ADTypes,
ScientificTypesBase,
SciMLBase,
SciMLSensitivity,
Statistics
Statistics,
Zygote

export construct,
inference,
Expand Down
7 changes: 1 addition & 6 deletions src/base_icnf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,7 @@ function construct(
nvars::Int,
naugmented::Int = 0;
data_type::Type{<:AbstractFloat} = Float32,
compute_mode::ComputeMode = DIVecJacMatrixMode(
ADTypes.AutoEnzyme(;
mode = Enzyme.set_runtime_activity(Enzyme.Reverse),
function_annotation = Enzyme.Const,
),
),
compute_mode::ComputeMode = DIVecJacMatrixMode(ADTypes.AutoZygote()),
inplace::Bool = false,
cond::Bool = aicnf <: Union{CondRNODE, CondFFJORD, CondPlanar},
resource::ComputationalResources.AbstractResource = ComputationalResources.CPU1(),
Expand Down
5 changes: 1 addition & 4 deletions src/exts/mlj_ext/core_cond_icnf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,7 @@ function CondICNFModel(
loss::Function = loss;
optimizers::Tuple = (Optimisers.Lion(),),
n_epochs::Int = 300,
adtype::ADTypes.AbstractADType = ADTypes.AutoEnzyme(;
mode = Enzyme.set_runtime_activity(Enzyme.Reverse),
function_annotation = Enzyme.Const,
),
adtype::ADTypes.AbstractADType = ADTypes.AutoZygote(),
use_batch::Bool = true,
batch_size::Int = 32,
sol_kwargs::NamedTuple = (;),
Expand Down
5 changes: 1 addition & 4 deletions src/exts/mlj_ext/core_icnf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,7 @@ function ICNFModel(
loss::Function = loss;
optimizers::Tuple = (Optimisers.Lion(),),
n_epochs::Int = 300,
adtype::ADTypes.AbstractADType = ADTypes.AutoEnzyme(;
mode = Enzyme.set_runtime_activity(Enzyme.Reverse),
function_annotation = Enzyme.Const,
),
adtype::ADTypes.AbstractADType = ADTypes.AutoZygote(),
use_batch::Bool = true,
batch_size::Int = 32,
sol_kwargs::NamedTuple = (;),
Expand Down
8 changes: 4 additions & 4 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
y = f(xs)
z = similar(xs)
ChainRulesCore.@ignore_derivatives fill!(z, zero(T))
res = similar(xs, size(xs, 1), size(xs, 1), size(xs, 2))
res = Zygote.Buffer(xs, size(xs, 1), size(xs, 1), size(xs, 2))
for i in axes(xs, 1)
ChainRulesCore.@ignore_derivatives z[i, :] .= one(T)
res[i, :, :] =
only(DifferentiationInterface.pullback(f, icnf.compute_mode.adback, xs, (z,)))
ChainRulesCore.@ignore_derivatives z[i, :] .= zero(T)
end
y, eachslice(res; dims = 3)
y, eachslice(copy(res); dims = 3)
end

@inline function jacobian_batched(
Expand All @@ -24,15 +24,15 @@ end
y = f(xs)
z = similar(xs)
ChainRulesCore.@ignore_derivatives fill!(z, zero(T))
res = similar(xs, size(xs, 1), size(xs, 1), size(xs, 2))
res = Zygote.Buffer(xs, size(xs, 1), size(xs, 1), size(xs, 2))
for i in axes(xs, 1)
ChainRulesCore.@ignore_derivatives z[i, :] .= one(T)
res[:, i, :] = only(
DifferentiationInterface.pushforward(f, icnf.compute_mode.adback, xs, (z,)),
)
ChainRulesCore.@ignore_derivatives z[i, :] .= zero(T)
end
y, eachslice(res; dims = 3)
y, eachslice(copy(res); dims = 3)
end

@inline function jacobian_batched(
Expand Down
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
TerminalLoggers = "5d786b92-1e48-4d6f-9151-6b4477ca9bed"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
ADTypes = "1"
Expand All @@ -33,4 +34,5 @@ MLJBase = "1"
SciMLBase = "2"
StableRNGs = "1"
TerminalLoggers = "0.1"
Zygote = "0.6, 0.7"
julia = "1.10"
22 changes: 13 additions & 9 deletions test/call_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,21 @@ Test.@testset "Call Tests" begin
nvars_ = Int[2]
aug_steers = Bool[false, true]
inplaces = Bool[false, true]
adtypes = ADTypes.AbstractADType[
ADTypes.AutoEnzyme(;
mode = Enzyme.set_runtime_activity(Enzyme.Reverse),
function_annotation = Enzyme.Const,
),
ADTypes.AutoEnzyme(;
mode = Enzyme.set_runtime_activity(Enzyme.Forward),
function_annotation = Enzyme.Const,
),
adtypes = ADTypes.AbstractADType[ADTypes.AutoZygote(),
# ADTypes.AutoEnzyme(;
# mode = Enzyme.set_runtime_activity(Enzyme.Reverse),
# function_annotation = Enzyme.Const,
# ),
# ADTypes.AutoEnzyme(;
# mode = Enzyme.set_runtime_activity(Enzyme.Forward),
# function_annotation = Enzyme.Const,
# ),
]
compute_modes = ContinuousNormalizingFlows.ComputeMode[
ContinuousNormalizingFlows.DIVecJacVectorMode(ADTypes.AutoZygote()),
ContinuousNormalizingFlows.DIJacVecVectorMode(ADTypes.AutoZygote()),
ContinuousNormalizingFlows.DIVecJacMatrixMode(ADTypes.AutoZygote()),
ContinuousNormalizingFlows.DIJacVecMatrixMode(ADTypes.AutoZygote()),
ContinuousNormalizingFlows.DIVecJacVectorMode(
ADTypes.AutoEnzyme(;
mode = Enzyme.set_runtime_activity(Enzyme.Reverse),
Expand Down
22 changes: 13 additions & 9 deletions test/fit_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,21 @@ Test.@testset "Fit Tests" begin
nvars_ = Int[2]
aug_steers = Bool[false, true]
inplaces = Bool[false, true]
adtypes = ADTypes.AbstractADType[
ADTypes.AutoEnzyme(;
mode = Enzyme.set_runtime_activity(Enzyme.Reverse),
function_annotation = Enzyme.Const,
),
ADTypes.AutoEnzyme(;
mode = Enzyme.set_runtime_activity(Enzyme.Forward),
function_annotation = Enzyme.Const,
),
adtypes = ADTypes.AbstractADType[ADTypes.AutoZygote(),
# ADTypes.AutoEnzyme(;
# mode = Enzyme.set_runtime_activity(Enzyme.Reverse),
# function_annotation = Enzyme.Const,
# ),
# ADTypes.AutoEnzyme(;
# mode = Enzyme.set_runtime_activity(Enzyme.Forward),
# function_annotation = Enzyme.Const,
# ),
]
compute_modes = ContinuousNormalizingFlows.ComputeMode[
ContinuousNormalizingFlows.DIVecJacVectorMode(ADTypes.AutoZygote()),
ContinuousNormalizingFlows.DIJacVecVectorMode(ADTypes.AutoZygote()),
ContinuousNormalizingFlows.DIVecJacMatrixMode(ADTypes.AutoZygote()),
ContinuousNormalizingFlows.DIJacVecMatrixMode(ADTypes.AutoZygote()),
ContinuousNormalizingFlows.DIVecJacVectorMode(
ADTypes.AutoEnzyme(;
mode = Enzyme.set_runtime_activity(Enzyme.Reverse),
Expand Down
7 changes: 1 addition & 6 deletions test/instability_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,7 @@ Test.@testset "Instability" begin
nn,
nvars,
naugs;
compute_mode = ContinuousNormalizingFlows.DIVecJacMatrixMode(
ADTypes.AutoEnzyme(;
mode = Enzyme.set_runtime_activity(Enzyme.Reverse),
function_annotation = Enzyme.Const,
),
),
compute_mode = ContinuousNormalizingFlows.DIVecJacMatrixMode(ADTypes.AutoZygote()),
tspan = (0.0f0, 13.0f0),
steer_rate = 1.0f-1,
λ₃ = 1.0f-2,
Expand Down
Loading
Loading