Skip to content

Commit f653c20

Browse files
feat: add symbolic interface for LinearProblem
1 parent 06cad10 commit f653c20

File tree

2 files changed

+95
-4
lines changed

2 files changed

+95
-4
lines changed

src/problems/linear_problems.jl

Lines changed: 50 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,36 @@
1+
"""
2+
$(TYPEDEF)
3+
4+
A utility struct stored inside `LinearProblem` to enable a symbolic interface.
5+
6+
# Fields
7+
8+
$(TYPEDFIELDS)
9+
"""
10+
struct SymbolicLinearInterface{F1, F2, S, M}
11+
"""
12+
A function which takes `A` and the parameter object `p` and updates `A` in-place.
13+
"""
14+
update_A!::F1
15+
"""
16+
A function which takes `b` and the parameter object `p` and updates `b` in-place.
17+
"""
18+
update_b!::F2
19+
"""
20+
The symbolic backend for the `LinearProblem`.
21+
"""
22+
sys::S
23+
"""
24+
Arbitrary metadata useful for the symbolic backend.
25+
"""
26+
metadata::M
27+
end
28+
29+
__has_sys(::SymbolicLinearInterface) = true
30+
has_sys(::SymbolicLinearInterface) = true
31+
32+
SymbolicIndexingInterface.symbolic_container(sli::SymbolicLinearInterface) = sli.sys
33+
134
@doc doc"""
235
336
Defines a linear system problem.
@@ -50,20 +83,23 @@ parameters. Any extra keyword arguments are passed on to the solvers.
5083
* `b`: The right-hand side of the linear system.
5184
* `p`: The parameters for the problem. Defaults to `NullParameters`. Currently unused.
5285
* `u0`: The initial condition used by iterative solvers.
86+
* `symbolic_interface`: An instance of `SymbolicLinearInterface` if the problem was
87+
generated by a symbolic backend.
5388
* `kwargs`: The keyword arguments passed on to the solvers.
5489
"""
55-
struct LinearProblem{uType, isinplace, F, bType, P, K} <:
90+
struct LinearProblem{uType, isinplace, F, bType, P, I <: Union{SymbolicLinearInterface, Nothing}, K} <:
5691
AbstractLinearProblem{bType, isinplace}
5792
A::F
5893
b::bType
5994
u0::uType
6095
p::P
96+
f::I
6197
kwargs::K
6298
@add_kwonly function LinearProblem{iip}(A, b, p = NullParameters(); u0 = nothing,
63-
kwargs...) where {iip}
99+
f = nothing, kwargs...) where {iip}
64100
warn_paramtype(p)
65-
new{typeof(u0), iip, typeof(A), typeof(b), typeof(p), typeof(kwargs)}(A, b, u0, p,
66-
kwargs)
101+
new{typeof(u0), iip, typeof(A), typeof(b), typeof(p), typeof(f), typeof(kwargs)}(A, b, u0, p,
102+
f, kwargs)
67103
end
68104
end
69105

@@ -77,6 +113,16 @@ function LinearProblem(A, b, args...; kwargs...)
77113
end
78114
end
79115

116+
SymbolicIndexingInterface.symbolic_container(prob::LinearProblem) = prob.f
117+
SymbolicIndexingInterface.state_values(prob::LinearProblem) = prob.u0
118+
SymbolicIndexingInterface.parameter_values(prob::LinearProblem) = prob.p
119+
SymbolicIndexingInterface.is_time_dependent(::LinearProblem) = false
120+
function SymbolicIndexingInterface.set_parameter!(valp::LinearProblem{A, B, C, D, E, <:SymbolicLinearInterface}, val, idx) where {A, B, C, D, E}
121+
set_parameter!(parameter_values(valp), val, idx)
122+
valp.f.update_A!(valp.A, valp.p)
123+
valp.f.update_b!(valp.b, valp.p)
124+
end
125+
80126
@doc doc"""
81127
Holds information on what variables to alias
82128
when solving a LinearProblem. Conforms to the AbstractAliasSpecifier interface.

test/downstream/problem_interface.jl

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,3 +332,48 @@ prob = SteadyStateProblem(osys, [u0; ps])
332332
@test scc.ps[p] 2.5
333333
end
334334
end
335+
336+
@testset "LinearProblem" begin
337+
# TODO update when MTK codegen exists
338+
sys = SymbolCache([:x, :y, :z], [:p, :q, :r])
339+
update_A! = function (A, p)
340+
A[1, 1] = p[1]
341+
A[2, 2] = p[2]
342+
A[3, 3] = p[3]
343+
end
344+
update_b! = function (b, p)
345+
b[1] = p[3]
346+
b[2] = -8p[2] - p[1]
347+
end
348+
f = SciMLBase.SymbolicLinearInterface(update_A!, update_b!, sys, nothing)
349+
A = Float64[1 1 1; 6 -4 5; 5 2 2]
350+
b = Float64[2, 31, 13]
351+
p = Float64[1, -4, 2]
352+
u0 = Float64[1, 2, 3]
353+
prob = LinearProblem(A, b, p; u0, f)
354+
@test prob[:x] 1.0
355+
@test prob[:y] 2.0
356+
@test prob[:z] 3.0
357+
@test prob.ps[:p] 1.0
358+
@test prob.ps[:q] -4.0
359+
@test prob.ps[:r] 2.0
360+
prob.ps[:p] = 2.0
361+
@test prob.ps[:p] 2.0
362+
@test prob.A[1, 1] 2.0
363+
@test prob.b[2] 30.0
364+
365+
prob2 = remake(prob; u0 = 2u0)
366+
@test prob2.u0 2u0
367+
prob2 = remake(prob; p = 2p)
368+
@test prob2.p 2p
369+
prob2 = remake(prob; u0 = [:x => 3.0], p = [:q => 1.5])
370+
@test prob2.u0[1] 3.0
371+
@test prob2.p[2] 1.5
372+
373+
# no u0
374+
prob = LinearProblem(A, b, p; f)
375+
prob2 = remake(prob; p = 2p)
376+
@test prob2.p 2p
377+
prob2 = remake(prob; p = [:q => 1.5])
378+
@test prob2.p[2] 1.5
379+
end

0 commit comments

Comments
 (0)