We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent e5be661 commit 174608eCopy full SHA for 174608e
test/gpu/diffeqflux_standard_gpu.jl
@@ -16,7 +16,7 @@ function trueODEfunc(du, u, p, t)
16
end
17
prob_trueode = ODEProblem(trueODEfunc, u0, tspan)
18
# Make the data into a GPU-based array if the user has a GPU
19
-ode_data = gdev(solve(prob_trueode, Tsit5(), saveat = tsteps))
+ode_data = gdev(Array(solve(prob_trueode, Tsit5(), saveat = tsteps)))
20
21
dudt2 = Chain(x -> x .^ 3, Dense(2, 50, tanh), Dense(50, 2))
22
u0 = Float32[2.0; 0.0] |> gdev
@@ -26,12 +26,12 @@ ps, st = Lux.setup(Random.default_rng(), dudt2)
26
ps = ComponentArray(ps) |> gdev
27
28
function predict_neuralode(p)
29
- gdev(first(prob_neuralode(u0, p, st)))
+ first(prob_neuralode(u0, p, st))
30
31
function loss_neuralode(p)
32
pred = predict_neuralode(p)
33
loss = sum(abs2, ode_data .- pred)
34
return loss
35
36
37
-Zygote.gradient(loss_neuralode, ps)
+Zygote.gradient(loss_neuralode, ps)
0 commit comments