Skip to content

Commit 7b0e4a7

Browse files
Merge pull request #3760 from AayushSabharwal/as/linear-scc
feat: use `LinearProblem` for linear SCCs in `SCCNonlinearProblem`
2 parents 53a4867 + ca8be7b commit 7b0e4a7

File tree

7 files changed

+142
-66
lines changed

7 files changed

+142
-66
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ REPL = "1"
147147
RecursiveArrayTools = "3.26"
148148
Reexport = "0.2, 1"
149149
RuntimeGeneratedFunctions = "0.5.9"
150-
SCCNonlinearSolve = "1.0.0"
150+
SCCNonlinearSolve = "1.4.0"
151151
SciMLBase = "2.108.0"
152152
SciMLPublic = "1.0.0"
153153
SciMLStructures = "1.7"
@@ -156,7 +156,7 @@ Setfield = "0.7, 0.8, 1"
156156
SimpleNonlinearSolve = "0.1.0, 1, 2"
157157
SparseArrays = "1"
158158
SpecialFunctions = "1, 2"
159-
StaticArrays = "0.10, 0.11, 0.12, 1.0"
159+
StaticArrays = "1.9.14"
160160
StochasticDelayDiffEq = "1.10"
161161
StochasticDiffEq = "6.72.1"
162162
SymbolicIndexingInterface = "0.3.39"

src/problems/linearproblem.jl

Lines changed: 65 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,42 @@
1+
struct LinearFunction{iip, I} <: SciMLBase.AbstractSciMLFunction{iip}
2+
interface::I
3+
A::AbstractMatrix
4+
b::AbstractVector
5+
end
6+
7+
function LinearFunction{iip}(
8+
sys::System; expression = Val{false}, check_compatibility = true,
9+
sparse = false, eval_expression = false, eval_module = @__MODULE__,
10+
checkbounds = false, cse = true, kwargs...) where {iip}
11+
check_complete(sys, LinearProblem)
12+
check_compatibility && check_compatible_system(LinearProblem, sys)
13+
14+
A, b = calculate_A_b(sys; sparse)
15+
update_A = generate_update_A(sys, A; expression, wrap_gfw = Val{true}, eval_expression,
16+
eval_module, checkbounds, cse, kwargs...)
17+
update_b = generate_update_b(sys, b; expression, wrap_gfw = Val{true}, eval_expression,
18+
eval_module, checkbounds, cse, kwargs...)
19+
observedfun = ObservedFunctionCache(
20+
sys; steady_state = false, expression, eval_expression, eval_module, checkbounds,
21+
cse)
22+
23+
if expression == Val{true}
24+
symbolic_interface = quote
25+
update_A = $update_A
26+
update_b = $update_b
27+
sys = $sys
28+
observedfun = $observedfun
29+
$(SciMLBase.SymbolicLinearInterface)(
30+
update_A, update_b, sys, observedfun, nothing)
31+
end
32+
else
33+
symbolic_interface = SciMLBase.SymbolicLinearInterface(
34+
update_A, update_b, sys, observedfun, nothing)
35+
end
36+
37+
return LinearFunction{iip, typeof(symbolic_interface)}(symbolic_interface, A, b)
38+
end
39+
140
function SciMLBase.LinearProblem(sys::System, op; kwargs...)
241
SciMLBase.LinearProblem{true}(sys, op; kwargs...)
342
end
@@ -9,14 +48,14 @@ end
948
function SciMLBase.LinearProblem{iip}(
1049
sys::System, op; check_length = true, expression = Val{false},
1150
check_compatibility = true, sparse = false, eval_expression = false,
12-
eval_module = @__MODULE__, checkbounds = false, cse = true,
13-
u0_constructor = identity, u0_eltype = nothing, kwargs...) where {iip}
51+
eval_module = @__MODULE__, u0_constructor = identity, u0_eltype = nothing,
52+
kwargs...) where {iip}
1453
check_complete(sys, LinearProblem)
1554
check_compatibility && check_compatible_system(LinearProblem, sys)
1655

17-
_, u0,
56+
f, u0,
1857
p = process_SciMLProblem(
19-
EmptySciMLFunction{iip}, sys, op; check_length, expression,
58+
LinearFunction{iip}, sys, op; check_length, expression,
2059
build_initializeprob = false, symbolic_u0 = true, u0_constructor, u0_eltype,
2160
kwargs...)
2261

