Skip to content

Commit d42923c

Browse files
Merge pull request #1141 from AayushSabharwal/as/pre-solve-hook
feat: add `get_updated_symbolic_problem`
2 parents d65d4a6 + 8140ed7 commit d42923c

File tree

3 files changed

+41
-6
lines changed

3 files changed

+41
-6
lines changed

Project.toml

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
3131
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
3232
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
3333
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
34+
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
3435
TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77"
3536

3637
[weakdeps]
@@ -39,8 +40,8 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
3940
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
4041
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
4142
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
42-
GeneralizedGenerated = "6b9d7cbe-bcb9-11e9-073f-15a7a543e2eb"
4343
GTPSA = "b27dd330-f138-47c5-815b-40db9dd9b6e8"
44+
GeneralizedGenerated = "6b9d7cbe-bcb9-11e9-073f-15a7a543e2eb"
4445
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
4546
Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7"
4647
MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca"
@@ -55,8 +56,8 @@ DiffEqBaseChainRulesCoreExt = "ChainRulesCore"
5556
DiffEqBaseDistributionsExt = "Distributions"
5657
DiffEqBaseEnzymeExt = ["ChainRulesCore", "Enzyme"]
5758
DiffEqBaseForwardDiffExt = ["ForwardDiff"]
58-
DiffEqBaseGeneralizedGeneratedExt = "GeneralizedGenerated"
5959
DiffEqBaseGTPSAExt = "GTPSA"
60+
DiffEqBaseGeneralizedGeneratedExt = "GeneralizedGenerated"
6061
DiffEqBaseMPIExt = "MPI"
6162
DiffEqBaseMeasurementsExt = "Measurements"
6263
DiffEqBaseMonteCarloMeasurementsExt = "MonteCarloMeasurements"
@@ -82,8 +83,8 @@ FastPower = "1.1"
8283
ForwardDiff = "0.10, 1"
8384
FunctionWrappers = "1.0"
8485
FunctionWrappersWrappers = "0.1"
85-
GeneralizedGenerated = "0.3"
8686
GTPSA = "1.4"
87+
GeneralizedGenerated = "0.3"
8788
LinearAlgebra = "1.9"
8889
Logging = "1.9"
8990
MPI = "0.20"
@@ -105,6 +106,7 @@ SparseArrays = "1.9"
105106
Static = "1"
106107
StaticArraysCore = "1.4"
107108
Statistics = "1"
109+
SymbolicIndexingInterface = "0.3.39"
108110
Tracker = "0.2"
109111
TruncatedStacktraces = "1"
110112
Unitful = "1"
@@ -129,10 +131,9 @@ SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
129131
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
130132
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
131133
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
132-
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
133134
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
134135
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
135136
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
136137

137138
[targets]
138-
test = ["Distributed", "Measurements", "MonteCarloMeasurements", "Unitful", "LabelledArrays", "SymbolicIndexingInterface", "ForwardDiff", "SparseArrays", "InteractiveUtils", "Plots", "Pkg", "Random", "ReverseDiff", "StaticArrays", "SafeTestsets", "Statistics", "Test", "Distributions", "Aqua", "BenchmarkTools"]
139+
test = ["Distributed", "Measurements", "MonteCarloMeasurements", "Unitful", "LabelledArrays", "ForwardDiff", "SparseArrays", "InteractiveUtils", "Plots", "Pkg", "Random", "ReverseDiff", "StaticArrays", "SafeTestsets", "Statistics", "Test", "Distributions", "Aqua", "BenchmarkTools"]

src/DiffEqBase.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,8 @@ Reexport.@reexport using SciMLBase
104104

105105
SciMLBase.isfunctionwrapper(x::FunctionWrapper) = true
106106

107+
import SymbolicIndexingInterface as SII
108+
107109
## Extension Functions
108110

109111
eltypedual(x) = false

src/solve.jl

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -515,6 +515,32 @@ function Base.showerror(io::IO, e::LateBindingTstopsNotSupportedError)
515515
println(io, TruncatedStacktraces.VERBOSE_MSG)
516516
end
517517

