-
-
Notifications
You must be signed in to change notification settings - Fork 79
Open
Labels
Description
Describe the bug 🐞
This is the MTK version of #1160, but in this case we get nothing as the result of the gradient.
Expected behavior
The gradient calculation should work
Minimal Reproducible Example 👇
Without MRE, we would only be able to help you to a limited extent, and attention to the issue would be limited. to know more about MRE refer to wikipedia and stackoverflow.
using OrdinaryDiffEqTsit5
using SciMLSensitivity
using Zygote
using ModelingToolkit
using ModelingToolkit: D_nounits as D, t_nounits as t
function mae2(sol, data)
l = zero(eltype(data))
for i in axes(data, 2)
for j in axes(data, 1)
l += abs2(sol.u[i][j] - data[j, i])
end
end
l / length(data)
end
function lotka()
@variables x(t) = 3.1 y(t) = 1.5
@parameters α = 1.3 β = 0.9 γ = 0.8 δ = 1.8
eqs = [
D(x) ~ α * x - β * x * y,
D(y) ~ -δ * y + γ * x * y
]
return complete(System(eqs, t, name=:lotka))
end
sys = lotka()
function ensemble_setup(x, sys)
function prob_func(prob, i, repeat)
remake(prob, u0=rand(2), initializealg=BrownFullBasicInit())
end
# function f(du, u, p, t)
# du[1] = p[1] * u[1] - p[2] * u[1] * u[2]
# du[2] = -3 * u[2] + u[1] * u[2]
# end
# prob = ODEProblem(f, [0.5, 0.5], (0.0, 1.0), x)
prob = ODEProblem(sys, [], (0., 1))
prob, prob_func
end
function ensemble_loss(x, data, sys)
prob, prob_func = ensemble_setup(x, sys)
function output_func(sol, i)
(mae2(sol, data), false)
end
ensembleprob = EnsembleProblem(prob; prob_func, output_func, safetycopy=false)
sim = solve(ensembleprob, Tsit5(), EnsembleSerial();
trajectories=3, saveat=[0., 0.4, 0.9],
save_end=true)
sum(sim)
end
_data = [1.1 2 4
0 5. 6]
s = ensemble_loss(rand(2), _data, sys)
Zygote.gradient(x -> ensemble_loss(x, _data, sys), rand(2))Error & Stacktrace
(nothing,)Environment (please complete the following information):
- Output of
using Pkg; Pkg.status()
[961ee093] ModelingToolkit v10.18.0
[1ed8b502] SciMLSensitivity v7.89.0
[e88e6eb3] Zygote v0.7.10- Output of
using Pkg; Pkg.status(; mode = PKGMODE_MANIFEST)
- Output of
versioninfo()
Additional context
Add any other context about the problem here.