@@ -33,45 +72,38 @@ function SciMLBase.LinearProblem{iip}(
3372
u0_eltype = something(u0_eltype, floatT)
3473

3574
u0_constructor = get_p_constructor(u0_constructor, u0Type, u0_eltype)
75+
symbolic_interface = f.interface
76+
A,
77+
b = get_A_b_from_LinearFunction(
78+
sys, f, p; eval_expression, eval_module, expression, u0_constructor, sparse)
3679

37-
A, b = calculate_A_b(sys; sparse)
38-
update_A = generate_update_A(sys, A; expression, wrap_gfw = Val{true}, eval_expression,
39-
eval_module, checkbounds, cse, kwargs...)
40-
update_b = generate_update_b(sys, b; expression, wrap_gfw = Val{true}, eval_expression,
41-
eval_module, checkbounds, cse, kwargs...)
42-
observedfun = ObservedFunctionCache(
43-
sys; steady_state = false, expression, eval_expression, eval_module, checkbounds,
44-
cse)
80+
kwargs = (; u0, process_kwargs(sys; kwargs...)..., f = symbolic_interface)
81+
args = (; A, b, p)
4582

83+
return maybe_codegen_scimlproblem(expression, LinearProblem{iip}, args; kwargs...)
84+
end
85+
86+
function get_A_b_from_LinearFunction(
87+
sys::System, f::LinearFunction, p; eval_expression = false,
88+
eval_module = @__MODULE__, expression = Val{false}, u0_constructor = identity,
89+
u0_eltype = float, sparse = false)
90+
@unpack A, b, interface = f
4691
if expression == Val{true}
47-
symbolic_interface = quote
48-
update_A = $update_A
49-
update_b = $update_b
50-
sys = $sys
51-
observedfun = $observedfun
52-
$(SciMLBase.SymbolicLinearInterface)(
53-
update_A, update_b, sys, observedfun, nothing)
54-
end
5592
get_A = build_explicit_observed_function(
5693
sys, A; param_only = true, eval_expression, eval_module)
57-
if sparse
58-
get_A = SparseArrays.sparse get_A
59-
end
6094
get_b = build_explicit_observed_function(
6195
sys, b; param_only = true, eval_expression, eval_module)
62-
A = u0_constructor(get_A(p))
63-
b = u0_constructor(get_b(p))
96+
A = u0_constructor(u0_eltype.(get_A(p)))
97+
b = u0_constructor(u0_eltype.(get_b(p)))
6498
else
65-
symbolic_interface = SciMLBase.SymbolicLinearInterface(
66-
update_A, update_b, sys, observedfun, nothing)
67-
A = u0_constructor(update_A(p))
68-
b = u0_constructor(update_b(p))
99+
A = u0_constructor(u0_eltype.(interface.update_A!(p)))
100+
b = u0_constructor(u0_eltype.(interface.update_b!(p)))
101+
end
102+
if sparse
103+
A = SparseArrays.sparse(A)
69104
end
70105

71-
kwargs = (; u0, process_kwargs(sys; kwargs...)..., f = symbolic_interface)
72-
args = (; A, b, p)
73-
74-
return maybe_codegen_scimlproblem(expression, LinearProblem{iip}, args; kwargs...)
106+
return A, b
75107
end
76108

77109
# For remake

src/problems/sccnonlinearproblem.jl

Lines changed: 52 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ end
1010

1111
function CacheWriter(sys::AbstractSystem, buffer_types::Vector{TypeT},
1212
exprs::Dict{TypeT, Vector{Any}}, solsyms, obseqs::Vector{Equation};
13-
eval_expression = false, eval_module = @__MODULE__, cse = true)
13+
eval_expression = false, eval_module = @__MODULE__, cse = true, sparse = false)
1414
ps = parameters(sys; initial_parameters = true)
1515
rps = reorder_parameters(sys, ps)
1616
obs_assigns = [eq.lhs eq.rhs for eq in obseqs]
@@ -39,9 +39,22 @@ end
3939
struct SCCNonlinearFunction{iip} end
4040

