Skip to content

Commit 174608e

Browse files
Fix GPU test
1 parent e5be661 commit 174608e

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

test/gpu/diffeqflux_standard_gpu.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ function trueODEfunc(du, u, p, t)
1616
end
1717
prob_trueode = ODEProblem(trueODEfunc, u0, tspan)
1818
# Make the data into a GPU-based array if the user has a GPU
19-
ode_data = gdev(solve(prob_trueode, Tsit5(), saveat = tsteps))
19+
ode_data = gdev(Array(solve(prob_trueode, Tsit5(), saveat = tsteps)))
2020

2121
dudt2 = Chain(x -> x .^ 3, Dense(2, 50, tanh), Dense(50, 2))
2222
u0 = Float32[2.0; 0.0] |> gdev
@@ -26,12 +26,12 @@ ps, st = Lux.setup(Random.default_rng(), dudt2)
2626
ps = ComponentArray(ps) |> gdev
2727

2828
function predict_neuralode(p)
29-
gdev(first(prob_neuralode(u0, p, st)))
29+
first(prob_neuralode(u0, p, st))
3030
end
3131
function loss_neuralode(p)
3232
pred = predict_neuralode(p)
3333
loss = sum(abs2, ode_data .- pred)
3434
return loss
3535
end
3636

37-
Zygote.gradient(loss_neuralode, ps)
37+
Zygote.gradient(loss_neuralode, ps)

0 commit comments

Comments
 (0)