Skip to content

Gradient issue with Ensemble solve (MTK version) #1268

@SebastianM-C

Description

@SebastianM-C

Describe the bug 🐞

This is the MTK version of #1160, but in this case we get nothing as the result of the gradient.

Expected behavior

The gradient calculation should work

Minimal Reproducible Example 👇

Without MRE, we would only be able to help you to a limited extent, and attention to the issue would be limited. to know more about MRE refer to wikipedia and stackoverflow.

using OrdinaryDiffEqTsit5
using SciMLSensitivity
using Zygote
using ModelingToolkit
using ModelingToolkit: D_nounits as D, t_nounits as t

function mae2(sol, data)
    l = zero(eltype(data))
    for i in axes(data, 2)
        for j in axes(data, 1)
            l += abs2(sol.u[i][j] - data[j, i])
        end
    end

    l / length(data)
end

function lotka()
    @variables x(t) = 3.1 y(t) = 1.5
    @parameters α = 1.3 β = 0.9 γ = 0.8 δ = 1.8
    eqs = [
        D(x) ~ α * x - β * x * y,
        D(y) ~ -δ * y + γ * x * y
    ]
    return complete(System(eqs, t, name=:lotka))
end

sys = lotka()

function ensemble_setup(x, sys)
    function prob_func(prob, i, repeat)
        remake(prob, u0=rand(2), initializealg=BrownFullBasicInit())
    end

    # function f(du, u, p, t)
    #     du[1] = p[1] * u[1] - p[2] * u[1] * u[2]
    #     du[2] = -3 * u[2] + u[1] * u[2]
    # end

    # prob = ODEProblem(f, [0.5, 0.5], (0.0, 1.0), x)
    prob = ODEProblem(sys, [], (0., 1))

    prob, prob_func
end

function ensemble_loss(x, data, sys)
    prob, prob_func = ensemble_setup(x, sys)

    function output_func(sol, i)
        (mae2(sol, data), false)
    end

    ensembleprob = EnsembleProblem(prob; prob_func, output_func, safetycopy=false)

    sim = solve(ensembleprob, Tsit5(), EnsembleSerial();
        trajectories=3, saveat=[0., 0.4, 0.9],
        save_end=true)

    sum(sim)
end
_data = [1.1 2 4
    0 5. 6]
s = ensemble_loss(rand(2), _data, sys)

Zygote.gradient(x -> ensemble_loss(x, _data, sys), rand(2))

Error & Stacktrace ⚠️

(nothing,)

Environment (please complete the following information):

  • Output of using Pkg; Pkg.status()
  [961ee093] ModelingToolkit v10.18.0
  [1ed8b502] SciMLSensitivity v7.89.0
  [e88e6eb3] Zygote v0.7.10
  • Output of using Pkg; Pkg.status(; mode = PKGMODE_MANIFEST)
  • Output of versioninfo()

Additional context

Add any other context about the problem here.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions