Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 19 additions & 20 deletions src/remake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,7 @@ function varmap_get(varmap, var, default = nothing)
return default
end

anydict(d::Dict{Any, Any}) = d
anydict(d) = Dict{Any, Any}(d)
anydict() = Dict{Any, Any}()

Expand Down Expand Up @@ -658,8 +659,9 @@ function _updated_u0_p_symmap(prob, u0, ::Val{true}, p, ::Val{false}, t0)
isdep = any(symbolic_type(v) !== NotSymbolic() for (_, v) in u0)
isdep || return remake_buffer(prob, state_values(prob), keys(u0), values(u0)), p

u0 = anydict(k => symbolic_type(v) === NotSymbolic() ? v : symbolic_evaluate(v, u0)
for (k, v) in u0)
for (k, v) in u0
u0[k] = symbolic_type(v) === NotSymbolic() ? v : symbolic_evaluate(v, u0)
end

isdep = any(symbolic_type(v) !== NotSymbolic() for (_, v) in u0)
isdep || return remake_buffer(prob, state_values(prob), keys(u0), values(u0)), p
Expand All @@ -668,17 +670,19 @@ function _updated_u0_p_symmap(prob, u0, ::Val{true}, p, ::Val{false}, t0)
# This is sort of an implicit dependency on MTK. The values of `u` won't actually be
# used, since any state symbols in the expression were substituted out earlier.
temp_state = ProblemState(; u = state_values(prob), p = p, t = t0)
u0 = anydict(k => symbolic_type(v) === NotSymbolic() ? v : getu(prob, v)(temp_state)
for (k, v) in u0)
for (k, v) in u0
u0[k] = symbolic_type(v) === NotSymbolic() ? v : getu(prob, v)(temp_state)
end
return remake_buffer(prob, state_values(prob), keys(u0), values(u0)), p
end

function _updated_u0_p_symmap(prob, u0, ::Val{false}, p, ::Val{true}, t0)
isdep = any(symbolic_type(v) !== NotSymbolic() for (_, v) in p)
isdep || return u0, remake_buffer(prob, parameter_values(prob), keys(p), values(p))

p = anydict(k => symbolic_type(v) === NotSymbolic() ? v : symbolic_evaluate(v, p)
for (k, v) in p)
for (k, v) in p
p[k] = symbolic_type(v) === NotSymbolic() ? v : symbolic_evaluate(v, p)
end
Comment on lines -680 to +685
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this preferred? This would be a slower construction.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it? I thought it would be faster since it doesn't create a new dictionary

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh I didn't see that it was extending an old one.

It's still adding elements one by one. That might amortize?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

julia> function oop(dict)
           return Dict{Any, Any}(k => 2v for (k, v) in dict)
       end
julia> function iip(dict)
           for (k, v) in dict
               dict[k] = 2v
           end
       end
julia> dict = Dict{Any, Any}(rand() => rand() for _ in 1:10000);
julia> oop(dict); iip(dict);
julia> @be oop(dict)
Benchmark: 25 samples with 1 evaluation
 min    2.599 ms (50092 allocs: 1.183 MiB)
 median 2.696 ms (50092 allocs: 1.183 MiB)
 mean   3.925 ms (50092 allocs: 1.183 MiB, 3.60% gc time)
 max    29.143 ms (50092 allocs: 1.183 MiB, 90.08% gc time)

julia> @be iip(dict)
Benchmark: 86 samples with 1 evaluation
 min    1.004 ms (20085 allocs: 313.828 KiB)
 median 1.052 ms (20085 allocs: 313.828 KiB)
 mean   1.140 ms (20085 allocs: 313.828 KiB, 1.77% gc time)
 max    5.962 ms (20085 allocs: 313.828 KiB, 81.95% gc time)

In-place seems faster

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, that's mainly because the generator isn't optimizing there 😅 . Both are slow, but it's fine.


isdep = any(symbolic_type(v) !== NotSymbolic() for (_, v) in p)
isdep || return u0, remake_buffer(prob, parameter_values(prob), keys(p), values(p))
Expand All @@ -687,8 +691,9 @@ function _updated_u0_p_symmap(prob, u0, ::Val{false}, p, ::Val{true}, t0)
# this is sort of an implicit dependency on MTK. The values of `p` won't actually be
# used, since any parameter symbols in the expression were substituted out earlier.
temp_state = ProblemState(; u = u0, p = parameter_values(prob), t = t0)
p = anydict(k => symbolic_type(v) === NotSymbolic() ? v : getu(prob, v)(temp_state)
for (k, v) in p)
for (k, v) in p
p[k] = symbolic_type(v) === NotSymbolic() ? v : getu(prob, v)(temp_state)
end
return u0, remake_buffer(prob, parameter_values(prob), keys(p), values(p))
end

Expand All @@ -700,20 +705,14 @@ function _updated_u0_p_symmap(prob, u0, ::Val{true}, p, ::Val{true}, t0)
return remake_buffer(prob, state_values(prob), keys(u0), values(u0)),
remake_buffer(prob, parameter_values(prob), keys(p), values(p))
end
if !isu0dep
u0 = remake_buffer(prob, state_values(prob), keys(u0), values(u0))
return _updated_u0_p_symmap(prob, u0, Val(false), p, Val(true), t0)
end
if !ispdep
p = remake_buffer(prob, parameter_values(prob), keys(p), values(p))
return _updated_u0_p_symmap(prob, u0, Val(true), p, Val(false), t0)
end

varmap = merge(u0, p)
u0 = anydict(k => symbolic_type(v) === NotSymbolic() ? v : symbolic_evaluate(v, varmap)
for (k, v) in u0)
p = anydict(k => symbolic_type(v) === NotSymbolic() ? v : symbolic_evaluate(v, varmap)
for (k, v) in p)
for (k, v) in u0
u0[k] = symbolic_type(v) === NotSymbolic() ? v : symbolic_evaluate(v, varmap)
end
for (k, v) in p
p[k] = symbolic_type(v) === NotSymbolic() ? v : symbolic_evaluate(v, varmap)
end
return remake_buffer(prob, state_values(prob), keys(u0), values(u0)),
remake_buffer(prob, parameter_values(prob), keys(p), values(p))
end
Expand Down
10 changes: 10 additions & 0 deletions test/downstream/modelingtoolkit_remake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -219,3 +219,13 @@ end
@test newoprob.ps[k] == [2.0, 3.0, 4.0, 5.0]
@test newoprob[V] == [1.5, 2.5]
end

@testset "remake with parameter dependent on observed" begin
@variables x(t) y(t)
@parameters p = x + y
@mtkbuild sys = ODESystem([D(x) ~ x, p ~ x + y], t)
prob = ODEProblem(sys, [x => 1.0, y => 2.0], (0.0, 1.0))
@test prob.ps[p] ≈ 3.0
prob2 = remake(prob; u0 = [y => 3.0], p = Dict())
@test prob2.ps[p] ≈ 4.0
end
Loading