Skip to content

Commit 4d9dd52

Browse files
Merge pull request #2840 from AayushSabharwal/as/deepcopy-ssprob
fix: fix deepcopy for SteadyStateProblem
2 parents 900d5f0 + 91a22a7 commit 4d9dd52

File tree

2 files changed

+13
-26
lines changed

2 files changed

+13
-26
lines changed

src/systems/abstractsystem.jl

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1223,21 +1223,24 @@ end
12231223
struct ObservedFunctionCache{S}
12241224
sys::S
12251225
dict::Dict{Any, Any}
1226+
steady_state::Bool
12261227
eval_expression::Bool
12271228
eval_module::Module
12281229
end
12291230

1230-
function ObservedFunctionCache(sys; eval_expression = false, eval_module = @__MODULE__)
1231-
return ObservedFunctionCache(sys, Dict(), eval_expression, eval_module)
1231+
function ObservedFunctionCache(
1232+
sys; steady_state = false, eval_expression = false, eval_module = @__MODULE__)
1233+
return ObservedFunctionCache(sys, Dict(), steady_state, eval_expression, eval_module)
12321234
end
12331235

12341236
# This is hit because ensemble problems do a deepcopy
12351237
function Base.deepcopy_internal(ofc::ObservedFunctionCache, stackdict::IdDict)
12361238
sys = deepcopy(ofc.sys)
12371239
dict = deepcopy(ofc.dict)
1240+
steady_state = ofc.steady_state
12381241
eval_expression = ofc.eval_expression
12391242
eval_module = ofc.eval_module
1240-
newofc = ObservedFunctionCache(sys, dict, eval_expression, eval_module)
1243+
newofc = ObservedFunctionCache(sys, dict, steady_state, eval_expression, eval_module)
12411244
stackdict[ofc] = newofc
12421245
return newofc
12431246
end
@@ -1248,6 +1251,12 @@ function (ofc::ObservedFunctionCache)(obsvar, args...)
12481251
ofc.sys, obsvar; eval_expression = ofc.eval_expression,
12491252
eval_module = ofc.eval_module)
12501253
end
1254+
if ofc.steady_state
1255+
obs = let fn = obs
1256+
fn1(u, p, t = Inf) = fn(u, p, t)
1257+
fn1
1258+
end
1259+
end
12511260
if args === ()
12521261
return obs
12531262
else

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -399,29 +399,7 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
399399
ArrayInterface.restructure(u0 .* u0', M)
400400
end
401401

402-
obs = observed(sys)
403-
observedfun = if steady_state
404-
let sys = sys, dict = Dict()
405-
function generated_observed(obsvar, args...)
406-
obs = get!(dict, value(obsvar)) do
407-
SymbolicIndexingInterface.observed(
408-
sys, obsvar; eval_expression, eval_module)
409-
end
410-
if args === ()
411-
return let obs = obs
412-
fn1(u, p, t = Inf) = obs(u, p, t)
413-
fn1
414-
end
415-
elseif length(args) == 2
416-
return obs(args..., Inf)
417-
else
418-
return obs(args...)
419-
end
420-
end
421-
end
422-
else
423-
ObservedFunctionCache(sys; eval_expression, eval_module)
424-
end
402+
observedfun = ObservedFunctionCache(sys; steady_state, eval_expression, eval_module)
425403

426404
jac_prototype = if sparse
427405
uElType = u0 === nothing ? Float64 : eltype(u0)

0 commit comments

Comments
 (0)