Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
ScientificTypesBase = "30f210dd-8aff-4c5f-94ba-8e64358c1161"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
WeightInitializers = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
Expand Down Expand Up @@ -57,5 +58,6 @@ SciMLBase = "2"
SciMLSensitivity = "7"
ScientificTypesBase = "3"
Statistics = "1"
WeightInitializers = "1"
Zygote = "0.7"
julia = "1.10"
2 changes: 2 additions & 0 deletions benchmark/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
OrdinaryDiffEqDefault = "50262376-6c5a-4cf5-baba-aaf4f84d72d7"
Expand All @@ -16,6 +17,7 @@ ADTypes = "1"
BenchmarkTools = "1"
ComponentArrays = "0.15"
DifferentiationInterface = "0.7"
Distributions = "0.25"
Lux = "1"
LuxCore = "1"
OrdinaryDiffEqDefault = "1"
Expand Down
126 changes: 66 additions & 60 deletions benchmark/benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import ADTypes,
BenchmarkTools,
ComponentArrays,
DifferentiationInterface,
Distributions,
Lux,
LuxCore,
OrdinaryDiffEqDefault,
Expand All @@ -11,25 +12,18 @@ import ADTypes,
Zygote,
ContinuousNormalizingFlows

SUITE = BenchmarkTools.BenchmarkGroup()

SUITE["main"] = BenchmarkTools.BenchmarkGroup(["package", "simple"])

SUITE["main"]["no_inplace"] = BenchmarkTools.BenchmarkGroup(["no_inplace"])
SUITE["main"]["inplace"] = BenchmarkTools.BenchmarkGroup(["inplace"])

SUITE["main"]["no_inplace"]["direct"] = BenchmarkTools.BenchmarkGroup(["direct"])
SUITE["main"]["no_inplace"]["AD-1-order"] = BenchmarkTools.BenchmarkGroup(["gradient"])

SUITE["main"]["inplace"]["direct"] = BenchmarkTools.BenchmarkGroup(["direct"])
SUITE["main"]["inplace"]["AD-1-order"] = BenchmarkTools.BenchmarkGroup(["gradient"])

rng = StableRNGs.StableRNG(1)
nvars = 2^3
ndata = 2^10
ndimension = 1
data_dist = Distributions.Beta{Float32}(2.0f0, 4.0f0)
r = rand(rng, data_dist, ndimension, ndata)
r = convert.(Float32, r)

nvars = size(r, 1)
naugs = nvars
n_in = nvars + naugs
n = 2^6
nn = Lux.Chain(Lux.Dense(n_in => n_in, tanh))

nn = Lux.Chain(Lux.Dense(n_in => 3 * n_in, tanh), Lux.Dense(3 * n_in => n_in, tanh))

icnf = ContinuousNormalizingFlows.construct(
ContinuousNormalizingFlows.ICNF,
Expand All @@ -49,9 +43,29 @@ icnf = ContinuousNormalizingFlows.construct(
sensealg = SciMLSensitivity.InterpolatingAdjoint(),
),
)

icnf2 = ContinuousNormalizingFlows.construct(
ContinuousNormalizingFlows.ICNF,
nn,
nvars,
naugs;
inplace = true,
compute_mode = ContinuousNormalizingFlows.LuxVecJacMatrixMode(ADTypes.AutoZygote()),
tspan = (0.0f0, 1.0f0),
steer_rate = 1.0f-1,
λ₁ = 1.0f-2,
λ₂ = 1.0f-2,
λ₃ = 1.0f-2,
rng,
sol_kwargs = (;
save_everystep = false,
alg = OrdinaryDiffEqDefault.DefaultODEAlgorithm(),
sensealg = SciMLSensitivity.InterpolatingAdjoint(),
),
)

ps, st = LuxCore.setup(icnf.rng, icnf)
ps = ComponentArrays.ComponentArray(ps)
r = rand(icnf.rng, Float32, nvars, n)

function diff_loss_tn(x::Any)
return ContinuousNormalizingFlows.loss(
Expand All @@ -72,49 +86,6 @@ function diff_loss_tt(x::Any)
)
end

