Skip to content

Commit 5965fec

Browse files
committed
make \ use solve
1 parent e3e4335 commit 5965fec

File tree

2 files changed

+18
-9
lines changed

2 files changed

+18
-9
lines changed

src/solve.jl

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ function sym_lu(A)
1616
for i=1:min(m, n)
1717
L[i,i] = 1
1818
end
19-
U = copy(A)
19+
U = copy!(Array{Any}(undef, size(A)),A)
2020
p = BlasInt[1:m;]
2121
for k = 1:m-1
2222
_, i = findmin(map(ii->iszero(U[ii, k]) ? Inf : nterms(U[ii,k]), k:n))
@@ -31,14 +31,13 @@ function sym_lu(A)
3131
U[j,k:m] .= U[j,k:m] .- L[j,k] .* U[k,k:m]
3232
end
3333
end
34-
factors = copy(U)
3534
for j=1:m
3635
for i=j+1:n
37-
factors[i,j] = L[i,j]
36+
U[i,j] = 0
3837
end
3938
end
4039

41-
LU(factors, p, BlasInt(0))
40+
(L, U, LinearAlgebra.ipiv2perm(p, m))
4241
end
4342

4443
# Given a vector of equations and a
@@ -67,11 +66,19 @@ Currently only works if all equations are linear.
6766
"""
6867
function solve(eqs, vars)
6968
A, b = A_b(eqs, vars)
69+
_solve(A, b)
70+
end
71+
72+
function _solve(A, b)
7073
A = SymbolicUtils.simplify.(to_symbolic.(A), polynorm=true)
7174
b = SymbolicUtils.simplify.(to_symbolic.(b), polynorm=true)
7275
map(to_mtk, SymbolicUtils.simplify.(ldiv(sym_lu(A), b)))
7376
end
7477

78+
LinearAlgebra.:(\)(A::AbstractMatrix{<:Expression}, b::AbstractVector{<:Expression}) = _solve(A, b)
79+
LinearAlgebra.:(\)(A::AbstractMatrix{<:Expression}, b::AbstractVector) = _solve(A, b)
80+
LinearAlgebra.:(\)(A::AbstractMatrix, b::AbstractVector{<:Expression}) = _solve(A, b)
81+
7582
# ldiv below
7683

7784
_iszero(x::Number) = iszero(x)
@@ -90,13 +97,10 @@ function simplifying_dot(x,y)
9097
end
9198
end
9299

93-
function ldiv(A::LU, b)
94-
L = A.L
95-
U = A.U
96-
100+
function ldiv((L,U,p), b)
97101
m, n = size(L)
98102
x = Vector{Any}(undef, length(b))
99-
b = b[A.p]
103+
b = b[p]
100104

101105
for i=n:-1:1
102106
sub = simplifying_dot(x[i+1:end], U[i,i+1:end])

test/operation_overloads.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,3 +70,8 @@ M \ b
7070
M \ reshape(b,2,1)
7171
M = [1 1; 0 2]
7272
M \ reshape(b,2,1)
73+
74+
75+
M = [1 a; 0 2]
76+
M \ b
77+
M \ [1, 2]

0 commit comments

Comments
 (0)