4141
function SCCNonlinearFunction{iip}(
42-
sys::System, _eqs, _dvs, _obs, cachesyms; eval_expression = false,
42+
sys::System, _eqs, _dvs, _obs, cachesyms, op; eval_expression = false,
4343
eval_module = @__MODULE__, cse = true, kwargs...) where {iip}
4444
ps = parameters(sys; initial_parameters = true)
45+
subsys = System(
46+
_eqs, _dvs, ps; observed = _obs, name = nameof(sys), defaults = defaults(sys))
47+
@set! subsys.parameter_dependencies = parameter_dependencies(sys)
48+
if get_index_cache(sys) !== nothing
49+
@set! subsys.index_cache = subset_unknowns_observed(
50+
get_index_cache(sys), sys, _dvs, getproperty.(_obs, (:lhs,)))
51+
@set! subsys.complete = true
52+
end
53+
# generate linear problem instead
54+
if isaffine(subsys)
55+
return LinearFunction{iip}(
56+
subsys; eval_expression, eval_module, cse, cachesyms, kwargs...)
57+
end
4558
rps = reorder_parameters(sys, ps)
4659

4760
obs_assignments = [eq.lhs eq.rhs for eq in _obs]
@@ -54,14 +67,6 @@ function SCCNonlinearFunction{iip}(
5467
f_oop, f_iip = eval_or_rgf.(f_gen; eval_expression, eval_module)
5568
f = GeneratedFunctionWrapper{(2, 2, is_split(sys))}(f_oop, f_iip)
5669

57-
subsys = System(_eqs, _dvs, ps; observed = _obs,
58-
parameter_dependencies = parameter_dependencies(sys), name = nameof(sys))
59-
if get_index_cache(sys) !== nothing
60-
@set! subsys.index_cache = subset_unknowns_observed(
61-
get_index_cache(sys), sys, _dvs, getproperty.(_obs, (:lhs,)))
62-
@set! subsys.complete = true
63-
end
64-
6570
return NonlinearFunction{iip}(f; sys = subsys)
6671
end
6772

@@ -70,7 +75,7 @@ function SciMLBase.SCCNonlinearProblem(sys::System, args...; kwargs...)
7075
end
7176

7277
function SciMLBase.SCCNonlinearProblem{iip}(sys::System, op; eval_expression = false,
73-
eval_module = @__MODULE__, cse = true, kwargs...) where {iip}
78+
eval_module = @__MODULE__, cse = true, u0_constructor = identity, kwargs...) where {iip}
7479
if !iscomplete(sys) || get_tearing_state(sys) === nothing
7580
error("A simplified `System` is required. Call `mtkcompile` on the system before creating an `SCCNonlinearProblem`.")
7681
end
@@ -113,7 +118,7 @@ function SciMLBase.SCCNonlinearProblem{iip}(sys::System, op; eval_expression = f
113118

114119
_, u0,
115120
p = process_SciMLProblem(
116-
EmptySciMLFunction{iip}, sys, op; eval_expression, eval_module, kwargs...)
121+
EmptySciMLFunction{iip}, sys, op; eval_expression, eval_module, symbolic_u0 = true, kwargs...)
117122

118123
explicitfuns = []
119124
nlfuns = []
@@ -224,28 +229,57 @@ function SciMLBase.SCCNonlinearProblem{iip}(sys::System, op; eval_expression = f
224229
get(cachevars, T, [])
225230
end)
226231
f = SCCNonlinearFunction{iip}(
227-
sys, _eqs, _dvs, _obs, cachebufsyms; eval_expression, eval_module, cse, kwargs...)
232+
sys, _eqs, _dvs, _obs, cachebufsyms, op;
233+
eval_expression, eval_module, cse, kwargs...)
228234
push!(nlfuns, f)
229235
end
230236

237+
u0_eltype = Union{}
238+
for x in u0
239+
symbolic_type(x) == NotSymbolic() || continue
240+
u0_eltype = typeof(x)
241+
break
242+
end
243+
if u0_eltype == Union{}
244+
u0_eltype = Float64
245+
end
246+
u0_eltype = float(u0_eltype)
247+
231248
if !isempty(cachetypes)
232249
templates = map(cachetypes, cachesizes) do T, n
233250
# Real refers to `eltype(u0)`
234251
if T == Real
235-
T = eltype(u0)
252+
T = u0_eltype
236253
elseif T <: Array && eltype(T) == Real
237-
T = Array{eltype(u0), ndims(T)}
254+
T = Array{u0_eltype, ndims(T)}
238255
end
239256
BufferTemplate(T, n)
240257
end
241258
p = rebuild_with_caches(p, templates...)
242259
end
243260

