Skip to content

Commit 696bab0

Browse files
feat: add LinearProblem codegen
1 parent 3719c68 commit 696bab0

File tree

5 files changed

+224
-0
lines changed

5 files changed

+224
-0
lines changed

src/ModelingToolkit.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ include("problems/jumpproblem.jl")
188188
include("problems/initializationproblem.jl")
189189
include("problems/sccnonlinearproblem.jl")
190190
include("problems/bvproblem.jl")
191+
include("problems/linearproblem.jl")
191192

192193
include("modelingtoolkitize/common.jl")
193194
include("modelingtoolkitize/odeproblem.jl")

src/problems/compatibility.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,3 +169,12 @@ function check_no_equations(sys::System, T)
169169
"""))
170170
end
171171
end
172+
173+
function check_affine(sys::System, T)
174+
if !isaffine(sys)
175+
throw(SystemCompatibilityError("""
176+
A non-affine system cannot be used to construct a `$T`. Consider a
177+
`NonlinearProblem` instead.
178+
"""))
179+
end
180+
end

src/problems/docs.jl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,3 +391,32 @@ $PROBLEM_INTERNALS_HEADER
391391
392392
$PROBLEM_INTERNAL_KWARGS
393393
""" SciMLBase.IntervalNonlinearProblem
394+
395+
@doc """
396+
SciMLBase.LinearProblem(sys::System, op; kwargs...)
397+
SciMLBase.LinearProblem{iip}(sys::System, op; kwargs...)
398+
399+
Build a `LinearProblem` given a system `sys` and operating point `op`. `iip` is a boolean
400+
indicating whether the problem should be in-place. The operating point should be an
401+
iterable collection of key-value pairs mapping variables/parameters in the system to the
402+
(initial) values they should take in `LinearProblem`. Any values not provided will
403+
fallback to the corresponding default (if present).
404+
405+
Note that since `u0` is optional for `LinearProblem`, values of unknowns do not need to be
406+
specified in `op` to create a `LinearProblem`. In such a case, `prob.u0` will be `nothing`
407+
and attempting to symbolically index the problem with an unknown, observable, or expression
408+
depending on unknowns/observables will error.
409+
410+
Updating the parameters automatically updates the `A` and `b` arrays.
411+
412+
# Keyword arguments
413+
414+
$PROBLEM_KWARGS
415+
$(prob_fun_common_kwargs(LinearProblem, false))
416+
417+
All other keyword arguments are forwarded to the $func constructor.
418+
419+
$PROBLEM_INTERNALS_HEADER
420+
421+
$PROBLEM_INTERNAL_KWARGS
422+
""" SciMLBase.LinearProblem

src/problems/linearproblem.jl

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
function SciMLBase.LinearProblem(sys::System, op; kwargs...)
2+
SciMLBase.LinearProblem{true}(sys, op; kwargs...)
3+
end
4+
5+
function SciMLBase.LinearProblem(sys::System, op::StaticArray; kwargs...)
6+
SciMLBase.LinearProblem{false}(sys, op; kwargs...)
7+
end
8+
9+
function SciMLBase.LinearProblem{iip}(
10+
sys::System, op; check_length = true, expression = Val{false},
11+
check_compatibility = true, sparse = false, eval_expression = false,
12+
eval_module = @__MODULE__, checkbounds = false, cse = true,
13+
u0_constructor = identity, u0_eltype = nothing, kwargs...) where {iip}
14+
check_complete(sys, LinearProblem)
15+
check_compatibility && check_compatible_system(LinearProblem, sys)
16+
17+
_, u0, p = process_SciMLProblem(
18+
EmptySciMLFunction{iip}, sys, op; check_length, expression,
19+
build_initializeprob = false, symbolic_u0 = true, u0_constructor, u0_eltype,
20+
kwargs...)
21+
22+
if any(x -> symbolic_type(x) != NotSymbolic(), u0)
23+
u0 = nothing
24+
end
25+
26+
u0Type = typeof(op)
27+
floatT = if u0 === nothing
28+
calculate_float_type(op, u0Type)
29+
else
30+
eltype(u0)
31+
end
32+
u0_eltype = something(u0_eltype, floatT)
33+
34+
u0_constructor = get_p_constructor(u0_constructor, u0Type, u0_eltype)
35+
36+
A, b = calculate_A_b(sys; sparse)
37+
update_A = generate_update_A(sys, A; expression, wrap_gfw = Val{true}, eval_expression,
38+
eval_module, checkbounds, cse, kwargs...)
39+
update_b = generate_update_b(sys, b; expression, wrap_gfw = Val{true}, eval_expression,
40+
eval_module, checkbounds, cse, kwargs...)
41+
observedfun = ObservedFunctionCache(
42+
sys; steady_state = false, expression, eval_expression, eval_module, checkbounds,
43+
cse)
44+
45+
if expression == Val{true}
46+
symbolic_interface = quote
47+
update_A = $update_A
48+
update_b = $update_b
49+
sys = $sys
50+
observedfun = $observedfun
51+
$(SciMLBase.SymbolicLinearInterface)(
52+
update_A, update_b, sys, observedfun, nothing)
53+
end
54+
get_A = build_explicit_observed_function(
55+
sys, A; param_only = true, eval_expression, eval_module)
56+
if sparse
57+
get_A = SparseArrays.sparse get_A
58+
end
59+
get_b = build_explicit_observed_function(
60+
sys, b; param_only = true, eval_expression, eval_module)
61+
A = u0_constructor(get_A(p))
62+
b = u0_constructor(get_b(p))
63+
else
64+
symbolic_interface = SciMLBase.SymbolicLinearInterface(
65+
update_A, update_b, sys, observedfun, nothing)
66+
A = u0_constructor(update_A(p))
67+
b = u0_constructor(update_b(p))
68+
end
69+
70+
kwargs = (; u0, process_kwargs(sys; kwargs...)..., f = symbolic_interface)
71+
args = (; A, b, p)
72+
73+
return maybe_codegen_scimlproblem(expression, LinearProblem{iip}, args; kwargs...)
74+
end
75+
76+
# For remake
77+
function SciMLBase.get_new_A_b(
78+
sys::AbstractSystem, f::SciMLBase.SymbolicLinearInterface, p, A, b; kw...)
79+
if ArrayInterface.ismutable(A)
80+
f.update_A!(A, p)
81+
f.update_b!(b, p)
82+
else
83+
# The generated function has both IIP and OOP variants
84+
A = StaticArraysCore.similar_type(A)(f.update_A!(p))
85+
b = StaticArraysCore.similar_type(b)(f.update_b!(p))
86+
end
87+
return A, b
88+
end
89+
90+
function check_compatible_system(T::Type{LinearProblem}, sys::System)
91+
check_time_independent(sys, T)
92+
check_affine(sys, T)
93+
check_not_dde(sys)
94+
check_no_cost(sys, T)
95+
check_no_constraints(sys, T)
96+
check_no_jumps(sys, T)
97+
check_no_noise(sys, T)
98+
end

