Skip to content

Commit 5a1345e

Browse files
Revert "refactor: use Symbolics.semilinear_form for LinearProblem codegen"
This reverts commit 2aba391.
1 parent 393af60 commit 5a1345e

File tree

1 file changed

+26
-6
lines changed

1 file changed

+26
-6
lines changed

src/systems/codegen.jl

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1150,15 +1150,35 @@ Return matrix `A` and vector `b` such that the system `sys` can be represented a
11501150
- `sparse`: return a sparse `A`.
11511151
"""
11521152
function calculate_A_b(sys::System; sparse = false)
1153-
rhss = [-eq.rhs for eq in full_equations(sys)]
1153+
rhss = [eq.rhs for eq in full_equations(sys)]
11541154
dvs = unknowns(sys)
11551155

1156-
A, b = semilinear_form(rhss, dvs)
1157-
if !sparse
1158-
A = collect(A)
1156+
A = Matrix{Any}(undef, length(rhss), length(dvs))
1157+
b = Vector{Any}(undef, length(rhss))
1158+
for (i, rhs) in enumerate(rhss)
1159+
# mtkcompile makes this `0 ~ rhs` which typically ends up giving
1160+
# unknowns negative coefficients. If given the equations `A * x ~ b`
1161+
# it will simplify to `0 ~ b - A * x`. Thus this negation usually leads
1162+
# to more comprehensible user API.
1163+
resid = -rhs
1164+
for (j, var) in enumerate(dvs)
1165+
p, q, islinear = Symbolics.linear_expansion(resid, var)
1166+
if !islinear
1167+
throw(ArgumentError("System is not linear. Equation $((0 ~ rhs)) is not linear in unknown $var."))
1168+
end
1169+
A[i, j] = p
1170+
resid = q
1171+
end
1172+
# negate beucause `resid` is the residual on the LHS
1173+
b[i] = -resid
1174+
end
1175+
1176+
@assert all(Base.Fix1(isassigned, A), eachindex(A))
1177+
@assert all(Base.Fix1(isassigned, A), eachindex(b))
1178+
1179+
if sparse
1180+
A = SparseArrays.sparse(A)
11591181
end
1160-
A = unwrap.(A)
1161-
b = unwrap.(-b)
11621182
return A, b
11631183
end
11641184

0 commit comments

Comments
 (0)