Skip to content

Commit cc5df4c

Browse files
Merge pull request #142 from JuliaDiffEq/myb/lu
Add static version of Wfact and negate the W operator
2 parents 1cabfea + 0fd0b28 commit cc5df4c

File tree

3 files changed

+42
-8
lines changed

3 files changed

+42
-8
lines changed

src/systems/diffeqs/diffeqsystem.jl

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -126,22 +126,32 @@ function generate_factorized_W(sys::ODESystem, simplify=true; version::FunctionV
126126

127127
gam = Variable(:gam; known = true)()
128128

129-
W = LinearAlgebra.I - gam*jac
129+
W = - LinearAlgebra.I + gam*jac
130130
Wfact = lu(W, Val(false), check=false).factors
131131

132132
if simplify
133133
Wfact = simplify_constants.(Wfact)
134134
end
135135

136-
W_t = LinearAlgebra.I/gam - jac
136+
W_t = - LinearAlgebra.I/gam + jac
137137
Wfact_t = lu(W_t, Val(false), check=false).factors
138138
if simplify
139139
Wfact_t = simplify_constants.(Wfact_t)
140140
end
141141

142+
if version === SArrayFunction
143+
siz = size(Wfact)
144+
constructor = :(x -> begin
145+
A = SMatrix{$siz...}(x)
146+
StaticArrays.LU(LowerTriangular( SMatrix{$siz...}(UnitLowerTriangular(A)) ), UpperTriangular(A), SVector(ntuple(n->n, max($siz...))))
147+
end)
148+
else
149+
constructor = nothing
150+
end
151+
142152
vs, ps = sys.dvs, sys.ps
143-
Wfact_func = build_function(Wfact , vs, ps, (:gam,:t), ODEToExpr(sys); version = version)
144-
Wfact_t_func = build_function(Wfact_t, vs, ps, (:gam,:t), ODEToExpr(sys); version = version)
153+
Wfact_func = build_function(Wfact , vs, ps, (:gam,:t), ODEToExpr(sys); version = version, constructor=constructor)
154+
Wfact_t_func = build_function(Wfact_t, vs, ps, (:gam,:t), ODEToExpr(sys); version = version, constructor=constructor)
145155

146156
return (Wfact_func, Wfact_t_func)
147157
end

src/utils.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ function flatten_expr!(x)
3131
x
3232
end
3333

34-
function build_function(rhss, vs, ps, args = (), conv = rhs -> convert(Expr, rhs); version::FunctionVersion)
34+
function build_function(rhss, vs, ps, args = (), conv = rhs -> convert(Expr, rhs); version::FunctionVersion, constructor=nothing)
3535
var_pairs = [(u.name, :(u[$i])) for (i, u) enumerate(vs)]
3636
param_pairs = [(p.name, :(p[$i])) for (i, p) enumerate(ps)]
3737
(ls, rs) = zip(var_pairs..., param_pairs...)
@@ -48,7 +48,7 @@ function build_function(rhss, vs, ps, args = (), conv = rhs -> convert(Expr, rhs
4848
let_expr = Expr(:let, var_eqs, sys_expr)
4949
:((u,p,$(args...)) -> begin
5050
X = $let_expr
51-
T = StaticArrays.similar_type(typeof(u), eltype(X))
51+
T = $(constructor === nothing ? :(StaticArrays.similar_type(typeof(u), eltype(X))) : constructor)
5252
T(X)
5353
end)
5454
end

test/system_construction.jl

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using ModelingToolkit
1+
using ModelingToolkit, StaticArrays, LinearAlgebra
22
using Test
33

44
# Define some variables
@@ -35,16 +35,40 @@ generate_function(de, [x,y,z], [σ,ρ,β])
3535
generate_function(de, [x,y,z], [σ,ρ,β]; version=ModelingToolkit.SArrayFunction)
3636
jac_expr = generate_jacobian(de)
3737
jac = calculate_jacobian(de)
38+
jacfun = eval(jac_expr)
39+
# iip
3840
f = ODEFunction(de, [x,y,z], [σ,ρ,β])
3941
fw, fwt = map(eval, ModelingToolkit.generate_factorized_W(de))
4042
du = zeros(3)
4143
u = collect(1:3)
4244
p = collect(4:6)
4345
f(du, u, p, 0.1)
4446
@test du == [4, 0, -16]
47+
J = zeros(3, 3)
48+
jacfun(J, u, p, t)
4549
FW = zeros(3, 3)
50+
FWt = zeros(3, 3)
4651
fw(FW, u, p, 0.2, 0.1)
47-
fwt(FW, u, p, 0.2, 0.1)
52+
fwt(FWt, u, p, 0.2, 0.1)
53+
# oop
54+
f = ODEFunction(de, [x,y,z], [σ,ρ,β]; version=ModelingToolkit.SArrayFunction)
55+
fw, fwt = map(eval, ModelingToolkit.generate_factorized_W(de; version=ModelingToolkit.SArrayFunction))
56+
du = @SArray zeros(3)
57+
u = SVector(1:3...)
58+
p = SVector(4:6...)
59+
@test f(u, p, 0.1) === @SArray [4, 0, -16]
60+
Sfw = fw(u, p, 0.2, 0.1)
61+
@test Sfw.L UnitLowerTriangular(FW)
62+
@test Sfw.U UpperTriangular(FW)
63+
sol = Sfw \ @SArray ones(3)
64+
@test sol isa SArray
65+
@test sol -(I - 0.2*J)\ones(3)
66+
Sfw_t = fwt(u, p, 0.2, 0.1)
67+
@test Sfw_t.L UnitLowerTriangular(FWt)
68+
@test Sfw_t.U UpperTriangular(FWt)
69+
sol = Sfw_t \ @SArray ones(3)
70+
@test sol isa SArray
71+
@test sol -(I/0.2 - J)\ones(3)
4872

4973
@testset "time-varying parameters" begin
5074
@parameters σ′(t-1)

0 commit comments

Comments
 (0)