Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 2 additions & 9 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ version = "0.26.1"
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
ComputationalResources = "ed09eef8-17a6-5b46-8889-db040fac31e3"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
Expand All @@ -17,6 +16,7 @@ FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
Expand All @@ -32,18 +32,10 @@ ScientificTypesBase = "30f210dd-8aff-4c5f-94ba-8e64358c1161"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[weakdeps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"

[extensions]
ContinuousNormalizingFlowsCUDAExt = "CUDA"

[compat]
ADTypes = "1"
CUDA = "5"
ChainRulesCore = "1"
ComponentArrays = "0.15"
ComputationalResources = "0.3"
DataFrames = "1"
Dates = "1"
DifferentiationInterface = "0.6"
Expand All @@ -53,6 +45,7 @@ FillArrays = "1"
LinearAlgebra = "1"
Lux = "1"
LuxCore = "1"
MLDataDevices = "1"
MLJBase = "1"
MLJModelInterface = "1"
MLUtils = "0.4"
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ n_in = nvars + naugs # with augmentation
n = 1024

# Model
using ContinuousNormalizingFlows, Lux, ADTypes #, Zygote, CUDA, ComputationalResources
using ContinuousNormalizingFlows, Lux, ADTypes #, Zygote, CUDA, MLDataDevices
nn = Chain(Dense(n_in => 3 * n_in, tanh), Dense(3 * n_in => n_in, tanh))
icnf = construct(
RNODE,
Expand All @@ -52,7 +52,7 @@ icnf = construct(
naugs; # number of augmented dimensions
# compute_mode = LuxVecJacMatrixMode(AutoZygote()), # process data in batches and use Zygote
# inplace = true, # use the inplace version of functions
# resource = CUDALibs(), # process data by GPU
# device = gpu_device(), # process data by GPU
tspan = (0.0f0, 13.0f0), # have bigger time span
steer_rate = 1.0f-1, # add random noise to end of the time span
λ₁ = 1.0f-2, # regulate flow
Expand Down

This file was deleted.

2 changes: 1 addition & 1 deletion src/ContinuousNormalizingFlows.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import ADTypes,
Base.Iterators,
ChainRulesCore,
ComponentArrays,
ComputationalResources,
DataFrames,
Dates,
DifferentiationInterface,
Expand All @@ -14,6 +13,7 @@ import ADTypes,
LinearAlgebra,
Lux,
LuxCore,
MLDataDevices,
MLJBase,
MLJModelInterface,
MLUtils,
Expand Down
44 changes: 18 additions & 26 deletions src/base_icnf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ function construct(
compute_mode::ComputeMode = LuxVecJacMatrixMode(ADTypes.AutoZygote()),
inplace::Bool = false,
cond::Bool = aicnf <: Union{CondRNODE, CondFFJORD, CondPlanar},
resource::ComputationalResources.AbstractResource = ComputationalResources.CPU1(),
device::MLDataDevices.AbstractDevice = MLDataDevices.cpu_device(),
basedist::Distributions.Distribution = Distributions.MvNormal(
FillArrays.Zeros{data_type}(nvars + naugmented),
FillArrays.Eye{data_type}(nvars + naugmented),
Expand All @@ -19,7 +19,7 @@ function construct(
FillArrays.Eye{data_type}(nvars + naugmented),
),
sol_kwargs::NamedTuple = (;),
rng::Random.AbstractRNG = rng_AT(resource),
rng::Random.AbstractRNG = MLDataDevices.default_device_rng(device),
λ₁::AbstractFloat = if aicnf <: Union{RNODE, CondRNODE}
convert(data_type, 1.0e-2)
else
Expand All @@ -46,7 +46,7 @@ function construct(
!iszero(λ₃),
typeof(nn),
typeof(nvars),
typeof(resource),
typeof(device),
typeof(basedist),
typeof(tspan),
typeof(steerdist),
Expand All @@ -58,7 +58,7 @@ function construct(
nvars,
naugmented,
compute_mode,
resource,
device,
basedist,
tspan,
steerdist,
Expand Down Expand Up @@ -104,16 +104,8 @@ end
icnf.tspan
end

@inline function rng_AT(::ComputationalResources.AbstractResource)
Random.default_rng()
end

@inline function base_AT(
::ComputationalResources.AbstractResource,
::AbstractICNF{T},
dims...,
) where {T <: AbstractFloat}
Array{T}(undef, dims...)
@inline function base_AT(icnf::AbstractICNF{T}, dims...) where {T <: AbstractFloat}
icnf.device(Array{T}(undef, dims...))
end

ChainRulesCore.@non_differentiable base_AT(::Any...)
Expand Down Expand Up @@ -213,7 +205,7 @@ function inference_prob(
n_aug_input = n_augment_input(icnf)
zrs = similar(xs, n_aug_input + n_aug + 1)
ChainRulesCore.@ignore_derivatives fill!(zrs, zero(T))
ϵ = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input)
ϵ = base_AT(icnf, icnf.nvars + n_aug_input)
Random.rand!(icnf.rng, icnf.epsdist, ϵ)
nn = icnf.nn
SciMLBase.ODEProblem{INPLACE, SciMLBase.FullSpecialize}(
Expand All @@ -236,7 +228,7 @@ function inference_prob(
n_aug_input = n_augment_input(icnf)
zrs = similar(xs, n_aug_input + n_aug + 1)
ChainRulesCore.@ignore_derivatives fill!(zrs, zero(T))
ϵ = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input)
ϵ = base_AT(icnf, icnf.nvars + n_aug_input)
Random.rand!(icnf.rng, icnf.epsdist, ϵ)
nn = CondLayer(icnf.nn, ys)
SciMLBase.ODEProblem{INPLACE, SciMLBase.FullSpecialize}(
Expand All @@ -258,7 +250,7 @@ function inference_prob(
n_aug_input = n_augment_input(icnf)
zrs = similar(xs, n_aug_input + n_aug + 1, size(xs, 2))
ChainRulesCore.@ignore_derivatives fill!(zrs, zero(T))
ϵ = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input, size(xs, 2))
ϵ = base_AT(icnf, icnf.nvars + n_aug_input, size(xs, 2))
Random.rand!(icnf.rng, icnf.epsdist, ϵ)
nn = icnf.nn
SciMLBase.ODEProblem{INPLACE, SciMLBase.FullSpecialize}(
Expand All @@ -281,7 +273,7 @@ function inference_prob(
n_aug_input = n_augment_input(icnf)
zrs = similar(xs, n_aug_input + n_aug + 1, size(xs, 2))
ChainRulesCore.@ignore_derivatives fill!(zrs, zero(T))
ϵ = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input, size(xs, 2))
ϵ = base_AT(icnf, icnf.nvars + n_aug_input, size(xs, 2))
Random.rand!(icnf.rng, icnf.epsdist, ϵ)
nn = CondLayer(icnf.nn, ys)
SciMLBase.ODEProblem{INPLACE, SciMLBase.FullSpecialize}(
Expand All @@ -300,11 +292,11 @@ function generate_prob(
) where {T <: AbstractFloat, INPLACE}
n_aug = n_augment(icnf, mode)
n_aug_input = n_augment_input(icnf)
new_xs = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input)
new_xs = base_AT(icnf, icnf.nvars + n_aug_input)
Random.rand!(icnf.rng, icnf.basedist, new_xs)
zrs = similar(new_xs, n_aug + 1)
ChainRulesCore.@ignore_derivatives fill!(zrs, zero(T))
ϵ = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input)
ϵ = base_AT(icnf, icnf.nvars + n_aug_input)
Random.rand!(icnf.rng, icnf.epsdist, ϵ)
nn = icnf.nn
SciMLBase.ODEProblem{INPLACE, SciMLBase.FullSpecialize}(
Expand All @@ -324,11 +316,11 @@ function generate_prob(
) where {T <: AbstractFloat, INPLACE}
n_aug = n_augment(icnf, mode)
n_aug_input = n_augment_input(icnf)
new_xs = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input)
new_xs = base_AT(icnf, icnf.nvars + n_aug_input)
Random.rand!(icnf.rng, icnf.basedist, new_xs)
zrs = similar(new_xs, n_aug + 1)
ChainRulesCore.@ignore_derivatives fill!(zrs, zero(T))
ϵ = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input)
ϵ = base_AT(icnf, icnf.nvars + n_aug_input)
Random.rand!(icnf.rng, icnf.epsdist, ϵ)
nn = CondLayer(icnf.nn, ys)
SciMLBase.ODEProblem{INPLACE, SciMLBase.FullSpecialize}(
Expand All @@ -348,11 +340,11 @@ function generate_prob(
) where {T <: AbstractFloat, INPLACE}
n_aug = n_augment(icnf, mode)
n_aug_input = n_augment_input(icnf)
new_xs = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input, n)
new_xs = base_AT(icnf, icnf.nvars + n_aug_input, n)
Random.rand!(icnf.rng, icnf.basedist, new_xs)
zrs = similar(new_xs, n_aug + 1, n)
ChainRulesCore.@ignore_derivatives fill!(zrs, zero(T))
ϵ = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input, n)
ϵ = base_AT(icnf, icnf.nvars + n_aug_input, n)
Random.rand!(icnf.rng, icnf.epsdist, ϵ)
nn = icnf.nn
SciMLBase.ODEProblem{INPLACE, SciMLBase.FullSpecialize}(
Expand All @@ -373,11 +365,11 @@ function generate_prob(
) where {T <: AbstractFloat, INPLACE}
n_aug = n_augment(icnf, mode)
n_aug_input = n_augment_input(icnf)
new_xs = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input, n)
new_xs = base_AT(icnf, icnf.nvars + n_aug_input, n)
Random.rand!(icnf.rng, icnf.basedist, new_xs)
zrs = similar(new_xs, n_aug + 1, n)
ChainRulesCore.@ignore_derivatives fill!(zrs, zero(T))
ϵ = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input, n)
ϵ = base_AT(icnf, icnf.nvars + n_aug_input, n)
Random.rand!(icnf.rng, icnf.epsdist, ϵ)
nn = CondLayer(icnf.nn, ys)
SciMLBase.ODEProblem{INPLACE, SciMLBase.FullSpecialize}(
Expand Down
24 changes: 7 additions & 17 deletions src/exts/mlj_ext/core_cond_icnf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,12 @@ function MLJModelInterface.fit(model::CondICNFModel, verbosity, XY)
X, Y = XY
x = collect(transpose(MLJModelInterface.matrix(X)))
y = collect(transpose(MLJModelInterface.matrix(Y)))
tdev = if model.m.resource isa ComputationalResources.CUDALibs
Lux.gpu_device()
else
Lux.cpu_device()
end
ps, st = LuxCore.setup(model.m.rng, model.m)
ps = ComponentArrays.ComponentArray(ps)
x = tdev(x)
y = tdev(y)
ps = tdev(ps)
st = tdev(st)
x = model.m.device(x)
y = model.m.device(y)
ps = model.m.device(ps)
st = model.m.device(st)
data = if model.m.compute_mode isa VectorMode
MLUtils.DataLoader((x, y); batchsize = -1, shuffle = true, partial = true)
elseif model.m.compute_mode isa MatrixMode
Expand All @@ -55,7 +50,7 @@ function MLJModelInterface.fit(model::CondICNFModel, verbosity, XY)
else
error("Not Implemented")
end
data = tdev(data)
data = model.m.device(data)
optfunc = SciMLBase.OptimizationFunction(
make_opt_loss(model.m, TrainMode(), st, model.loss),
model.adtype,
Expand Down Expand Up @@ -94,13 +89,8 @@ function MLJModelInterface.transform(model::CondICNFModel, fitresult, XYnew)
Xnew, Ynew = XYnew
xnew = collect(transpose(MLJModelInterface.matrix(Xnew)))
ynew = collect(transpose(MLJModelInterface.matrix(Ynew)))
tdev = if model.m.resource isa ComputationalResources.CUDALibs
Lux.gpu_device()
else
Lux.cpu_device()
end
xnew = tdev(xnew)
ynew = tdev(ynew)
xnew = model.m.device(xnew)
ynew = model.m.device(ynew)
(ps, st) = fitresult

tst = @timed if model.m.compute_mode isa VectorMode
Expand Down
20 changes: 5 additions & 15 deletions src/exts/mlj_ext/core_icnf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,11 @@ end

function MLJModelInterface.fit(model::ICNFModel, verbosity, X)
x = collect(transpose(MLJModelInterface.matrix(X)))
tdev = if model.m.resource isa ComputationalResources.CUDALibs
Lux.gpu_device()
else
Lux.cpu_device()
end
ps, st = LuxCore.setup(model.m.rng, model.m)
ps = ComponentArrays.ComponentArray(ps)
x = tdev(x)
ps = tdev(ps)
st = tdev(st)
x = model.m.device(x)
ps = model.m.device(ps)
st = model.m.device(st)
data = if model.m.compute_mode isa VectorMode
MLUtils.DataLoader((x,); batchsize = -1, shuffle = true, partial = true)
elseif model.m.compute_mode isa MatrixMode
Expand All @@ -52,7 +47,7 @@ function MLJModelInterface.fit(model::ICNFModel, verbosity, X)
else
error("Not Implemented")
end
data = tdev(data)
data = model.m.device(data)
optfunc = SciMLBase.OptimizationFunction(
make_opt_loss(model.m, TrainMode(), st, model.loss),
model.adtype,
Expand Down Expand Up @@ -90,12 +85,7 @@ end

function MLJModelInterface.transform(model::ICNFModel, fitresult, Xnew)
xnew = collect(transpose(MLJModelInterface.matrix(Xnew)))
tdev = if model.m.resource isa ComputationalResources.CUDALibs
Lux.gpu_device()
else
Lux.cpu_device()
end
xnew = tdev(xnew)
xnew = model.m.device(xnew)
(ps, st) = fitresult

tst = @timed if model.m.compute_mode isa VectorMode
Expand Down
4 changes: 2 additions & 2 deletions src/icnf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ struct ICNF{
NORM_Z_AUG,
NN <: LuxCore.AbstractLuxLayer,
NVARS <: Int,
RESOURCE <: ComputationalResources.AbstractResource,
DEVICE <: MLDataDevices.AbstractDevice,
BASEDIST <: Distributions.Distribution,
TSPAN <: NTuple{2, T},
STEERDIST <: Distributions.Distribution,
Expand All @@ -91,7 +91,7 @@ struct ICNF{
naugmented::NVARS

compute_mode::CM
resource::RESOURCE
device::DEVICE
basedist::BASEDIST
tspan::TSPAN
steerdist::STEERDIST
Expand Down
4 changes: 2 additions & 2 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
ComputationalResources = "ed09eef8-17a6-5b46-8889-db040fac31e3"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Expand All @@ -21,13 +21,13 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
ADTypes = "1"
Aqua = "0.8"
ComponentArrays = "0.15"
ComputationalResources = "0.3"
DataFrames = "1"
DifferentiationInterface = "0.6"
Distances = "0.10"
Distributions = "0.25"
JET = "0.9"
Lux = "1"
MLDataDevices = "1"
MLJBase = "1"
SciMLBase = "2"
StableRNGs = "1"
Expand Down
Loading
Loading