261+
# yes, `get_p_constructor` since this is only used for `LinearProblem` and
262+
# will retain the shape of `A`
263+
u0_constructor = get_p_constructor(u0_constructor, typeof(u0), u0_eltype)
244264
subprobs = []
245-
for (f, vscc) in zip(nlfuns, var_sccs)
265+
for (i, (f, vscc)) in enumerate(zip(nlfuns, var_sccs))
246266
_u0 = SymbolicUtils.Code.create_array(
247267
typeof(u0), eltype(u0), Val(1), Val(length(vscc)), u0[vscc]...)
248-
prob = NonlinearProblem(f, _u0, p)
268+
symbolic_idxs = findall(x -> symbolic_type(x) != NotSymbolic(), _u0)
269+
explicitfuns[i](p, subprobs)
270+
if f isa LinearFunction
271+
_u0 = isempty(symbolic_idxs) ? _u0 : zeros(u0_eltype, length(_u0))
272+
_u0 = u0_eltype.(_u0)
273+
symbolic_interface = f.interface
274+
A,
275+
b = get_A_b_from_LinearFunction(
276+
sys, f, p; eval_expression, eval_module, u0_constructor, u0_eltype)
277+
prob = LinearProblem{iip}(A, b, p; f = symbolic_interface, u0 = _u0)
278+
else
279+
isempty(symbolic_idxs) || throw(MissingGuessError(dvs[vscc], _u0))
280+
_u0 = u0_eltype.(_u0)
281+
prob = NonlinearProblem(f, _u0, p)
282+
end
249283
push!(subprobs, prob)
250284
end
251285

@@ -255,5 +289,5 @@ function SciMLBase.SCCNonlinearProblem{iip}(sys::System, op; eval_expression = f
255289
@set! sys.eqs = new_eqs
256290
@set! sys.index_cache = subset_unknowns_observed(
257291
get_index_cache(sys), sys, new_dvs, getproperty.(obs, (:lhs,)))
258-
return SCCNonlinearProblem(subprobs, explicitfuns, p, true; sys)
292+
return SCCNonlinearProblem(Tuple(subprobs), Tuple(explicitfuns), p, true; sys)
259293
end

src/systems/codegen.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1189,10 +1189,11 @@ $GENERATE_X_KWARGS
11891189
All other keyword arguments are forwarded to [`build_function_wrapper`](@ref).
11901190
"""
11911191
function generate_update_A(sys::System, A::AbstractMatrix; expression = Val{true},
1192-
wrap_gfw = Val{false}, eval_expression = false, eval_module = @__MODULE__, kwargs...)
1192+
wrap_gfw = Val{false}, eval_expression = false, eval_module = @__MODULE__, cachesyms = (), kwargs...)
11931193
ps = reorder_parameters(sys)
11941194

1195-
res = build_function_wrapper(sys, A, ps...; p_start = 1, expression = Val{true},
1195+
res = build_function_wrapper(
1196+
sys, A, ps..., cachesyms...; p_start = 1, expression = Val{true},
11961197
similarto = typeof(A), kwargs...)
11971198
return maybe_compile_function(expression, wrap_gfw, (1, 1, is_split(sys)), res;
11981199
eval_expression, eval_module)
@@ -1211,10 +1212,11 @@ $GENERATE_X_KWARGS
12111212
All other keyword arguments are forwarded to [`build_function_wrapper`](@ref).
12121213
"""
12131214
function generate_update_b(sys::System, b::AbstractVector; expression = Val{true},
1214-
wrap_gfw = Val{false}, eval_expression = false, eval_module = @__MODULE__, kwargs...)
1215+
wrap_gfw = Val{false}, eval_expression = false, eval_module = @__MODULE__, cachesyms = (), kwargs...)
12151216
ps = reorder_parameters(sys)
12161217

