Skip to content

Commit 7eb5354

Browse files
Merge pull request #3005 from AayushSabharwal/as/autodiff-defaults
fix: improve resolution of dependent parameter defaults
2 parents a25a254 + 2f10bf5 commit 7eb5354

File tree

3 files changed

+26
-2
lines changed

3 files changed

+26
-2
lines changed

src/systems/parameter_buffer.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,7 @@ function MTKParameters(
105105
end
106106

107107
isempty(missing_params) || throw(MissingParametersError(collect(missing_params)))
108-
109-
p = Dict(unwrap(k) => fixpoint_sub(v, bigdefs) for (k, v) in p)
108+
p = Dict(unwrap(k) => (bigdefs[unwrap(k)] = fixpoint_sub(v, bigdefs)) for (k, v) in p)
110109
for (sym, _) in p
111110
if iscall(sym) && operation(sym) === getindex &&
112111
first(arguments(sym)) in all_ps

test/extensions/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
[deps]
22
BifurcationKit = "0f109fa4-8a5d-4b75-95aa-f515264e7665"
3+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
34
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
45
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
56
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"

test/extensions/ad.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ using SymbolicIndexingInterface
55
using SciMLStructures
66
using OrdinaryDiffEq
77
using SciMLSensitivity
8+
using ForwardDiff
89

910
@variables x(t)[1:3] y(t)
1011
@parameters p[1:3, 1:3] q
@@ -27,3 +28,26 @@ gs = gradient(new_p) do new_p
2728
new_sol = solve(new_prob, Tsit5())
2829
sum(new_sol)
2930
end
31+
32+
@testset "Issue#2997" begin
33+
pars = @parameters y0 mh Tγ0 Th0 h ργ0
34+
vars = @variables x(t)
35+
@named sys = ODESystem([D(x) ~ y0],
36+
t,
37+
vars,
38+
pars;
39+
defaults = [
40+
y0 => mh * 3.1 / (2.3 * Th0),
41+
mh => 123.4,
42+
Th0 => (4 / 11)^(1 / 3) * Tγ0,
43+
Tγ0 => (15 / π^2 * ργ0 * (2 * h)^2 / 7)^(1 / 4) / 5
44+
])
45+
sys = structural_simplify(sys)
46+
47+
function x_at_0(θ)
48+
prob = ODEProblem(sys, [sys.x => 1.0], (0.0, 1.0), [sys.ργ0 => θ[1], sys.h => θ[2]])
49+
return prob.u0[1]
50+
end
51+
52+
@test ForwardDiff.gradient(x_at_0, [0.3, 0.7]) == zeros(2)
53+
end

0 commit comments

Comments
 (0)