Skip to content

Commit 2b1bf2f

Browse files
committed
Add check kw in solve_for
1 parent 503c4fc commit 2b1bf2f

File tree

1 file changed

+12
-9
lines changed

1 file changed

+12
-9
lines changed

src/solve.jl

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -56,34 +56,37 @@ end
5656
# return the coefficient matrix `A` and a
5757
# vector of constants (possibly symbolic) `b` such that
5858
# A \ b will solve the equations for the vars
59-
function A_b(eqs::AbstractArray, vars::AbstractArray)
59+
function A_b(eqs::AbstractArray, vars::AbstractArray, check)
6060
exprs = rhss(eqs) .- lhss(eqs)
61-
for ex in exprs
62-
@assert islinear(ex, vars)
61+
if check
62+
for ex in exprs
63+
@assert islinear(ex, vars)
64+
end
6365
end
6466
A = jacobian(exprs, vars)
6567
b = A * vars - exprs
6668
A, b
6769
end
68-
function A_b(eq, var)
70+
function A_b(eq, var, check)
6971
ex = eq.rhs - eq.lhs
70-
@assert islinear(ex, [var])
72+
check && @assert islinear(ex, [var])
7173
a = expand_derivatives(Differential(var)(ex))
7274
b = a * var - ex
7375
a, b
7476
end
7577

7678
"""
77-
solve_for(eqs::Vector, vars::Vector)
79+
solve_for(eqs::Vector, vars::Vector; simplify=true, check=true)
7880
7981
Solve the vector of equations `eqs` for a set of variables `vars`.
8082
8183
Assumes `length(eqs) == length(vars)`
8284
83-
Currently only works if all equations are linear.
85+
Currently only works if all equations are linear. `check` if the expr is linear
86+
w.r.t `vars`.
8487
"""
85-
function solve_for(eqs, vars; simplify=true)
86-
A, b = A_b(eqs, vars)
88+
function solve_for(eqs, vars; simplify=true, check=true)
89+
A, b = A_b(eqs, vars, check)
8790
#TODO: we need to make sure that `solve_for(eqs, vars)` contains no `vars`
8891
_solve(A, b, simplify)
8992
end

0 commit comments

Comments
 (0)