1217-
res = build_function_wrapper(sys, b, ps...; p_start = 1, expression = Val{true},
1218+
res = build_function_wrapper(
1219+
sys, b, ps..., cachesyms...; p_start = 1, expression = Val{true},
12181220
similarto = typeof(b), kwargs...)
12191221
return maybe_compile_function(expression, wrap_gfw, (1, 1, is_split(sys)), res;
12201222
eval_expression, eval_module)

test/initial_values.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ end
202202
@variables a(t) b(t) c(t) d(t) e(t)
203203
eqs = [D(a) ~ b, D(b) ~ c, D(c) ~ d, D(d) ~ e, D(e) ~ 1]
204204
@mtkcompile sys = System(eqs, t)
205-
@test_throws ["a(t)", "c(t)"] ODEProblem(
205+
@test_throws ["d(t)", "c(t)"] ODEProblem(
206206
sys, [e => 2, a => b, b => a + 1, c => d, d => c + 1], (0, 1))
207207
end
208208

test/initializationsystem.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1670,15 +1670,15 @@ end
16701670
x[1] ~ 0.01exp(-1)
16711671
x[2] ~ 0.01cos(t)]
16721672

1673-
@mtkbuild sys = ODESystem(eqs, t)
1673+
@mtkcompile sys = System(eqs, t)
16741674
prob = ODEProblem(sys, [], (0.0, 1.0))
16751675
sol = solve(prob, Tsit5())
16761676
@test SciMLBase.successful_retcode(sol)
16771677
end
16781678

16791679
@testset "Defaults removed with ` => nothing` aren't retained" begin
16801680
@variables x(t)[1:2]
1681-
@mtkbuild sys = System([D(x[1]) ~ -x[1], x[1] + x[2] ~ 3], t; defaults = [x[1] => 1])
1681+
@mtkcompile sys = System([D(x[1]) ~ -x[1], x[1] + x[2] ~ 3], t; defaults = [x[1] => 1])
16821682
prob = ODEProblem(sys, [x[1] => nothing, x[2] => 1], (0.0, 1.0))
16831683
@test SciMLBase.initialization_status(prob) == SciMLBase.FULLY_DETERMINED
16841684
end
@@ -1696,7 +1696,7 @@ end
16961696
D(x) ~ r * x
16971697
end
16981698
end
1699-
@mtkbuild sys = Foo(p = "a")
1699+
@mtkcompile sys = Foo(p = "a")
17001700
prob = ODEProblem(sys, [], (0.0, 1.0))
17011701
@test prob.p.nonnumeric[1] isa Vector{AbstractString}
17021702
integ = init(prob)

test/scc_nonlinear_problem.jl

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,16 +27,21 @@ using ModelingToolkit: t_nounits as t, D_nounits as D
2727
@test_throws ["not compatible"] SCCNonlinearProblem(_model, [])
2828
model = mtkcompile(model)
2929
prob = NonlinearProblem(model, [u => zeros(8)])
30-
sccprob = SCCNonlinearProblem(model, [u => zeros(8)])
30+
sccprob = SCCNonlinearProblem(model, collect(u[1:5]) .=> zeros(5))
3131
sol1 = solve(prob, NewtonRaphson())
3232
sol2 = solve(sccprob, NewtonRaphson())
3333
@test SciMLBase.successful_retcode(sol1)
3434
@test SciMLBase.successful_retcode(sol2)
3535
@test sol1[u] sol2[u]
3636

37-
sccprob = SCCNonlinearProblem{false}(model, SA[u => zeros(8)])
37+
sccprob = SCCNonlinearProblem{false}(model, SA[(collect(u[1:5]) .=> zeros(5))...])
3838
for prob in sccprob.probs
39-
@test prob.u0 isa SVector
39+
if prob isa LinearProblem
40+
@test prob.A isa SMatrix
41+
@test prob.b isa SVector
42+
else
43+
@test prob.u0 isa SVector
44+
end
4045
@test !SciMLBase.isinplace(prob)
4146
end
4247

@@ -91,8 +96,9 @@ end
9196
@mtkcompile sys = System(eqs, [u], [p1, p2])
9297
sccprob = SCCNonlinearProblem(sys, [u => u0, p1 => p[1], p2 => p[2][]])
9398
sccsol = solve(sccprob, SimpleNewtonRaphson(); abstol = 1e-9)
99+
sccresid = prob.f(sccsol[u], (u0, p))
94100
@test SciMLBase.successful_retcode(sccsol)
95-
@test norm(sccsol.resid) < norm(sol.resid)
101+
@test norm(sccresid) < norm(sol.resid)
96102

97103
# Test BLT sorted
98104
@test istril(StructuralTransformations.sorted_incidence_matrix(sys), 1)
@@ -173,9 +179,11 @@ end
173179
0 ~ func(x[1], x[2]) * exp(x[3]) - x[4]^3 - 5
174180
0 ~ func(x[1], x[2]) * exp(x[4]) - x[3]^3 - 4])
175181
sccprob = SCCNonlinearProblem(sys, [])
182+
# since explicitfuns are called during problem construction
183+
@test val[] == 1
176184
sccsol = solve(sccprob, NewtonRaphson())
177185
@test SciMLBase.successful_retcode(sccsol)
178-
@test val[] == 1
186+
@test val[] == 2
179187
end
180188

181189
import ModelingToolkitStandardLibrary.Blocks as B

0 commit comments

Comments
 (0)