@@ -6,12 +6,17 @@ using TensorAlgebra:
6
6
contract!,
7
7
eigen,
8
8
eigvals,
9
+ factorize,
9
10
fusedims,
10
11
left_null,
12
+ left_orth,
13
+ left_polar,
11
14
lq,
12
15
permmortar,
13
16
qr,
14
17
right_null,
18
+ right_orth,
19
+ right_polar,
15
20
splitdims,
16
21
svd,
17
22
svdvals
@@ -133,55 +138,57 @@ function TensorAlgebra.splitdims(na::AbstractNamedDimsArray, splitters::Pair...)
133
138
return nameddimsarray (a_split, names_split)
134
139
end
135
140
136
- function TensorAlgebra. qr (
137
- a:: AbstractNamedDimsArray , dimnames_codomain, dimnames_domain; kwargs...
141
+ # Generic interface for forwarding binary factorizations
142
+ # to the corresponding functions in TensorAlgebra.jl.
143
+ function factorize_with (
144
+ f, a:: AbstractNamedDimsArray , dimnames_codomain, dimnames_domain; kwargs...
138
145
)
139
146
codomain = to_nameddimsindices (a, dimnames_codomain)
140
147
domain = to_nameddimsindices (a, dimnames_domain)
141
- q_unnamed, r_unnamed = qr (dename (a), nameddimsindices (a), codomain, domain; kwargs... )
142
- name_q = randname (dimnames (a, 1 ))
143
- name_r = name_q
144
- namedindices_q = named (last (axes (q_unnamed )), name_q )
145
- namedindices_r = named (first (axes (r_unnamed )), name_r )
146
- nameddimsindices_q = (codomain... , namedindices_q )
147
- nameddimsindices_r = (namedindices_r , domain... )
148
- q = nameddimsarray (q_unnamed, nameddimsindices_q )
149
- r = nameddimsarray (r_unnamed, nameddimsindices_r )
150
- return q, r
148
+ x_unnamed, y_unnamed = f (dename (a), nameddimsindices (a), codomain, domain; kwargs... )
149
+ name_x = randname (dimnames (a, 1 ))
150
+ name_y = name_x
151
+ namedindices_x = named (last (axes (x_unnamed )), name_x )
152
+ namedindices_y = named (first (axes (y_unnamed )), name_y )
153
+ nameddimsindices_x = (codomain... , namedindices_x )
154
+ nameddimsindices_y = (namedindices_y , domain... )
155
+ x = nameddimsarray (x_unnamed, nameddimsindices_x )
156
+ y = nameddimsarray (y_unnamed, nameddimsindices_y )
157
+ return x, y
151
158
end
152
- function TensorAlgebra . qr ( a:: AbstractNamedDimsArray , dimnames_codomain; kwargs... )
159
+ function factorize_with (f, a:: AbstractNamedDimsArray , dimnames_codomain; kwargs... )
153
160
codomain = to_nameddimsindices (a, dimnames_codomain)
154
161
domain = setdiff (nameddimsindices (a), codomain)
155
- return qr (a, codomain, domain; kwargs... )
156
- end
157
- function LinearAlgebra. qr (a:: AbstractNamedDimsArray , args... ; kwargs... )
158
- return TensorAlgebra. qr (a, args... ; kwargs... )
162
+ return factorize_with (f, a, codomain, domain; kwargs... )
159
163
end
160
164
161
- function TensorAlgebra. lq (
162
- a:: AbstractNamedDimsArray , dimnames_codomain, dimnames_domain; kwargs...
163
- )
164
- codomain = to_nameddimsindices (a, dimnames_codomain)
165
- domain = to_nameddimsindices (a, dimnames_domain)
166
- l_unnamed, q_unnamed = lq (dename (a), nameddimsindices (a), codomain, domain; kwargs... )
167
- name_l = randname (dimnames (a, 1 ))
168
- name_q = name_l
169
- namedindices_l = named (last (axes (l_unnamed)), name_l)
170
- namedindices_q = named (first (axes (q_unnamed)), name_q)
171
- nameddimsindices_l = (codomain... , namedindices_l)
172
- nameddimsindices_q = (namedindices_q, domain... )
173
- l = nameddimsarray (l_unnamed, nameddimsindices_l)
174
- q = nameddimsarray (q_unnamed, nameddimsindices_q)
175
- return l, q
165
+ for f in [:qr , :lq , :left_polar , :right_polar , :left_orth , :right_orth , :factorize ]
166
+ @eval begin
167
+ function TensorAlgebra. $f (
168
+ a:: AbstractNamedDimsArray , dimnames_codomain, dimnames_domain; kwargs...
169
+ )
170
+ return factorize_with ($ f, a, dimnames_codomain, dimnames_domain; kwargs... )
171
+ end
172
+ function TensorAlgebra. $f (a:: AbstractNamedDimsArray , dimnames_codomain; kwargs... )
173
+ return factorize_with ($ f, a, dimnames_codomain; kwargs... )
174
+ end
175
+ end
176
176
end
177
- function TensorAlgebra . lq (a :: AbstractNamedDimsArray , dimnames_codomain; kwargs ... )
178
- codomain = to_nameddimsindices (a, dimnames_codomain)
179
- domain = setdiff ( nameddimsindices (a), codomain )
180
- return lq (a, codomain, domain ; kwargs... )
177
+
178
+ # Overload LinearAlgebra functions where relevant.
179
+ function LinearAlgebra . qr (a :: AbstractNamedDimsArray , args ... ; kwargs ... )
180
+ return TensorAlgebra . qr (a, args ... ; kwargs... )
181
181
end
182
182
function LinearAlgebra. lq (a:: AbstractNamedDimsArray , args... ; kwargs... )
183
183
return TensorAlgebra. lq (a, args... ; kwargs... )
184
184
end
185
+ function LinearAlgebra. factorize (a:: AbstractNamedDimsArray , args... ; kwargs... )
186
+ return TensorAlgebra. factorize (a, args... ; kwargs... )
187
+ end
188
+
189
+ #
190
+ # Non-binary factorizations.
191
+ #
185
192
186
193
function TensorAlgebra. svd (
187
194
a:: AbstractNamedDimsArray , dimnames_codomain, dimnames_domain; kwargs...
0 commit comments