diff_loss_tn(ps)
diff_loss_tt(ps)
DifferentiationInterface.gradient(diff_loss_tn, ADTypes.AutoZygote(), ps)
DifferentiationInterface.gradient(diff_loss_tt, ADTypes.AutoZygote(), ps)
GC.gc()

SUITE["main"]["no_inplace"]["direct"]["train"] =
BenchmarkTools.@benchmarkable diff_loss_tn(ps)
SUITE["main"]["no_inplace"]["direct"]["test"] =
BenchmarkTools.@benchmarkable diff_loss_tt(ps)
SUITE["main"]["no_inplace"]["AD-1-order"]["train"] =
BenchmarkTools.@benchmarkable DifferentiationInterface.gradient(
diff_loss_tn,
ADTypes.AutoZygote(),
ps,
)
SUITE["main"]["no_inplace"]["AD-1-order"]["test"] =
BenchmarkTools.@benchmarkable DifferentiationInterface.gradient(
diff_loss_tt,
ADTypes.AutoZygote(),
ps,
)

icnf2 = ContinuousNormalizingFlows.construct(
ContinuousNormalizingFlows.ICNF,
nn,
nvars,
naugs;
inplace = true,
compute_mode = ContinuousNormalizingFlows.LuxVecJacMatrixMode(ADTypes.AutoZygote()),
tspan = (0.0f0, 1.0f0),
steer_rate = 1.0f-1,
λ₁ = 1.0f-2,
λ₂ = 1.0f-2,
λ₃ = 1.0f-2,
rng,
sol_kwargs = (;
save_everystep = false,
alg = OrdinaryDiffEqDefault.DefaultODEAlgorithm(),
sensealg = SciMLSensitivity.InterpolatingAdjoint(),
),
)

function diff_loss_tn2(x::Any)
return ContinuousNormalizingFlows.loss(
icnf2,
Expand All @@ -134,12 +105,47 @@ function diff_loss_tt2(x::Any)
)
end

diff_loss_tn(ps)
diff_loss_tt(ps)
DifferentiationInterface.gradient(diff_loss_tn, ADTypes.AutoZygote(), ps)
DifferentiationInterface.gradient(diff_loss_tt, ADTypes.AutoZygote(), ps)

diff_loss_tn2(ps)
diff_loss_tt2(ps)
DifferentiationInterface.gradient(diff_loss_tn2, ADTypes.AutoZygote(), ps)
DifferentiationInterface.gradient(diff_loss_tt2, ADTypes.AutoZygote(), ps)

GC.gc()

SUITE = BenchmarkTools.BenchmarkGroup()

SUITE["main"] = BenchmarkTools.BenchmarkGroup(["package", "simple"])

SUITE["main"]["no_inplace"] = BenchmarkTools.BenchmarkGroup(["no_inplace"])
SUITE["main"]["inplace"] = BenchmarkTools.BenchmarkGroup(["inplace"])

SUITE["main"]["no_inplace"]["direct"] = BenchmarkTools.BenchmarkGroup(["direct"])
SUITE["main"]["no_inplace"]["AD-1-order"] = BenchmarkTools.BenchmarkGroup(["gradient"])

SUITE["main"]["inplace"]["direct"] = BenchmarkTools.BenchmarkGroup(["direct"])
SUITE["main"]["inplace"]["AD-1-order"] = BenchmarkTools.BenchmarkGroup(["gradient"])

SUITE["main"]["no_inplace"]["direct"]["train"] =
BenchmarkTools.@benchmarkable diff_loss_tn(ps)
SUITE["main"]["no_inplace"]["direct"]["test"] =
BenchmarkTools.@benchmarkable diff_loss_tt(ps)
SUITE["main"]["no_inplace"]["AD-1-order"]["train"] =
BenchmarkTools.@benchmarkable DifferentiationInterface.gradient(
diff_loss_tn,
ADTypes.AutoZygote(),
ps,
)
SUITE["main"]["no_inplace"]["AD-1-order"]["test"] =
BenchmarkTools.@benchmarkable DifferentiationInterface.gradient(
diff_loss_tt,
ADTypes.AutoZygote(),
ps,
)
SUITE["main"]["inplace"]["direct"]["train"] =
BenchmarkTools.@benchmarkable diff_loss_tn2(ps)
SUITE["main"]["inplace"]["direct"]["test"] = BenchmarkTools.@benchmarkable diff_loss_tt2(ps)
Expand Down
2 changes: 1 addition & 1 deletion examples/usage.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ model = ICNFModel(
icnf;
optimizers = (Adam(),),
adtype = AutoZygote(),
batch_size = 512,
batchsize = 512,
sol_kwargs = (; progress = true, epochs = 300), # pass to the solver
)
mach = machine(model, df)
Expand Down
12 changes: 5 additions & 7 deletions src/ContinuousNormalizingFlows.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import ADTypes,
SciMLSensitivity,
ScientificTypesBase,
Statistics,
WeightInitializers,
Zygote

