@@ -646,30 +646,40 @@ struct ReconstructInitializeprob{GP, GU}
646646 ugetter:: GU
647647end
648648
649+ """
650+ $(TYPEDEF)
651+
652+ A wrapper over an observed function which allows calling it on a problem-like object.
653+ `TD` determines whether the getter function is `(u, p, t)` (if `true`) or `(u, p)` (if
654+ `false`).
655+ """
656+ struct ObservedWrapper{TD, F}
657+ f:: F
658+ end
659+
660+ ObservedWrapper {TD} (f:: F ) where {TD, F} = ObservedWrapper {TD, F} (f)
661+
662+ function (ow:: ObservedWrapper{true} )(prob)
663+ ow. f (state_values (prob), parameter_values (prob), current_time (prob))
664+ end
665+
666+ function (ow:: ObservedWrapper{false} )(prob)
667+ ow. f (state_values (prob), parameter_values (prob))
668+ end
669+
649670"""
650671 $(TYPEDSIGNATURES)
651672
652673Given an index provider `indp` and a vector of symbols `syms` return a type-stable getter
653- function by splitting `syms` into contiguous buffers where the getter of each buffer
654- is type-stable and constructing a function that calls and concatenates the results.
655- """
656- function concrete_getu (indp, syms:: AbstractVector )
657- # a list of contiguous buffer
658- split_syms = [Any[syms[1 ]]]
659- # the type of the getter of the last buffer
660- current = typeof (getu (indp, syms[1 ]))
661- for sym in syms[2 : end ]
662- getter = getu (indp, sym)
663- if typeof (getter) != current
664- # if types don't match, build a new buffer
665- push! (split_syms, [])
666- current = typeof (getter)
667- end
668- push! (split_syms[end ], sym)
669- end
670- split_syms = Tuple (split_syms)
671- # the getter is now type-stable, and we can vcat it to get the full buffer
672- return Base. Fix1 (reduce, vcat) ∘ getu (indp, split_syms)
674+ function.
675+
676+ Note that the getter ONLY works for problem-like objects, since it generates an observed
677+ function. It does NOT work for solutions.
678+ """
679+ Base. @nospecializeinfer function concrete_getu (indp, syms:: AbstractVector )
680+ @nospecialize
681+ obsfn = SymbolicIndexingInterface. observed (indp, syms)
682+ return ObservedWrapper {is_time_dependent(indp)} (obsfn)
673683end
674684
675685"""
0 commit comments