Skip to content

Commit eedfc28

Browse files
Merge branch 'master' into speed
2 parents f440f76 + 74ef64f commit eedfc28

File tree

9 files changed

+117
-12
lines changed

9 files changed

+117
-12
lines changed

docs/src/examples/remake.md

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,8 @@ function loss(x, p)
5151
odeprob = p[1] # ODEProblem stored as parameters to avoid using global variables
5252
ps = parameter_values(odeprob) # obtain the parameter object from the problem
5353
ps = replace(Tunable(), ps, x) # create a copy with the values passed to the loss function
54-
T = eltype(x)
55-
# we also have to convert the `u0` vector
56-
u0 = T.(state_values(odeprob))
5754
# remake the problem, passing in our new parameter object
58-
newprob = remake(odeprob; u0 = u0, p = ps)
55+
newprob = remake(odeprob; p = ps)
5956
timesteps = p[2]
6057
sol = solve(newprob, AutoTsit5(Rosenbrock23()); saveat = timesteps)
6158
truth = p[3]

src/systems/abstractsystem.jl

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -859,6 +859,11 @@ function namespace_defaults(sys)
859859
for (k, v) in pairs(defs))
860860
end
861861

862+
function namespace_guesses(sys)
863+
guess = guesses(sys)
864+
Dict(unknowns(sys, k) => namespace_expr(v, sys) for (k, v) in guess)
865+
end
866+
862867
function namespace_equations(sys::AbstractSystem, ivs = independent_variables(sys))
863868
eqs = equations(sys)
864869
isempty(eqs) && return Equation[]
@@ -968,7 +973,13 @@ function full_parameters(sys::AbstractSystem)
968973
end
969974

970975
function guesses(sys::AbstractSystem)
971-
get_guesses(sys)
976+
guess = get_guesses(sys)
977+
systems = get_systems(sys)
978+
isempty(systems) && return guess
979+
for subsys in systems
980+
guess = merge(guess, namespace_guesses(subsys))
981+
end
982+
return guess
972983
end
973984

974985
# required in `src/connectors.jl:437`

src/systems/diffeqs/odesystem.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,7 @@ function flatten(sys::ODESystem, noeqs = false)
354354
get_iv(sys),
355355
unknowns(sys),
356356
parameters(sys),
357+
guesses = guesses(sys),
357358
observed = observed(sys),
358359
continuous_events = continuous_events(sys),
359360
discrete_events = discrete_events(sys),

src/systems/index_cache.jl

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ struct ParameterIndex{P, I}
1717
end
1818

1919
const ParamIndexMap = Dict{Union{Symbol, BasicSymbolic}, Tuple{Int, Int}}
20-
const UnknownIndexMap = Dict{Union{Symbol, BasicSymbolic}, Union{Int, UnitRange{Int}}}
20+
const UnknownIndexMap = Dict{
21+
Union{Symbol, BasicSymbolic}, Union{Int, UnitRange{Int}, AbstractArray{Int}}}
2122

2223
struct IndexCache
2324
unknown_idx::UnknownIndexMap
@@ -40,17 +41,32 @@ function IndexCache(sys::AbstractSystem)
4041
for sym in unks
4142
usym = unwrap(sym)
4243
sym_idx = if Symbolics.isarraysymbolic(sym)
43-
idx:(idx + length(sym) - 1)
44+
reshape(idx:(idx + length(sym) - 1), size(sym))
4445
else
4546
idx
4647
end
4748
unk_idxs[usym] = sym_idx
4849

49-
if hasname(sym)
50+
if hasname(sym) && (!istree(sym) || operation(sym) !== getindex)
5051
unk_idxs[getname(usym)] = sym_idx
5152
end
5253
idx += length(sym)
5354
end
55+
for sym in unks
56+
usym = unwrap(sym)
57+
istree(sym) && operation(sym) === getindex || continue
58+
arrsym = arguments(sym)[1]
59+
all(haskey(unk_idxs, arrsym[i]) for i in eachindex(arrsym)) || continue
60+
61+
idxs = [unk_idxs[arrsym[i]] for i in eachindex(arrsym)]
62+
if idxs == idxs[begin]:idxs[end]
63+
idxs = reshape(idxs[begin]:idxs[end], size(idxs))
64+
end
65+
unk_idxs[arrsym] = idxs
66+
if hasname(arrsym)
67+
unk_idxs[getname(arrsym)] = idxs
68+
end
69+
end
5470
end
5571

5672
disc_buffers = Dict{Any, Set{BasicSymbolic}}()
@@ -124,7 +140,7 @@ function IndexCache(sys::AbstractSystem)
124140
for (j, p) in enumerate(buf)
125141
idxs[p] = (i, j)
126142
idxs[default_toterm(p)] = (i, j)
127-
if hasname(p)
143+
if hasname(p) && (!istree(p) || operation(p) !== getindex)
128144
idxs[getname(p)] = (i, j)
129145
idxs[getname(default_toterm(p))] = (i, j)
130146
end

