@@ -617,30 +617,40 @@ struct ReconstructInitializeprob{GP, GU}
617617 ugetter:: GU
618618end
619619
620+ """
621+ $(TYPEDEF)
622+
623+ A wrapper over an observed function which allows calling it on a problem-like object.
624+ `TD` determines whether the getter function is `(u, p, t)` (if `true`) or `(u, p)` (if
625+ `false`).
626+ """
627+ struct ObservedWrapper{TD, F}
628+ f:: F
629+ end
630+
631+ ObservedWrapper {TD} (f:: F ) where {TD, F} = ObservedWrapper {TD, F} (f)
632+
633+ function (ow:: ObservedWrapper{true} )(prob)
634+ ow. f (state_values (prob), parameter_values (prob), current_time (prob))
635+ end
636+
637+ function (ow:: ObservedWrapper{false} )(prob)
638+ ow. f (state_values (prob), parameter_values (prob))
639+ end
640+
620641"""
621642 $(TYPEDSIGNATURES)
622643
623644Given an index provider `indp` and a vector of symbols `syms` return a type-stable getter
624- function by splitting `syms` into contiguous buffers where the getter of each buffer
625- is type-stable and constructing a function that calls and concatenates the results.
626- """
627- function concrete_getu (indp, syms:: AbstractVector )
628- # a list of contiguous buffer
629- split_syms = [Any[syms[1 ]]]
630- # the type of the getter of the last buffer
631- current = typeof (getu (indp, syms[1 ]))
632- for sym in syms[2 : end ]
633- getter = getu (indp, sym)
634- if typeof (getter) != current
635- # if types don't match, build a new buffer
636- push! (split_syms, [])
637- current = typeof (getter)
638- end
639- push! (split_syms[end ], sym)
640- end
641- split_syms = Tuple (split_syms)
642- # the getter is now type-stable, and we can vcat it to get the full buffer
643- return Base. Fix1 (reduce, vcat) ∘ getu (indp, split_syms)
645+ function.
646+
647+ Note that the getter ONLY works for problem-like objects, since it generates an observed
648+ function. It does NOT work for solutions.
649+ """
650+ Base. @nospecializeinfer function concrete_getu (indp, syms:: AbstractVector )
651+ @nospecialize
652+ obsfn = SymbolicIndexingInterface. observed (indp, syms)
653+ return ObservedWrapper {is_time_dependent(indp)} (obsfn)
644654end
645655
646656"""
0 commit comments