|
12 | 12 | # It should work as-is with Operation type too.
|
13 | 13 | function sym_lu(A)
|
14 | 14 | m, n = size(A)
|
15 |
| - L = fill!(Array{Any}(undef, size(A)),0) # TODO: make sparse? |
16 |
| - for i=1:min(m, n) |
17 |
| - L[i,i] = 1 |
18 |
| - end |
19 |
| - U = copy!(Array{Any}(undef, size(A)),A) |
20 |
| - p = BlasInt[1:m;] |
21 |
| - for k = 1:m-1 |
22 |
| - _, i = findmin(map(ii->_iszero(U[ii, k]) ? Inf : nterms(U[ii,k]), k:n)) |
| 15 | + F = map(x->x isa Num ? x : Num(x), A) |
| 16 | + minmn = min(m, n) |
| 17 | + p = Vector{BlasInt}(undef, minmn) |
| 18 | + info = zero(BlasInt) |
| 19 | + for k = 1:minmn |
| 20 | + val, i = findmin(map(ii->_iszero(F[ii, k]) ? Inf : nterms(F[ii,k]), k:n)) |
| 21 | + if !(val isa Symbolic) && (val isa Number) && val == Inf && iszero(info) |
| 22 | + info = k |
| 23 | + end |
23 | 24 | i += k - 1
|
24 | 25 | # swap
|
25 |
| - U[k, k:end], U[i, k:end] = U[i, k:end], U[k, k:end] |
26 |
| - L[k, 1:k-1], L[i, 1:k-1] = L[i, 1:k-1], L[k, 1:k-1] |
| 26 | + for j in 1:n |
| 27 | + F[k, j], F[i, j] = F[i, j], F[k, j] |
| 28 | + end |
27 | 29 | p[k] = i
|
28 |
| - |
29 |
| - for j = k+1:m |
30 |
| - L[j,k] = U[j, k] / U[k, k] |
31 |
| - U[j,k:m] .= U[j,k:m] .- (L[j,k],) .* U[k,k:m] |
| 30 | + for i in k+1:m |
| 31 | + F[i, k] = F[i, k] / F[k, k] |
32 | 32 | end
|
33 |
| - end |
34 |
| - for j=1:m |
35 |
| - for i=j+1:n |
36 |
| - U[i,j] = 0 |
| 33 | + for j = k+1:n |
| 34 | + for i in k+1:m |
| 35 | + F[i, j] = F[i, j] - F[i, k] * F[k, j] |
| 36 | + end |
37 | 37 | end
|
38 | 38 | end
|
39 |
| - |
40 |
| - (L, U, LinearAlgebra.ipiv2perm(p, m)) |
| 39 | + LU(F, p, info) |
41 | 40 | end
|
42 | 41 |
|
43 | 42 | # Given a vector of equations and a
|
@@ -77,46 +76,56 @@ function solve_for(eqs, vars)
|
77 | 76 | end
|
78 | 77 |
|
79 | 78 | function _solve(A::AbstractMatrix, b::AbstractArray)
|
80 |
| - A = SymbolicUtils.simplify.(to_symbolic.(A), polynorm=true) |
81 |
| - b = SymbolicUtils.simplify.(to_symbolic.(b), polynorm=true) |
82 |
| - SymbolicUtils.simplify.(ldiv(sym_lu(A), b)) |
| 79 | + A = SymbolicUtils.simplify.(Num.(A), polynorm=true) |
| 80 | + b = SymbolicUtils.simplify.(Num.(b), polynorm=true) |
| 81 | + value.(SymbolicUtils.simplify.(sym_lu(A) \ b)) |
83 | 82 | end
|
84 | 83 | _solve(a, b) = value(SymbolicUtils.simplify(b/a, polynorm=true))
|
85 | 84 |
|
86 | 85 | # ldiv below
|
87 | 86 |
|
88 | 87 | _iszero(x::Number) = iszero(x)
|
89 | 88 | _isone(x::Number) = isone(x)
|
90 |
| -_iszero(::Term) = false |
91 |
| -_isone(::Term) = false |
92 |
| - |
93 |
| -function simplifying_dot(x,y) |
94 |
| - isempty(x) && return 0 |
95 |
| - muls = map(x,y) do xi,yi |
96 |
| - _isone(xi) ? yi : _isone(yi) ? xi : _iszero(xi) ? 0 : _iszero(yi) ? 0 : xi * yi |
| 89 | +_iszero(::Symbolic) = false |
| 90 | +_isone(::Symbolic) = false |
| 91 | +_iszero(x::Num) = _iszero(value(x)) |
| 92 | +_isone(x::Num) = _isone(value(x)) |
| 93 | + |
| 94 | +LinearAlgebra.ldiv!(A::UpperTriangular{<:Union{Symbolic,Num}}, b::AbstractVector{<:Union{Symbolic,Num}}, x::AbstractVector{<:Union{Symbolic,Num}} = b) = symsub!(A, b, x) |
| 95 | +function symsub!(A::UpperTriangular, b::AbstractVector, x::AbstractVector = b) |
| 96 | + LinearAlgebra.require_one_based_indexing(A, b, x) |
| 97 | + n = size(A, 2) |
| 98 | + if !(n == length(b) == length(x)) |
| 99 | + throw(DimensionMismatch("second dimension of left hand side A, $n, length of output x, $(length(x)), and length of right hand side b, $(length(b)), must be equal")) |
97 | 100 | end
|
98 |
| - |
99 |
| - reduce(muls) do acc, x |
100 |
| - _iszero(acc) ? x : _iszero(x) ? acc : acc + x |
| 101 | + @inbounds for j in n:-1:1 |
| 102 | + _iszero(A.data[j,j]) && throw(SingularException(j)) |
| 103 | + xj = x[j] = b[j] / A.data[j,j] |
| 104 | + for i in j-1:-1:1 |
| 105 | + sub = _isone(xj) ? A.data[i,j] : A.data[i,j] * xj |
| 106 | + if !_iszero(sub) |
| 107 | + b[i] -= sub |
| 108 | + end |
| 109 | + end |
101 | 110 | end
|
| 111 | + x |
102 | 112 | end
|
103 | 113 |
|
104 |
| -function ldiv((L,U,p), b) |
105 |
| - m, n = size(L) |
106 |
| - x = Vector{Any}(undef, length(b)) |
107 |
| - b = b[p] |
108 |
| - |
109 |
| - for i=n:-1:1 |
110 |
| - sub = simplifying_dot(x[i+1:end], U[i,i+1:end]) |
111 |
| - den = U[i,i] |
112 |
| - x[i] = _iszero(sub) ? b[i] : b[i] - sub |
113 |
| - x[i] = _isone(den) ? x[i] : _isone(-den) ? -x[i] : x[i] / den |
| 114 | +LinearAlgebra.ldiv!(A::UnitLowerTriangular{<:Union{Symbolic,Num}}, b::AbstractVector{<:Union{Symbolic,Num}}, x::AbstractVector{<:Union{Symbolic,Num}} = b) = symsub!(A, b, x) |
| 115 | +function symsub!(A::UnitLowerTriangular, b::AbstractVector, x::AbstractVector = b) |
| 116 | + LinearAlgebra.require_one_based_indexing(A, b, x) |
| 117 | + n = size(A, 2) |
| 118 | + if !(n == length(b) == length(x)) |
| 119 | + throw(DimensionMismatch("second dimension of left hand side A, $n, length of output x, $(length(x)), and length of right hand side b, $(length(b)), must be equal")) |
114 | 120 | end
|
115 |
| - |
116 |
| - # unit lower triangular solve first: |
117 |
| - for i=1:n |
118 |
| - sub = simplifying_dot(b[1:i-1], L[i, 1:i-1]) # this should be `b` not x |
119 |
| - x[i] = _iszero(sub) ? x[i] : x[i] - sub |
| 121 | + @inbounds for j in 1:n |
| 122 | + xj = x[j] = b[j] |
| 123 | + for i in j+1:n |
| 124 | + sub = _isone(xj) ? A.data[i,j] : A.data[i,j] * xj |
| 125 | + if !_iszero(sub) |
| 126 | + b[i] -= sub |
| 127 | + end |
| 128 | + end |
120 | 129 | end
|
121 | 130 | x
|
122 | 131 | end
|
0 commit comments