Skip to content

Commit a381821

Browse files
authored
migrate to MLDataDevices (#464)
* migrate to `MLDataDevices` * compat
1 parent 13956b8 commit a381821

File tree

12 files changed

+54
-111
lines changed

12 files changed

+54
-111
lines changed

Project.toml

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ version = "0.26.1"
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
88
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
99
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
10-
ComputationalResources = "ed09eef8-17a6-5b46-8889-db040fac31e3"
1110
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
1211
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
1312
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
@@ -17,6 +16,7 @@ FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
1716
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1817
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
1918
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
19+
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
2020
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
2121
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
2222
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
@@ -32,18 +32,10 @@ ScientificTypesBase = "30f210dd-8aff-4c5f-94ba-8e64358c1161"
3232
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
3333
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
3434

35-
[weakdeps]
36-
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
37-
38-
[extensions]
39-
ContinuousNormalizingFlowsCUDAExt = "CUDA"
40-
4135
[compat]
4236
ADTypes = "1"
43-
CUDA = "5"
4437
ChainRulesCore = "1"
4538
ComponentArrays = "0.15"
46-
ComputationalResources = "0.3"
4739
DataFrames = "1"
4840
Dates = "1"
4941
DifferentiationInterface = "0.6"
@@ -53,6 +45,7 @@ FillArrays = "1"
5345
LinearAlgebra = "1"
5446
Lux = "1"
5547
LuxCore = "1"
48+
MLDataDevices = "1"
5649
MLJBase = "1"
5750
MLJModelInterface = "1"
5851
MLUtils = "0.4"

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ n_in = nvars + naugs # with augmentation
4343
n = 1024
4444

4545
# Model
46-
using ContinuousNormalizingFlows, Lux, ADTypes #, Zygote, CUDA, ComputationalResources
46+
using ContinuousNormalizingFlows, Lux, ADTypes #, Zygote, CUDA, MLDataDevices
4747
nn = Chain(Dense(n_in => 3 * n_in, tanh), Dense(3 * n_in => n_in, tanh))
4848
icnf = construct(
4949
RNODE,
@@ -52,7 +52,7 @@ icnf = construct(
5252
naugs; # number of augmented dimensions
5353
# compute_mode = LuxVecJacMatrixMode(AutoZygote()), # process data in batches and use Zygote
5454
# inplace = true, # use the inplace version of functions
55-
# resource = CUDALibs(), # process data by GPU
55+
# device = gpu_device(), # process data by GPU
5656
tspan = (0.0f0, 13.0f0), # have bigger time span
5757
steer_rate = 1.0f-1, # add random noise to end of the time span
5858
λ₁ = 1.0f-2, # regulate flow

ext/ContinuousNormalizingFlowsCUDAExt/ContinuousNormalizingFlowsCUDAExt.jl

Lines changed: 0 additions & 17 deletions
This file was deleted.

src/ContinuousNormalizingFlows.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ import ADTypes,
44
Base.Iterators,
55
ChainRulesCore,
66
ComponentArrays,
7-
ComputationalResources,
87
DataFrames,
98
Dates,
109
DifferentiationInterface,
@@ -14,6 +13,7 @@ import ADTypes,
1413
LinearAlgebra,
1514
Lux,
1615
LuxCore,
16+
MLDataDevices,
1717
MLJBase,
1818
MLJModelInterface,
1919
MLUtils,

src/base_icnf.jl

Lines changed: 18 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ function construct(
77
compute_mode::ComputeMode = LuxVecJacMatrixMode(ADTypes.AutoZygote()),
88
inplace::Bool = false,
99
cond::Bool = aicnf <: Union{CondRNODE, CondFFJORD, CondPlanar},
10-
resource::ComputationalResources.AbstractResource = ComputationalResources.CPU1(),
10+
device::MLDataDevices.AbstractDevice = MLDataDevices.cpu_device(),
1111
basedist::Distributions.Distribution = Distributions.MvNormal(
1212
FillArrays.Zeros{data_type}(nvars + naugmented),
1313
FillArrays.Eye{data_type}(nvars + naugmented),
@@ -19,7 +19,7 @@ function construct(
1919
FillArrays.Eye{data_type}(nvars + naugmented),
2020
),
2121
sol_kwargs::NamedTuple = (;),
22-
rng::Random.AbstractRNG = rng_AT(resource),
22+
rng::Random.AbstractRNG = MLDataDevices.default_device_rng(device),
2323
λ₁::AbstractFloat = if aicnf <: Union{RNODE, CondRNODE}
2424
convert(data_type, 1.0e-2)
2525
else
@@ -46,7 +46,7 @@ function construct(
4646
!iszero(λ₃),
4747
typeof(nn),
4848
typeof(nvars),
49-
typeof(resource),
49+
typeof(device),
5050
typeof(basedist),
5151
typeof(tspan),
5252
typeof(steerdist),
@@ -58,7 +58,7 @@ function construct(
5858
nvars,
5959
naugmented,
6060
compute_mode,
61-
resource,
61+
device,
6262
basedist,
6363
tspan,
6464
steerdist,
@@ -104,16 +104,8 @@ end
104104
icnf.tspan
105105
end
106106

107-
@inline function rng_AT(::ComputationalResources.AbstractResource)
108-
Random.default_rng()
109-
end
110-
111-
@inline function base_AT(
112-
::ComputationalResources.AbstractResource,
113-
::AbstractICNF{T},
114-
dims...,
115-
) where {T <: AbstractFloat}
116-
Array{T}(undef, dims...)
107+
@inline function base_AT(icnf::AbstractICNF{T}, dims...) where {T <: AbstractFloat}
108+
icnf.device(Array{T}(undef, dims...))
117109
end
118110

119111
ChainRulesCore.@non_differentiable base_AT(::Any...)
@@ -213,7 +205,7 @@ function inference_prob(
213205
n_aug_input = n_augment_input(icnf)
214206
zrs = similar(xs, n_aug_input + n_aug + 1)
215207
ChainRulesCore.@ignore_derivatives fill!(zrs, zero(T))
216-
ϵ = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input)
208+
ϵ = base_AT(icnf, icnf.nvars + n_aug_input)
217209
Random.rand!(icnf.rng, icnf.epsdist, ϵ)
218210
nn = icnf.nn
219211
SciMLBase.ODEProblem{INPLACE, SciMLBase.FullSpecialize}(
@@ -236,7 +228,7 @@ function inference_prob(
236228
n_aug_input = n_augment_input(icnf)
237229
zrs = similar(xs, n_aug_input + n_aug + 1)
238230
ChainRulesCore.@ignore_derivatives fill!(zrs, zero(T))
239-
ϵ = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input)
231+
ϵ = base_AT(icnf, icnf.nvars + n_aug_input)
240232
Random.rand!(icnf.rng, icnf.epsdist, ϵ)
241233
nn = CondLayer(icnf.nn, ys)
242234
SciMLBase.ODEProblem{INPLACE, SciMLBase.FullSpecialize}(
@@ -258,7 +250,7 @@ function inference_prob(
258250
n_aug_input = n_augment_input(icnf)
259251
zrs = similar(xs, n_aug_input + n_aug + 1, size(xs, 2))
260252
ChainRulesCore.@ignore_derivatives fill!(zrs, zero(T))
261-
ϵ = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input, size(xs, 2))
253+
ϵ = base_AT(icnf, icnf.nvars + n_aug_input, size(xs, 2))
262254
Random.rand!(icnf.rng, icnf.epsdist, ϵ)
263255
nn = icnf.nn
264256
SciMLBase.ODEProblem{INPLACE, SciMLBase.FullSpecialize}(
@@ -281,7 +273,7 @@ function inference_prob(
281273
n_aug_input = n_augment_input(icnf)
282274
zrs = similar(xs, n_aug_input + n_aug + 1, size(xs, 2))
283275
ChainRulesCore.@ignore_derivatives fill!(zrs, zero(T))
284-
ϵ = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input, size(xs, 2))
276+
ϵ = base_AT(icnf, icnf.nvars + n_aug_input, size(xs, 2))
285277
Random.rand!(icnf.rng, icnf.epsdist, ϵ)
286278
nn = CondLayer(icnf.nn, ys)
287279
SciMLBase.ODEProblem{INPLACE, SciMLBase.FullSpecialize}(
@@ -300,11 +292,11 @@ function generate_prob(
300292
) where {T <: AbstractFloat, INPLACE}
301293
n_aug = n_augment(icnf, mode)
302294
n_aug_input = n_augment_input(icnf)
303-
new_xs = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input)
295+
new_xs = base_AT(icnf, icnf.nvars + n_aug_input)
304296
Random.rand!(icnf.rng, icnf.basedist, new_xs)
305297
zrs = similar(new_xs, n_aug + 1)
306298
ChainRulesCore.@ignore_derivatives fill!(zrs, zero(T))
307-
ϵ = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input)
299+
ϵ = base_AT(icnf, icnf.nvars + n_aug_input)
308300
Random.rand!(icnf.rng, icnf.epsdist, ϵ)
309301
nn = icnf.nn
310302
SciMLBase.ODEProblem{INPLACE, SciMLBase.FullSpecialize}(
@@ -324,11 +316,11 @@ function generate_prob(
324316
) where {T <: AbstractFloat, INPLACE}
325317
n_aug = n_augment(icnf, mode)
326318
n_aug_input = n_augment_input(icnf)
327-
new_xs = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input)
319+
new_xs = base_AT(icnf, icnf.nvars + n_aug_input)
328320
Random.rand!(icnf.rng, icnf.basedist, new_xs)
329321
zrs = similar(new_xs, n_aug + 1)
330322
ChainRulesCore.@ignore_derivatives fill!(zrs, zero(T))
331-
ϵ = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input)
323+
ϵ = base_AT(icnf, icnf.nvars + n_aug_input)
332324
Random.rand!(icnf.rng, icnf.epsdist, ϵ)
333325
nn = CondLayer(icnf.nn, ys)
334326
SciMLBase.ODEProblem{INPLACE, SciMLBase.FullSpecialize}(
@@ -348,11 +340,11 @@ function generate_prob(
348340
) where {T <: AbstractFloat, INPLACE}
349341
n_aug = n_augment(icnf, mode)
350342
n_aug_input = n_augment_input(icnf)
351-
new_xs = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input, n)
343+
new_xs = base_AT(icnf, icnf.nvars + n_aug_input, n)
352344
Random.rand!(icnf.rng, icnf.basedist, new_xs)
353345
zrs = similar(new_xs, n_aug + 1, n)
354346
ChainRulesCore.@ignore_derivatives fill!(zrs, zero(T))
355-
ϵ = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input, n)
347+
ϵ = base_AT(icnf, icnf.nvars + n_aug_input, n)
356348
Random.rand!(icnf.rng, icnf.epsdist, ϵ)
357349
nn = icnf.nn
358350
SciMLBase.ODEProblem{INPLACE, SciMLBase.FullSpecialize}(
@@ -373,11 +365,11 @@ function generate_prob(
373365
) where {T <: AbstractFloat, INPLACE}
374366
n_aug = n_augment(icnf, mode)
375367
n_aug_input = n_augment_input(icnf)
376-
new_xs = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input, n)
368+
new_xs = base_AT(icnf, icnf.nvars + n_aug_input, n)
377369
Random.rand!(icnf.rng, icnf.basedist, new_xs)
378370
zrs = similar(new_xs, n_aug + 1, n)
379371
ChainRulesCore.@ignore_derivatives fill!(zrs, zero(T))
380-
ϵ = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input, n)
372+
ϵ = base_AT(icnf, icnf.nvars + n_aug_input, n)
381373
Random.rand!(icnf.rng, icnf.epsdist, ϵ)
382374
nn = CondLayer(icnf.nn, ys)
383375
SciMLBase.ODEProblem{INPLACE, SciMLBase.FullSpecialize}(

src/exts/mlj_ext/core_cond_icnf.jl

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -28,17 +28,12 @@ function MLJModelInterface.fit(model::CondICNFModel, verbosity, XY)
2828
X, Y = XY
2929
x = collect(transpose(MLJModelInterface.matrix(X)))
3030
y = collect(transpose(MLJModelInterface.matrix(Y)))
31-
tdev = if model.m.resource isa ComputationalResources.CUDALibs
32-
Lux.gpu_device()
33-
else
34-
Lux.cpu_device()
35-
end
3631
ps, st = LuxCore.setup(model.m.rng, model.m)
3732
ps = ComponentArrays.ComponentArray(ps)
38-
x = tdev(x)
39-
y = tdev(y)
40-
ps = tdev(ps)
41-
st = tdev(st)
33+
x = model.m.device(x)
34+
y = model.m.device(y)
35+
ps = model.m.device(ps)
36+
st = model.m.device(st)
4237
data = if model.m.compute_mode isa VectorMode
4338
MLUtils.DataLoader((x, y); batchsize = -1, shuffle = true, partial = true)
4439
elseif model.m.compute_mode isa MatrixMode
@@ -55,7 +50,7 @@ function MLJModelInterface.fit(model::CondICNFModel, verbosity, XY)
5550
else
5651
error("Not Implemented")
5752
end
58-
data = tdev(data)
53+
data = model.m.device(data)
5954
optfunc = SciMLBase.OptimizationFunction(
6055
make_opt_loss(model.m, TrainMode(), st, model.loss),
6156
model.adtype,
@@ -94,13 +89,8 @@ function MLJModelInterface.transform(model::CondICNFModel, fitresult, XYnew)
9489
Xnew, Ynew = XYnew
9590
xnew = collect(transpose(MLJModelInterface.matrix(Xnew)))
9691
ynew = collect(transpose(MLJModelInterface.matrix(Ynew)))
97-
tdev = if model.m.resource isa ComputationalResources.CUDALibs
98-
Lux.gpu_device()
99-
else
100-
Lux.cpu_device()
101-
end
102-
xnew = tdev(xnew)
103-
ynew = tdev(ynew)
92+
xnew = model.m.device(xnew)
93+
ynew = model.m.device(ynew)
10494
(ps, st) = fitresult
10595

10696
tst = @timed if model.m.compute_mode isa VectorMode

src/exts/mlj_ext/core_icnf.jl

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -26,16 +26,11 @@ end
2626

2727
function MLJModelInterface.fit(model::ICNFModel, verbosity, X)
2828
x = collect(transpose(MLJModelInterface.matrix(X)))
29-
tdev = if model.m.resource isa ComputationalResources.CUDALibs
30-
Lux.gpu_device()
31-
else
32-
Lux.cpu_device()
33-
end
3429
ps, st = LuxCore.setup(model.m.rng, model.m)
3530
ps = ComponentArrays.ComponentArray(ps)
36-
x = tdev(x)
37-
ps = tdev(ps)
38-
st = tdev(st)
31+
x = model.m.device(x)
32+
ps = model.m.device(ps)
33+
st = model.m.device(st)
3934
data = if model.m.compute_mode isa VectorMode
4035
MLUtils.DataLoader((x,); batchsize = -1, shuffle = true, partial = true)
4136
elseif model.m.compute_mode isa MatrixMode
@@ -52,7 +47,7 @@ function MLJModelInterface.fit(model::ICNFModel, verbosity, X)
5247
else
5348
error("Not Implemented")
5449
end
55-
data = tdev(data)
50+
data = model.m.device(data)
5651
optfunc = SciMLBase.OptimizationFunction(
5752
make_opt_loss(model.m, TrainMode(), st, model.loss),
5853
model.adtype,
@@ -90,12 +85,7 @@ end
9085

9186
function MLJModelInterface.transform(model::ICNFModel, fitresult, Xnew)
9287
xnew = collect(transpose(MLJModelInterface.matrix(Xnew)))
93-
tdev = if model.m.resource isa ComputationalResources.CUDALibs
94-
Lux.gpu_device()
95-
else
96-
Lux.cpu_device()
97-
end
98-
xnew = tdev(xnew)
88+
xnew = model.m.device(xnew)
9989
(ps, st) = fitresult
10090

10191
tst = @timed if model.m.compute_mode isa VectorMode

src/icnf.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ struct ICNF{
7878
NORM_Z_AUG,
7979
NN <: LuxCore.AbstractLuxLayer,
8080
NVARS <: Int,
81-
RESOURCE <: ComputationalResources.AbstractResource,
81+
DEVICE <: MLDataDevices.AbstractDevice,
8282
BASEDIST <: Distributions.Distribution,
8383
TSPAN <: NTuple{2, T},
8484
STEERDIST <: Distributions.Distribution,
@@ -91,7 +91,7 @@ struct ICNF{
9191
naugmented::NVARS
9292

9393
compute_mode::CM
94-
resource::RESOURCE
94+
device::DEVICE
9595
basedist::BASEDIST
9696
tspan::TSPAN
9797
steerdist::STEERDIST

test/Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,14 @@
22
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
33
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
44
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
5-
ComputationalResources = "ed09eef8-17a6-5b46-8889-db040fac31e3"
65
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
76
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
87
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
98
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
109
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
1110
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
1211
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
12+
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
1313
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
1414
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
1515
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
@@ -21,13 +21,13 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2121
ADTypes = "1"
2222
Aqua = "0.8"
2323
ComponentArrays = "0.15"
24-
ComputationalResources = "0.3"
2524
DataFrames = "1"
2625
DifferentiationInterface = "0.6"
2726
Distances = "0.10"
2827
Distributions = "0.25"
2928
JET = "0.9"
3029
Lux = "1"
30+
MLDataDevices = "1"
3131
MLJBase = "1"
3232
SciMLBase = "2"
3333
StableRNGs = "1"

0 commit comments

Comments
 (0)