518+
"""
519+
$(TYPEDSIGNATURES)
520+
521+
Given the index provider `indp` used to construct the problem `prob` being solved, return
522+
an updated `prob` to be used for solving. All implementations should accept arbitrary
523+
keyword arguments.
524+
525+
Should be called before the problem is solved, after performing type-promotion on the
526+
problem.
527+
"""
528+
function get_updated_symbolic_problem(indp, prob; kw...)
529+
return prob
530+
end
531+
532+
"""
533+
$(TYPEDSIGNATURES)
534+
535+
Get the innermost index provider using `SII.symbolic_container`.
536+
"""
537+
function _get_root_indp(indp)
538+
if hasmethod(SII.symbolic_container, Tuple{typeof(indp)}) && (sc = SII.symbolic_container(indp)) !== indp
539+
return _get_root_indp(sc)
540+
end
541+
return indp
542+
end
543+
518544
function init_call(_prob, args...; merge_callbacks = true, kwargshandle = nothing,
519545
kwargs...)
520546
kwargshandle = kwargshandle === nothing ? KeywordArgError : kwargshandle
@@ -1213,24 +1239,27 @@ function checkkwargs(kwargshandle; kwargs...)
12131239
end
12141240

12151241
function get_concrete_problem(prob::AbstractJumpProblem, isadapt; kwargs...)
1216-
prob
1242+
get_updated_symbolic_problem(_get_root_indp(prob), prob)
12171243
end
12181244

12191245
function get_concrete_problem(prob::SteadyStateProblem, isadapt; kwargs...)
1246+
prob = get_updated_symbolic_problem(_get_root_indp(prob), prob)
12201247
p = get_concrete_p(prob, kwargs)
12211248
u0 = get_concrete_u0(prob, isadapt, Inf, kwargs)
12221249
u0 = promote_u0(u0, p, nothing)
12231250
remake(prob; u0 = u0, p = p)
12241251
end
12251252

12261253
function get_concrete_problem(prob::NonlinearProblem, isadapt; kwargs...)
1254+
prob = get_updated_symbolic_problem(_get_root_indp(prob), prob)
12271255
p = get_concrete_p(prob, kwargs)
12281256
u0 = get_concrete_u0(prob, isadapt, nothing, kwargs)
12291257
u0 = promote_u0(u0, p, nothing)
12301258
remake(prob; u0 = u0, p = p)
12311259
end
12321260

12331261
function get_concrete_problem(prob::NonlinearLeastSquaresProblem, isadapt; kwargs...)
1262+
prob = get_updated_symbolic_problem(_get_root_indp(prob), prob)
12341263
p = get_concrete_p(prob, kwargs)
12351264
u0 = get_concrete_u0(prob, isadapt, nothing, kwargs)
12361265
u0 = promote_u0(u0, p, nothing)
@@ -1252,6 +1281,7 @@ function init(prob::PDEProblem, alg::AbstractDEAlgorithm, args...;
12521281
end
12531282

12541283
function get_concrete_problem(prob, isadapt; kwargs...)
1284+
prob = get_updated_symbolic_problem(_get_root_indp(prob), prob)
12551285
p = get_concrete_p(prob, kwargs)
12561286
tspan = get_concrete_tspan(prob, isadapt, kwargs, p)
12571287
u0 = get_concrete_u0(prob, isadapt, tspan[1], kwargs)
@@ -1270,6 +1300,7 @@ function get_concrete_problem(prob, isadapt; kwargs...)
12701300
end
12711301

12721302
function get_concrete_problem(prob::DAEProblem, isadapt; kwargs...)
1303+
prob = get_updated_symbolic_problem(_get_root_indp(prob), prob)
12731304
p = get_concrete_p(prob, kwargs)
12741305
tspan = get_concrete_tspan(prob, isadapt, kwargs, p)
12751306
u0 = get_concrete_u0(prob, isadapt, tspan[1], kwargs)
@@ -1293,6 +1324,7 @@ function get_concrete_problem(prob::DAEProblem, isadapt; kwargs...)
12931324
end
12941325

12951326
function get_concrete_problem(prob::DDEProblem, isadapt; kwargs...)
1327+
prob = get_updated_symbolic_problem(_get_root_indp(prob), prob)
12961328
p = get_concrete_p(prob, kwargs)
12971329
tspan = get_concrete_tspan(prob, isadapt, kwargs, p)
12981330
u0 = get_concrete_u0(prob, isadapt, tspan[1], kwargs)

0 commit comments

Comments
 (0)