Skip to content

Commit 2255638

Browse files
committed
fix: remove deprecated API
1 parent c8e8c91 commit 2255638

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

test/gpu/mixed_gpu_cpu_adjoint.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
using SciMLSensitivity, OrdinaryDiffEq
22
using Lux, LuxCUDA, Test, Zygote, Random, LinearAlgebra, ComponentArrays
33

4+
const gdev = gpu_device()
5+
46
CUDA.allowscalar(false)
57

68
H = CuArray(rand(Float32, 2, 2))
@@ -42,10 +44,10 @@ grad = Zygote.gradient(cost, p)[1]
4244
rng = MersenneTwister(1234)
4345
m = 32
4446
n = 16
45-
Z = randn(rng, Float32, (n, m)) |> gpu
47+
Z = randn(rng, Float32, (n, m)) |> gdev
4648
𝒯 = 2.0f0
4749
Δτ = 1.0f-1
48-
ca_init = [zeros(1); ones(m)] |> gpu
50+
ca_init = [zeros(1); ones(m)] |> gdev
4951

5052
function f(ca, Z, t)
5153
a = ca[2:end]
@@ -54,7 +56,7 @@ function f(ca, Z, t)
5456
Ka_unit = Z' * w_unit
5557
z_unit = dot(abs.(Ka_unit), a_unit)
5658
aKa_over_z = a .* Ka_unit / z_unit
57-
[sum(aKa_over_z) / m; -abs.(aKa_over_z)] |> gpu
59+
[sum(aKa_over_z) / m; -abs.(aKa_over_z)] |> gdev
5860
end
5961

6062
function c(Z)

0 commit comments

Comments
 (0)