@@ -7,7 +7,7 @@ function construct(
77 compute_mode:: ComputeMode = LuxVecJacMatrixMode(ADTypes. AutoZygote()),
88 inplace:: Bool = false ,
99 cond:: Bool = aicnf <: Union{CondRNODE, CondFFJORD, CondPlanar} ,
10- resource :: ComputationalResources.AbstractResource = ComputationalResources . CPU1 (),
10+ device :: MLDataDevices.AbstractDevice = MLDataDevices . cpu_device (),
1111 basedist:: Distributions.Distribution = Distributions. MvNormal(
1212 FillArrays. Zeros{data_type}(nvars + naugmented),
1313 FillArrays. Eye{data_type}(nvars + naugmented),
@@ -19,7 +19,7 @@ function construct(
1919 FillArrays. Eye{data_type}(nvars + naugmented),
2020 ),
2121 sol_kwargs:: NamedTuple = (;),
22- rng:: Random.AbstractRNG = rng_AT(resource ),
22+ rng:: Random.AbstractRNG = MLDataDevices . default_device_rng(device ),
2323 λ₁:: AbstractFloat = if aicnf <: Union{RNODE, CondRNODE}
2424 convert(data_type, 1.0e-2 )
2525 else
@@ -46,7 +46,7 @@ function construct(
4646 ! iszero(λ₃),
4747 typeof(nn),
4848 typeof(nvars),
49- typeof(resource ),
49+ typeof(device ),
5050 typeof(basedist),
5151 typeof(tspan),
5252 typeof(steerdist),
@@ -58,7 +58,7 @@ function construct(
5858 nvars,
5959 naugmented,
6060 compute_mode,
61- resource ,
61+ device ,
6262 basedist,
6363 tspan,
6464 steerdist,
104104 icnf. tspan
105105end
106106
107- @inline function rng_AT(:: ComputationalResources.AbstractResource )
108- Random. default_rng()
109- end
110-
111- @inline function base_AT(
112- :: ComputationalResources.AbstractResource ,
113- :: AbstractICNF{T} ,
114- dims... ,
115- ) where {T <: AbstractFloat }
116- Array{T}(undef, dims... )
107+ @inline function base_AT(icnf:: AbstractICNF{T} , dims... ) where {T <: AbstractFloat }
108+ icnf. device(Array{T}(undef, dims... ))
117109end
118110
119111ChainRulesCore. @non_differentiable base_AT(:: Any... )
@@ -213,7 +205,7 @@ function inference_prob(
213205 n_aug_input = n_augment_input(icnf)
214206 zrs = similar(xs, n_aug_input + n_aug + 1 )
215207 ChainRulesCore. @ignore_derivatives fill!(zrs, zero(T))
216- ϵ = base_AT(icnf. resource, icnf , icnf. nvars + n_aug_input)
208+ ϵ = base_AT(icnf, icnf. nvars + n_aug_input)
217209 Random. rand!(icnf. rng, icnf. epsdist, ϵ)
218210 nn = icnf. nn
219211 SciMLBase. ODEProblem{INPLACE, SciMLBase. FullSpecialize}(
@@ -236,7 +228,7 @@ function inference_prob(
236228 n_aug_input = n_augment_input(icnf)
237229 zrs = similar(xs, n_aug_input + n_aug + 1 )
238230 ChainRulesCore. @ignore_derivatives fill!(zrs, zero(T))
239- ϵ = base_AT(icnf. resource, icnf , icnf. nvars + n_aug_input)
231+ ϵ = base_AT(icnf, icnf. nvars + n_aug_input)
240232 Random. rand!(icnf. rng, icnf. epsdist, ϵ)
241233 nn = CondLayer(icnf. nn, ys)
242234 SciMLBase. ODEProblem{INPLACE, SciMLBase. FullSpecialize}(
@@ -258,7 +250,7 @@ function inference_prob(
258250 n_aug_input = n_augment_input(icnf)
259251 zrs = similar(xs, n_aug_input + n_aug + 1 , size(xs, 2 ))
260252 ChainRulesCore. @ignore_derivatives fill!(zrs, zero(T))
261- ϵ = base_AT(icnf. resource, icnf , icnf. nvars + n_aug_input, size(xs, 2 ))
253+ ϵ = base_AT(icnf, icnf. nvars + n_aug_input, size(xs, 2 ))
262254 Random. rand!(icnf. rng, icnf. epsdist, ϵ)
263255 nn = icnf. nn
264256 SciMLBase. ODEProblem{INPLACE, SciMLBase. FullSpecialize}(
@@ -281,7 +273,7 @@ function inference_prob(
281273 n_aug_input = n_augment_input(icnf)
282274 zrs = similar(xs, n_aug_input + n_aug + 1 , size(xs, 2 ))
283275 ChainRulesCore. @ignore_derivatives fill!(zrs, zero(T))
284- ϵ = base_AT(icnf. resource, icnf , icnf. nvars + n_aug_input, size(xs, 2 ))
276+ ϵ = base_AT(icnf, icnf. nvars + n_aug_input, size(xs, 2 ))
285277 Random. rand!(icnf. rng, icnf. epsdist, ϵ)
286278 nn = CondLayer(icnf. nn, ys)
287279 SciMLBase. ODEProblem{INPLACE, SciMLBase. FullSpecialize}(
@@ -300,11 +292,11 @@ function generate_prob(
300292) where {T <: AbstractFloat , INPLACE}
301293 n_aug = n_augment(icnf, mode)
302294 n_aug_input = n_augment_input(icnf)
303- new_xs = base_AT(icnf. resource, icnf , icnf. nvars + n_aug_input)
295+ new_xs = base_AT(icnf, icnf. nvars + n_aug_input)
304296 Random. rand!(icnf. rng, icnf. basedist, new_xs)
305297 zrs = similar(new_xs, n_aug + 1 )
306298 ChainRulesCore. @ignore_derivatives fill!(zrs, zero(T))
307- ϵ = base_AT(icnf. resource, icnf , icnf. nvars + n_aug_input)
299+ ϵ = base_AT(icnf, icnf. nvars + n_aug_input)
308300 Random. rand!(icnf. rng, icnf. epsdist, ϵ)
309301 nn = icnf. nn
310302 SciMLBase. ODEProblem{INPLACE, SciMLBase. FullSpecialize}(
@@ -324,11 +316,11 @@ function generate_prob(
324316) where {T <: AbstractFloat , INPLACE}
325317 n_aug = n_augment(icnf, mode)
326318 n_aug_input = n_augment_input(icnf)
327- new_xs = base_AT(icnf. resource, icnf , icnf. nvars + n_aug_input)
319+ new_xs = base_AT(icnf, icnf. nvars + n_aug_input)
328320 Random. rand!(icnf. rng, icnf. basedist, new_xs)
329321 zrs = similar(new_xs, n_aug + 1 )
330322 ChainRulesCore. @ignore_derivatives fill!(zrs, zero(T))
331- ϵ = base_AT(icnf. resource, icnf , icnf. nvars + n_aug_input)
323+ ϵ = base_AT(icnf, icnf. nvars + n_aug_input)
332324 Random. rand!(icnf. rng, icnf. epsdist, ϵ)
333325 nn = CondLayer(icnf. nn, ys)
334326 SciMLBase. ODEProblem{INPLACE, SciMLBase. FullSpecialize}(
@@ -348,11 +340,11 @@ function generate_prob(
348340) where {T <: AbstractFloat , INPLACE}
349341 n_aug = n_augment(icnf, mode)
350342 n_aug_input = n_augment_input(icnf)
351- new_xs = base_AT(icnf. resource, icnf , icnf. nvars + n_aug_input, n)
343+ new_xs = base_AT(icnf, icnf. nvars + n_aug_input, n)
352344 Random. rand!(icnf. rng, icnf. basedist, new_xs)
353345 zrs = similar(new_xs, n_aug + 1 , n)
354346 ChainRulesCore. @ignore_derivatives fill!(zrs, zero(T))
355- ϵ = base_AT(icnf. resource, icnf , icnf. nvars + n_aug_input, n)
347+ ϵ = base_AT(icnf, icnf. nvars + n_aug_input, n)
356348 Random. rand!(icnf. rng, icnf. epsdist, ϵ)
357349 nn = icnf. nn
358350 SciMLBase. ODEProblem{INPLACE, SciMLBase. FullSpecialize}(
@@ -373,11 +365,11 @@ function generate_prob(
373365) where {T <: AbstractFloat , INPLACE}
374366 n_aug = n_augment(icnf, mode)
375367 n_aug_input = n_augment_input(icnf)
376- new_xs = base_AT(icnf. resource, icnf , icnf. nvars + n_aug_input, n)
368+ new_xs = base_AT(icnf, icnf. nvars + n_aug_input, n)
377369 Random. rand!(icnf. rng, icnf. basedist, new_xs)
378370 zrs = similar(new_xs, n_aug + 1 , n)
379371 ChainRulesCore. @ignore_derivatives fill!(zrs, zero(T))
380- ϵ = base_AT(icnf. resource, icnf , icnf. nvars + n_aug_input, n)
372+ ϵ = base_AT(icnf, icnf. nvars + n_aug_input, n)
381373 Random. rand!(icnf. rng, icnf. epsdist, ϵ)
382374 nn = CondLayer(icnf. nn, ys)
383375 SciMLBase. ODEProblem{INPLACE, SciMLBase. FullSpecialize}(
0 commit comments