export construct,
Expand Down Expand Up @@ -54,13 +55,10 @@ export construct,
include(joinpath("layers", "cond_layer.jl"))
include(joinpath("layers", "planar_layer.jl"))

include("types.jl")

include("base_icnf.jl")

include("icnf.jl")

include("utils.jl")
include(joinpath("core", "types.jl"))
include(joinpath("core", "base_icnf.jl"))
include(joinpath("core", "icnf.jl"))
include(joinpath("core", "utils.jl"))

include(joinpath("exts", "mlj_ext", "core.jl"))
include(joinpath("exts", "mlj_ext", "core_icnf.jl"))
Expand Down
18 changes: 9 additions & 9 deletions src/base_icnf.jl → src/core/base_icnf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,12 +94,12 @@ function n_augment_input(::AbstractICNF)
end

function steer_tspan(
icnf::AbstractICNF{T, <:ComputeMode, INPLACE, COND, AUGMENTED, true},
icnf::AbstractICNF{<:AbstractFloat, <:ComputeMode, INPLACE, COND, AUGMENTED, true},
::TrainMode,
) where {T <: AbstractFloat, INPLACE, COND, AUGMENTED}
) where {INPLACE, COND, AUGMENTED}
t₀, t₁ = icnf.tspan
Δt = abs(t₁ - t₀)
r = convert(T, rand(icnf.rng, icnf.steerdist))
r = oftype(t₁, rand(icnf.rng, icnf.steerdist))
t₁_new = muladd(Δt, r, t₁)
return (t₀, t₁_new)
end
Expand Down Expand Up @@ -504,12 +504,12 @@ function loss(
end

function make_ode_func(
icnf::AbstractICNF{T, CM, INPLACE},
icnf::AbstractICNF{T},
mode::Mode,
nn::LuxCore.AbstractLuxLayer,
st::NamedTuple,
ϵ::AbstractVecOrMat{T},
) where {T <: AbstractFloat, CM, INPLACE}
) where {T <: AbstractFloat}
function ode_func(u::Any, p::Any, t::Any)
return augmented_f(u, p, t, icnf, mode, nn, st, ϵ)
end
Expand All @@ -521,19 +521,19 @@ function make_ode_func(
return ode_func
end

function (icnf::AbstractICNF{T, CM, INPLACE, false})(
function (icnf::AbstractICNF{<:AbstractFloat, <:ComputeMode, INPLACE, false})(
xs::AbstractVecOrMat,
ps::Any,
st::NamedTuple,
) where {T, CM, INPLACE}
) where {INPLACE}
return first(inference(icnf, TrainMode(), xs, ps, st)), st
end

function (icnf::AbstractICNF{T, CM, INPLACE, true})(
function (icnf::AbstractICNF{<:AbstractFloat, <:ComputeMode, INPLACE, true})(
xs_ys::Tuple,
ps::Any,
st::NamedTuple,
) where {T, CM, INPLACE}
) where {INPLACE}
xs, ys = xs_ys
return first(inference(icnf, TrainMode(), xs, ys, ps, st)), st
end
File renamed without changes.
File renamed without changes.
File renamed without changes.
87 changes: 51 additions & 36 deletions src/exts/dist_ext/core_cond_icnf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,51 +15,66 @@ function CondICNFDist(
return CondICNFDist(mach.model.m, mode, ys, ps, st)
end

function Distributions._logpdf(d::CondICNFDist, x::AbstractVector{<:Real})
return if d.m isa AbstractICNF{<:AbstractFloat, <:VectorMode}
first(inference(d.m, d.mode, x, d.ys, d.ps, d.st))
elseif d.m isa AbstractICNF{<:AbstractFloat, <:MatrixMode}
@warn "to compute by matrices, data should be a matrix."
first(Distributions._logpdf(d, hcat(x)))
else
error("Not Implemented")
end
function Distributions._logpdf(
d::CondICNFDist{<:AbstractICNF{<:AbstractFloat, <:VectorMode}},
x::AbstractVector{<:Real},
)
return first(inference(d.m, d.mode, x, d.ys, d.ps, d.st))
end

function Distributions._logpdf(
d::CondICNFDist{<:AbstractICNF{<:AbstractFloat, <:MatrixMode}},
x::AbstractVector{<:Real},
)
@warn "to compute by matrices, data should be a matrix."
return first(Distributions._logpdf(d, hcat(x)))
end

function Distributions._logpdf(
d::CondICNFDist{<:AbstractICNF{<:AbstractFloat, <:VectorMode}},
A::AbstractMatrix{<:Real},
)
@warn "to compute by vectors, data should be a vector."
return Distributions._logpdf.(d, collect(collect.(eachcol(A))))
end
function Distributions._logpdf(d::CondICNFDist, A::AbstractMatrix{<:Real})
return if d.m isa AbstractICNF{<:AbstractFloat, <:VectorMode}
@warn "to compute by vectors, data should be a vector."
Distributions._logpdf.(d, collect(collect.(eachcol(A))))
elseif d.m isa AbstractICNF{<:AbstractFloat, <:MatrixMode}
first(inference(d.m, d.mode, A, d.ys[:, begin:size(A, 2)], d.ps, d.st))
else
error("Not Implemented")
end

function Distributions._logpdf(
d::CondICNFDist{<:AbstractICNF{<:AbstractFloat, <:MatrixMode}},
A::AbstractMatrix{<:Real},
)
return first(inference(d.m, d.mode, A, d.ys[:, begin:size(A, 2)], d.ps, d.st))
end

function Distributions._rand!(
rng::Random.AbstractRNG,
d::CondICNFDist,
d::CondICNFDist{<:AbstractICNF{<:AbstractFloat, <:VectorMode}},
x::AbstractVector{<:Real},
)
return if d.m isa AbstractICNF{<:AbstractFloat, <:VectorMode}
x .= generate(d.m, d.mode, d.ys, d.ps, d.st)
elseif d.m isa AbstractICNF{<:AbstractFloat, <:MatrixMode}
@warn "to compute by matrices, data should be a matrix."
x .= Distributions._rand!(rng, d, hcat(x))
else
error("Not Implemented")
end
return x .= generate(d.m, d.mode, d.ys, d.ps, d.st)
end

function Distributions._rand!(
rng::Random.AbstractRNG,
d::CondICNFDist{<:AbstractICNF{<:AbstractFloat, <:MatrixMode}},
x::AbstractVector{<:Real},
)
@warn "to compute by matrices, data should be a matrix."
return x .= Distributions._rand!(rng, d, hcat(x))
end

function Distributions._rand!(
rng::Random.AbstractRNG,
d::CondICNFDist{<:AbstractICNF{<:AbstractFloat, <:VectorMode}},
A::AbstractMatrix{<:Real},
)
@warn "to compute by vectors, data should be a vector."
return A .= hcat(Distributions._rand!.(rng, d, collect(collect.(eachcol(A))))...)
end

function Distributions._rand!(
rng::Random.AbstractRNG,
d::CondICNFDist,
d::CondICNFDist{<:AbstractICNF{<:AbstractFloat, <:MatrixMode}},
A::AbstractMatrix{<:Real},
)
return if d.m isa AbstractICNF{<:AbstractFloat, <:VectorMode}
@warn "to compute by vectors, data should be a vector."
A .= hcat(Distributions._rand!.(rng, d, collect(collect.(eachcol(A))))...)
elseif d.m isa AbstractICNF{<:AbstractFloat, <:MatrixMode}
A .= generate(d.m, d.mode, d.ys[:, begin:size(A, 2)], d.ps, d.st, size(A, 2))
else
error("Not Implemented")
end
return A .= generate(d.m, d.mode, d.ys[:, begin:size(A, 2)], d.ps, d.st, size(A, 2))
end
Loading
Loading