|
1 |
| -using LinearAlgebra: LinearAlgebra, qr |
2 |
| -using TensorAlgebra: TensorAlgebra, blockedperm, contract, contract!, fusedims, splitdims |
| 1 | +using LinearAlgebra: LinearAlgebra |
| 2 | +using TensorAlgebra: |
| 3 | + TensorAlgebra, blockedperm, contract, contract!, fusedims, qr, splitdims, svd |
3 | 4 | using TensorAlgebra.BaseExtensions: BaseExtensions
|
4 | 5 |
|
5 | 6 | function TensorAlgebra.contract!(
|
@@ -35,6 +36,22 @@ function Base.:*(a1::AbstractNamedDimsArray, a2::AbstractNamedDimsArray)
|
35 | 36 | return contract(a1, a2)
|
36 | 37 | end
|
37 | 38 |
|
| 39 | +# Left associative fold/reduction. |
| 40 | +# Circumvent Base definitions: |
| 41 | +# ```julia |
| 42 | +# *(A::AbstractMatrix, B::AbstractMatrix, C::AbstractMatrix) |
| 43 | +# *(A::AbstractMatrix, B::AbstractMatrix, C::AbstractMatrix, D::AbstractMatrix) |
| 44 | +# ``` |
| 45 | +# that optimize matrix multiplication sequence. |
| 46 | +function Base.:*( |
| 47 | + a1::AbstractNamedDimsArray, |
| 48 | + a2::AbstractNamedDimsArray, |
| 49 | + a3::AbstractNamedDimsArray, |
| 50 | + a_rest::AbstractNamedDimsArray..., |
| 51 | +) |
| 52 | + return *(*(a1, a2), a3, a_rest...) |
| 53 | +end |
| 54 | + |
38 | 55 | function LinearAlgebra.mul!(
|
39 | 56 | a_dest::AbstractNamedDimsArray,
|
40 | 57 | a1::AbstractNamedDimsArray,
|
@@ -99,32 +116,78 @@ function TensorAlgebra.splitdims(na::AbstractNamedDimsArray, splitters::Pair...)
|
99 | 116 | return nameddims(a_split, names_split)
|
100 | 117 | end
|
101 | 118 |
|
102 |
| -function LinearAlgebra.qr( |
| 119 | +function TensorAlgebra.qr( |
103 | 120 | a::AbstractNamedDimsArray,
|
104 | 121 | nameddimsindices_codomain,
|
105 | 122 | nameddimsindices_domain;
|
106 | 123 | positive=nothing,
|
107 | 124 | )
|
108 | 125 | @assert isnothing(positive) || !positive
|
109 |
| - # TODO: This should be `TensorAlgebra.qr` rather than overloading `LinearAlgebra.qr`. |
110 |
| - # TODO: Don't require wrapping in `Tuple`. |
111 |
| - q, r = qr( |
| 126 | + q_unnamed, r_unnamed = qr( |
112 | 127 | unname(a),
|
113 |
| - Tuple(nameddimsindices(a)), |
114 |
| - Tuple(to_nameddimsindices(a, nameddimsindices_codomain)), |
115 |
| - Tuple(to_nameddimsindices(a, nameddimsindices_domain)), |
| 128 | + nameddimsindices(a), |
| 129 | + to_nameddimsindices(a, nameddimsindices_codomain), |
| 130 | + to_nameddimsindices(a, nameddimsindices_domain), |
| 131 | + ) |
| 132 | + name_q = randname(dimnames(a, 1)) |
| 133 | + name_r = name_q |
| 134 | + namedindices_q = named(last(axes(q_unnamed)), name_q) |
| 135 | + namedindices_r = named(first(axes(r_unnamed)), name_r) |
| 136 | + nameddimsindices_q = ( |
| 137 | + to_nameddimsindices(a, nameddimsindices_codomain)..., namedindices_q |
116 | 138 | )
|
117 |
| - name_qr = randname(nameddimsindices(a)[1]) |
118 |
| - nameddimsindices_q = (to_nameddimsindices(a, nameddimsindices_codomain)..., name_qr) |
119 |
| - nameddimsindices_r = (name_qr, to_nameddimsindices(a, nameddimsindices_domain)...) |
120 |
| - return nameddims(q, nameddimsindices_q), nameddims(r, nameddimsindices_r) |
| 139 | + nameddimsindices_r = (namedindices_r, to_nameddimsindices(a, nameddimsindices_domain)...) |
| 140 | + q = nameddims(q_unnamed, nameddimsindices_q) |
| 141 | + r = nameddims(r_unnamed, nameddimsindices_r) |
| 142 | + return q, r |
121 | 143 | end
|
122 | 144 |
|
123 |
| -function LinearAlgebra.qr(a::AbstractNamedDimsArray, nameddimsindices_codomain; kwargs...) |
| 145 | +function TensorAlgebra.qr(a::AbstractNamedDimsArray, nameddimsindices_codomain; kwargs...) |
124 | 146 | return qr(
|
125 | 147 | a,
|
126 | 148 | nameddimsindices_codomain,
|
127 | 149 | setdiff(nameddimsindices(a), to_nameddimsindices(a, nameddimsindices_codomain));
|
128 | 150 | kwargs...,
|
129 | 151 | )
|
130 | 152 | end
|
| 153 | + |
| 154 | +function LinearAlgebra.qr(a::AbstractNamedDimsArray, args...; kwargs...) |
| 155 | + return TensorAlgebra.qr(a, args...; kwargs...) |
| 156 | +end |
| 157 | + |
| 158 | +function TensorAlgebra.svd( |
| 159 | + a::AbstractNamedDimsArray, nameddimsindices_codomain, nameddimsindices_domain |
| 160 | +) |
| 161 | + u_unnamed, s_unnamed, v_unnamed = svd( |
| 162 | + unname(a), |
| 163 | + nameddimsindices(a), |
| 164 | + to_nameddimsindices(a, nameddimsindices_codomain), |
| 165 | + to_nameddimsindices(a, nameddimsindices_domain), |
| 166 | + ) |
| 167 | + name_u = randname(dimnames(a, 1)) |
| 168 | + name_v = randname(dimnames(a, 1)) |
| 169 | + namedindices_u = named(last(axes(u_unnamed)), name_u) |
| 170 | + namedindices_v = named(first(axes(v_unnamed)), name_v) |
| 171 | + nameddimsindices_u = ( |
| 172 | + to_nameddimsindices(a, nameddimsindices_codomain)..., namedindices_u |
| 173 | + ) |
| 174 | + nameddimsindices_s = (namedindices_u, namedindices_v) |
| 175 | + nameddimsindices_v = (namedindices_v, to_nameddimsindices(a, nameddimsindices_domain)...) |
| 176 | + u = nameddims(u_unnamed, nameddimsindices_u) |
| 177 | + s = nameddims(s_unnamed, nameddimsindices_s) |
| 178 | + v = nameddims(v_unnamed, nameddimsindices_v) |
| 179 | + return u, s, v |
| 180 | +end |
| 181 | + |
| 182 | +function TensorAlgebra.svd(a::AbstractNamedDimsArray, nameddimsindices_codomain; kwargs...) |
| 183 | + return svd( |
| 184 | + a, |
| 185 | + nameddimsindices_codomain, |
| 186 | + setdiff(nameddimsindices(a), to_nameddimsindices(a, nameddimsindices_codomain)); |
| 187 | + kwargs..., |
| 188 | + ) |
| 189 | +end |
| 190 | + |
| 191 | +function LinearAlgebra.svd(a::AbstractNamedDimsArray, args...; kwargs...) |
| 192 | + return TensorAlgebra.svd(a, args...; kwargs...) |
| 193 | +end |
0 commit comments