@@ -257,35 +257,48 @@ function find_solve_sequence(partitions, vars)
257
257
end
258
258
259
259
function build_observed_function (
260
- sys, syms ;
260
+ sys, ts ;
261
261
expression= false ,
262
262
output_type= Array,
263
263
checkbounds= true
264
264
)
265
265
266
- if (isscalar = ! (syms isa Vector ))
267
- syms = [syms ]
266
+ if (isscalar = ! (ts isa AbstractVector ))
267
+ ts = [ts ]
268
268
end
269
- syms = value .(syms)
270
- syms_set = Set (syms)
269
+ ts = Symbolics. scalarize .(value .(ts))
270
+
271
+ vars = Set ()
272
+ foreach (Base. Fix1 (vars!, vars), ts)
273
+ ivs = independent_variables (sys)
274
+ dep_vars = collect (setdiff (vars, ivs))
275
+
271
276
s = structure (sys)
272
277
@unpack partitions, fullvars, graph = s
273
278
diffvars = map (i-> fullvars[i], diffvars_range (s))
274
279
algvars = map (i-> fullvars[i], algvars_range (s))
275
280
276
- required_algvars = Set (intersect (algvars, syms_set ))
281
+ required_algvars = Set (intersect (algvars, vars ))
277
282
obs = observed (sys)
278
283
observed_idx = Dict (map (x-> x. lhs, obs) .=> 1 : length (obs))
279
284
# FIXME : this is a rather rough estimate of dependencies.
280
285
maxidx = 0
281
- for (i, s) in enumerate (syms)
286
+ sts = Set (states (sys))
287
+ for (i, s) in enumerate (dep_vars)
282
288
idx = get (observed_idx, s, nothing )
283
- idx === nothing && continue
289
+ if idx === nothing
290
+ if ! (s in sts)
291
+ throw (ArgumentError (" $s is either an observed nor a state variable." ))
292
+ end
293
+ continue
294
+ end
284
295
idx > maxidx && (maxidx = idx)
285
296
end
297
+ vs = Set ()
286
298
for idx in 1 : maxidx
287
- vs = vars ( obs[idx]. rhs)
299
+ vars! (vs, obs[idx]. rhs)
288
300
union! (required_algvars, intersect (algvars, vs))
301
+ empty! (vs)
289
302
end
290
303
291
304
varidxs = findall (x-> x in required_algvars, fullvars)
@@ -301,12 +314,11 @@ function build_observed_function(
301
314
solves = []
302
315
end
303
316
304
- output = map (syms) do sym
305
- if sym in required_algvars
306
- sym
307
- else
308
- obs[observed_idx[sym]]. rhs
309
- end
317
+ subs = []
318
+ for sym in vars
319
+ eqidx = get (observed_idx, sym, nothing )
320
+ eqidx === nothing && continue
321
+ push! (subs, sym ← obs[eqidx]. rhs)
310
322
end
311
323
pre = get_postprocess_fbody (sys)
312
324
@@ -321,8 +333,9 @@ function build_observed_function(
321
333
[
322
334
collect (Iterators. flatten (solves))
323
335
map (eq -> eq. lhs← eq. rhs, obs[1 : maxidx])
336
+ subs
324
337
],
325
- isscalar ? output [1 ] : MakeArray (output , output_type)
338
+ isscalar ? ts [1 ] : MakeArray (ts , output_type)
326
339
))
327
340
) |> Code. toexpr
328
341
0 commit comments