src/systems/codegen.jl

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1130,3 +1130,90 @@ function build_explicit_observed_function(sys, ts;
11301130
return f
11311131
end
11321132
end
1133+
1134+
"""
1135+
$(TYPEDSIGNATURES)
1136+
1137+
Return matrix `A` and vector `b` such that the system `sys` can be represented as
1138+
`A * x = b` where `x` is `unknowns(sys)`. Errors if the system is not affine.
1139+
1140+
# Keyword arguments
1141+
1142+
- `sparse`: return a sparse `A`.
1143+
"""
1144+
function calculate_A_b(sys::System; sparse = false)
1145+
rhss = [eq.rhs for eq in full_equations(sys)]
1146+
dvs = unknowns(sys)
1147+
1148+
A = Matrix{Any}(undef, length(rhss), length(dvs))
1149+
b = Vector{Any}(undef, length(rhss))
1150+
for (i, rhs) in enumerate(rhss)
1151+
# mtkcompile makes this `0 ~ rhs` which typically ends up giving
1152+
# unknowns negative coefficients. If given the equations `A * x ~ b`
1153+
# it will simplify to `0 ~ b - A * x`. Thus this negation usually leads
1154+
# to more comprehensible user API.
1155+
resid = -rhs
1156+
for (j, var) in enumerate(dvs)
1157+
p, q, islinear = Symbolics.linear_expansion(resid, var)
1158+
if !islinear
1159+
throw(ArgumentError("System is not linear. Equation $((0 ~ rhs)) is not linear in unknown $var."))
1160+
end
1161+
A[i, j] = p
1162+
resid = q
1163+
end
1164+
# negate beucause `resid` is the residual on the LHS
1165+
b[i] = -resid
1166+
end
1167+
1168+
@assert all(Base.Fix1(isassigned, A), eachindex(A))
1169+
@assert all(Base.Fix1(isassigned, A), eachindex(b))
1170+
1171+
if sparse
1172+
A = SparseArrays.sparse(A)
1173+
end
1174+
return A, b
1175+
end
1176+
1177+
"""
1178+
$(TYPEDSIGNATURES)
1179+
1180+
Given a system `sys` and the `A` from [`calculate_A_b`](@ref) generate the function that
1181+
updates `A` given the parameter object.
1182+
1183+
# Keyword arguments
1184+
1185+
$GENERATE_X_KWARGS
1186+
1187+
All other keyword arguments are forwarded to [`build_function_wrapper`](@ref).
1188+
"""
1189+
function generate_update_A(sys::System, A::AbstractMatrix; expression = Val{true},
1190+
wrap_gfw = Val{false}, eval_expression = false, eval_module = @__MODULE__, kwargs...)
1191+
ps = reorder_parameters(sys)
1192+
1193+
res = build_function_wrapper(sys, A, ps...; p_start = 1, expression = Val{true},
1194+
similarto = typeof(A), kwargs...)
1195+
return maybe_compile_function(expression, wrap_gfw, (1, 1, is_split(sys)), res;
1196+
eval_expression, eval_module)
1197+
end
1198+
1199+
"""
1200+
$(TYPEDSIGNATURES)
1201+
1202+
Given a system `sys` and the `b` from [`calculate_A_b`](@ref) generate the function that
1203+
updates `b` given the parameter object.
1204+
1205+
# Keyword arguments
1206+
1207+
$GENERATE_X_KWARGS
1208+
1209+
All other keyword arguments are forwarded to [`build_function_wrapper`](@ref).
1210+
"""
1211+
function generate_update_b(sys::System, b::AbstractVector; expression = Val{true},
1212+
wrap_gfw = Val{false}, eval_expression = false, eval_module = @__MODULE__, kwargs...)
1213+
ps = reorder_parameters(sys)
1214+
1215+
res = build_function_wrapper(sys, b, ps...; p_start = 1, expression = Val{true},
1216+
similarto = typeof(b), kwargs...)
1217+
return maybe_compile_function(expression, wrap_gfw, (1, 1, is_split(sys)), res;
1218+
eval_expression, eval_module)
1219+
end

0 commit comments

Comments
 (0)