src/systems/parameter_buffer.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,15 @@ function SymbolicIndexingInterface.remake_buffer(sys, oldbuf::MTKParameters, val
402402
return newbuf
403403
end
404404

405+
function DiffEqBase.anyeltypedual(
406+
p::MTKParameters, ::Type{Val{counter}} = Val{0}) where {counter}
407+
DiffEqBase.anyeltypedual(p.tunable)
408+
end
409+
function DiffEqBase.anyeltypedual(p::Type{<:MTKParameters{T}},
410+
::Type{Val{counter}} = Val{0}) where {counter} where {T}
411+
DiffEqBase.__anyeltypedual(T)
412+
end
413+
405414
_subarrays(v::AbstractVector) = isempty(v) ? () : (v,)
406415
_subarrays(v::ArrayPartition) = v.x
407416
_subarrays(v::Tuple) = v

src/variables.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -456,7 +456,7 @@ end
456456
## Guess ======================================================================
457457
struct VariableGuess end
458458
Symbolics.option_to_metadata_type(::Val{:guess}) = VariableGuess
459-
getguess(x::Num) = getguess(Symbolics.unwrap(x))
459+
getguess(x::Union{Num, Symbolics.Arr}) = getguess(Symbolics.unwrap(x))
460460

461461
"""
462462
getguess(x)
@@ -469,8 +469,6 @@ Create variables with a guess like this
469469
```
470470
"""
471471
function getguess(x)
472-
p = Symbolics.getparent(x, nothing)
473-
p === nothing || (x = p)
474472
Symbolics.getmetadata(x, VariableGuess, nothing)
475473
end
476474

test/mtkparameters.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ using ModelingToolkit
22
using ModelingToolkit: t_nounits as t, D_nounits as D, MTKParameters
33
using SymbolicIndexingInterface
44
using SciMLStructures: SciMLStructures, canonicalize, Tunable, Discrete, Constants
5+
using OrdinaryDiffEq
56
using ForwardDiff
67
using JET
78

@@ -206,3 +207,24 @@ end
206207
portion, ps, ones(length(buffer)))
207208
end
208209
end
210+
211+
# Issue#2642
212+
@parameters α β γ δ
213+
@variables x(t) y(t)
214+
eqs = [D(x) ~- β * y) * x
215+
D(y) ~* x - γ) * y]
216+
@mtkbuild odesys = ODESystem(eqs, t)
217+
odeprob = ODEProblem(
218+
odesys, [x => 1.0, y => 1.0], (0.0, 10.0), [α => 1.5, β => 1.0, γ => 3.0, δ => 1.0])
219+
tunables, _... = canonicalize(Tunable(), odeprob.p)
220+
@test tunables isa AbstractVector{Float64}
221+
222+
function loss(x)
223+
ps = odeprob.p
224+
newps = SciMLStructures.replace(Tunable(), ps, x)
225+
newprob = remake(odeprob, p = newps)
226+
sol = solve(newprob, Tsit5())
227+
return sum(sol)
228+
end
229+
230+
@test_nowarn ForwardDiff.gradient(loss, collect(tunables))

test/odesystem.jl

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1120,3 +1120,42 @@ tearing_state = TearingState(expand_connections(sys))
11201120
ts_vars = tearing_state.fullvars
11211121
orig_vars = unknowns(sys)
11221122
@test isempty(setdiff(ts_vars, orig_vars))
1123+
1124+
# Guesses in hierarchical systems
1125+
@variables x(t) y(t)
1126+
@named sys = ODESystem(Equation[], t, [x], []; guesses = [x => 1.0])
1127+
@named outer = ODESystem(
1128+
[D(y) ~ sys.x + t, 0 ~ t + y - sys.x * y], t, [y], []; systems = [sys])
1129+
@test ModelingToolkit.guesses(outer)[sys.x] == 1.0
1130+
outer = structural_simplify(outer)
1131+
@test ModelingToolkit.get_guesses(outer)[sys.x] == 1.0
1132+
prob = ODEProblem(outer, [outer.y => 2.0], (0.0, 10.0))
1133+
int = init(prob, Rodas4())
1134+
@test int[outer.sys.x] == 1.0
1135+
1136+
# Ensure indexes of array symbolics are cached appropriately
1137+
@variables x(t)[1:2]
1138+
@named sys = ODESystem(Equation[], t, [x], [])
1139+
sys1 = complete(sys)
1140+
@named sys = ODESystem(Equation[], t, [x...], [])
1141+
sys2 = complete(sys)
1142+
for sys in [sys1, sys2]
1143+
for (sym, idx) in [(x, 1:2), (x[1], 1), (x[2], 2)]
1144+
@test is_variable(sys, sym)
1145+
@test variable_index(sys, sym) == idx
1146+
end
1147+
end
1148+
1149+
@variables x(t)[1:2, 1:2]
1150+
@named sys = ODESystem(Equation[], t, [x], [])
1151+
sys1 = complete(sys)
1152+
@named sys = ODESystem(Equation[], t, [x...], [])
1153+
sys2 = complete(sys)
1154+
for sys in [sys1, sys2]
1155+
@test is_variable(sys, x)
1156+
@test variable_index(sys, x) == [1 3; 2 4]
1157+
for i in eachindex(x)
1158+
@test is_variable(sys, x[i])
1159+
@test variable_index(sys, x[i]) == variable_index(sys, x)[i]
1160+
end
1161+
end

test/test_variable_metadata.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,18 @@ using ModelingToolkit
1616
@test hasguess(y) === true
1717
@test ModelingToolkit.dump_variable_metadata(y).guess == 0
1818

19+
# Issue#2653
20+
@variables y[1:3] [guess = ones(3)]
21+
@test getguess(y) == ones(3)
22+
@test hasguess(y) === true
23+
@test ModelingToolkit.dump_variable_metadata(y).guess == ones(3)
24+
25+
for i in 1:3
26+
@test getguess(y[i]) == 1.0
27+
@test hasguess(y[i]) === true
28+
@test ModelingToolkit.dump_variable_metadata(y[i]).guess == 1.0
29+
end
30+
1931
@variables y
2032
@test hasguess(y) === false
2133
@test !haskey(ModelingToolkit.dump_variable_metadata(y), :guess)

0 commit comments

Comments
 (0)