Skip to content

Commit 64c4833

Browse files
feat: add symbolic interface for LinearProblem
1 parent 28cb4df commit 64c4833

File tree

2 files changed

+77
-4
lines changed

2 files changed

+77
-4
lines changed

src/problems/linear_problems.jl

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,33 @@
1+
"""
2+
$(TYPEDEF)
3+
4+
A utility struct stored inside `LinearProblem` to enable a symbolic interface.
5+
6+
# Fields
7+
8+
$(TYPEDFIELDS)
9+
"""
10+
struct SymbolicLinearInterface{F1, F2, S, M}
11+
"""
12+
A function which takes `A` and the parameter object `p` and updates `A` in-place.
13+
"""
14+
update_A!::F1
15+
"""
16+
A function which takes `b` and the parameter object `p` and updates `b` in-place.
17+
"""
18+
update_b!::F2
19+
"""
20+
The symbolic backend for the `LinearProblem`.
21+
"""
22+
sys::S
23+
"""
24+
Arbitrary metadata useful for the symbolic backend.
25+
"""
26+
metadata::M
27+
end
28+
29+
SymbolicIndexingInterface.symbolic_container(sli::SymbolicLinearInterface) = sli.sys
30+
131
@doc doc"""
232
333
Defines a linear system problem.
@@ -50,20 +80,23 @@ parameters. Any extra keyword arguments are passed on to the solvers.
5080
* `b`: The right-hand side of the linear system.
5181
* `p`: The parameters for the problem. Defaults to `NullParameters`. Currently unused.
5282
* `u0`: The initial condition used by iterative solvers.
83+
* `symbolic_interface`: An instance of `SymbolicLinearInterface` if the problem was
84+
generated by a symbolic backend.
5385
* `kwargs`: The keyword arguments passed on to the solvers.
5486
"""
55-
struct LinearProblem{uType, isinplace, F, bType, P, K} <:
87+
struct LinearProblem{uType, isinplace, F, bType, P, I <: Union{SymbolicLinearInterface, Nothing}, K} <:
5688
AbstractLinearProblem{bType, isinplace}
5789
A::F
5890
b::bType
5991
u0::uType
6092
p::P
93+
symbolic_interface::I
6194
kwargs::K
6295
@add_kwonly function LinearProblem{iip}(A, b, p = NullParameters(); u0 = nothing,
63-
kwargs...) where {iip}
96+
symbolic_interface = nothing, kwargs...) where {iip}
6497
warn_paramtype(p)
65-
new{typeof(u0), iip, typeof(A), typeof(b), typeof(p), typeof(kwargs)}(A, b, u0, p,
66-
kwargs)
98+
new{typeof(u0), iip, typeof(A), typeof(b), typeof(p), typeof(symbolic_interface), typeof(kwargs)}(A, b, u0, p,
99+
symbolic_interface, kwargs)
67100
end
68101
end
69102

@@ -77,6 +110,16 @@ function LinearProblem(A, b, args...; kwargs...)
77110
end
78111
end
79112

113+
SymbolicIndexingInterface.symbolic_container(prob::LinearProblem) = prob.symbolic_interface
114+
SymbolicIndexingInterface.state_values(prob::LinearProblem) = prob.u0
115+
SymbolicIndexingInterface.parameter_values(prob::LinearProblem) = prob.p
116+
SymbolicIndexingInterface.is_time_dependent(::LinearProblem) = false
117+
function SymbolicIndexingInterface.set_parameter!(valp::LinearProblem{A, B, C, D, E, <:SymbolicLinearInterface}, val, idx) where {A, B, C, D, E}
118+
set_parameter!(parameter_values(valp), val, idx)
119+
valp.symbolic_interface.update_A!(valp.A, valp.p)
120+
valp.symbolic_interface.update_b!(valp.b, valp.p)
121+
end
122+
80123
@doc doc"""
81124
Holds information on what variables to alias
82125
when solving a LinearProblem. Conforms to the AbstractAliasSpecifier interface.

test/downstream/problem_interface.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,3 +332,33 @@ prob = SteadyStateProblem(osys, u0, ps)
332332
@test scc.ps[p] 2.5
333333
end
334334
end
335+
336+
@testset "LinearProblem" begin
337+
# TODO update when MTK codegen exists
338+
sys = SymbolCache([:x, :y, :z], [:p, :q, :r])
339+
update_A! = function (A, p)
340+
A[1, 1] = p[1]
341+
A[2, 2] = p[2]
342+
A[3, 3] = p[3]
343+
end
344+
update_b! = function (b, p)
345+
b[1] = p[3]
346+
b[2] = -8p[2] - p[1]
347+
end
348+
symbolic_interface = SciMLBase.SymbolicLinearInterface(update_A!, update_b!, sys, nothing)
349+
A = Float64[1 1 1; 6 -4 5; 5 2 2]
350+
b = Float64[2, 31, 13]
351+
p = Float64[1, -4, 2]
352+
u0 = Float64[1, 2, 3]
353+
prob = LinearProblem(A, b, p; u0, symbolic_interface)
354+
@test prob[:x] 1.0
355+
@test prob[:y] 2.0
356+
@test prob[:z] 3.0
357+
@test prob.ps[:p] 1.0
358+
@test prob.ps[:q] -4.0
359+
@test prob.ps[:r] 2.0
360+
prob.ps[:p] = 2.0
361+
@test prob.ps[:p] 2.0
362+
@test prob.A[1, 1] 2.0
363+
@test prob.b[2] 30.0
364+
end

0 commit comments

Comments
 (0)