Skip to content

Commit 4f28525

Browse files
feat: add get_updated_symbolic_problem
1 parent a4135ab commit 4f28525

File tree

2 files changed

+29
-0
lines changed

2 files changed

+29
-0
lines changed

src/DiffEqBase.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,8 @@ Reexport.@reexport using SciMLBase
104104

105105
SciMLBase.isfunctionwrapper(x::FunctionWrapper) = true
106106

107+
import SymbolicIndexingInterface as SII
108+
107109
## Extension Functions
108110

109111
eltypedual(x) = false

src/solve.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -515,12 +515,39 @@ function Base.showerror(io::IO, e::LateBindingTstopsNotSupportedError)
515515
println(io, TruncatedStacktraces.VERBOSE_MSG)
516516
end
517517

518+
"""
519+
$(TYPEDSIGNATURES)
520+
521+
Given the index provider `indp` used to construct the problem `prob` being solved, return
522+
an updated `prob` to be used for solving. All implementations should accept arbitrary
523+
keyword arguments.
524+
525+
Should be called before the problem is solved, after performing type-promotion on the
526+
problem.
527+
"""
528+
function get_updated_symbolic_problem(indp, prob; kw...)
529+
return prob
530+
end
531+
532+
"""
533+
$(TYPEDSIGNATURES)
534+
535+
Get the innermost index provider using `SII.symbolic_container`.
536+
"""
537+
function _get_root_indp(indp)
538+
if hasmethod(SII.symbolic_container, Tuple{typeof(indp)}) && (sc = SII.symbolic_container(indp)) !== indp
539+
return _get_root_indp(sc)
540+
end
541+
return indp
542+
end
543+
518544
function init_call(_prob, args...; merge_callbacks = true, kwargshandle = nothing,
519545
kwargs...)
520546
kwargshandle = kwargshandle === nothing ? KeywordArgError : kwargshandle
521547
kwargshandle = has_kwargs(_prob) && haskey(_prob.kwargs, :kwargshandle) ?
522548
_prob.kwargs[:kwargshandle] : kwargshandle
523549

550+
_prob = get_updated_symbolic_problem(_get_root_indp(_prob), _prob)
524551
if has_kwargs(_prob)
525552
if merge_callbacks && haskey(_prob.kwargs, :callback) && haskey(kwargs, :callback)
526553
kwargs_temp = NamedTuple{

0 commit comments

Comments
 (0)