1
1
using LinearAlgebra: LinearAlgebra
2
2
using TensorAlgebra:
3
- TensorAlgebra, blockedperm, contract, contract!, fusedims, permmortar, qr, splitdims, svd
3
+ TensorAlgebra,
4
+ blockedperm,
5
+ contract,
6
+ contract!,
7
+ eigen,
8
+ eigvals,
9
+ fusedims,
10
+ left_null,
11
+ lq,
12
+ permmortar,
13
+ qr,
14
+ right_null,
15
+ splitdims,
16
+ svd,
17
+ svdvals
4
18
using TensorAlgebra. BaseExtensions: BaseExtensions
5
19
6
20
function TensorAlgebra. contract! (
@@ -94,7 +108,7 @@ function TensorAlgebra.fusedims(na::AbstractNamedDimsArray, fusions::Pair...)
94
108
)
95
109
end
96
110
perm = blockedperm (na, nameddimsindices_fuse... )
97
- a_fused = fusedims (unname (na), perm)
111
+ a_fused = fusedims (dename (na), perm)
98
112
return nameddimsarray (a_fused, nameddimsindices_fused)
99
113
end
100
114
@@ -107,7 +121,7 @@ function TensorAlgebra.splitdims(na::AbstractNamedDimsArray, splitters::Pair...)
107
121
split_lengths = unname .(split_namedlengths)
108
122
return fused_dim => split_lengths
109
123
end
110
- a_split = splitdims (unname (na), splitters_unnamed... )
124
+ a_split = splitdims (dename (na), splitters_unnamed... )
111
125
names_split = Any[tuple .(nameddimsindices (na))... ]
112
126
for splitter in splitters
113
127
fused_name, split_namedlengths = splitter
@@ -120,77 +134,170 @@ function TensorAlgebra.splitdims(na::AbstractNamedDimsArray, splitters::Pair...)
120
134
end
121
135
122
136
function TensorAlgebra. qr (
123
- a:: AbstractNamedDimsArray ,
124
- nameddimsindices_codomain,
125
- nameddimsindices_domain;
126
- positive= nothing ,
137
+ a:: AbstractNamedDimsArray , dimnames_codomain, dimnames_domain; kwargs...
127
138
)
128
- @assert isnothing (positive) || ! positive
129
- q_unnamed, r_unnamed = qr (
130
- unname (a),
131
- nameddimsindices (a),
132
- to_nameddimsindices (a, nameddimsindices_codomain),
133
- to_nameddimsindices (a, nameddimsindices_domain),
134
- )
139
+ codomain = to_nameddimsindices (a, dimnames_codomain)
140
+ domain = to_nameddimsindices (a, dimnames_domain)
141
+ q_unnamed, r_unnamed = qr (dename (a), nameddimsindices (a), codomain, domain; kwargs... )
135
142
name_q = randname (dimnames (a, 1 ))
136
143
name_r = name_q
137
144
namedindices_q = named (last (axes (q_unnamed)), name_q)
138
145
namedindices_r = named (first (axes (r_unnamed)), name_r)
139
- nameddimsindices_q = (
140
- to_nameddimsindices (a, nameddimsindices_codomain)... , namedindices_q
141
- )
142
- nameddimsindices_r = (namedindices_r, to_nameddimsindices (a, nameddimsindices_domain)... )
146
+ nameddimsindices_q = (codomain... , namedindices_q)
147
+ nameddimsindices_r = (namedindices_r, domain... )
143
148
q = nameddimsarray (q_unnamed, nameddimsindices_q)
144
149
r = nameddimsarray (r_unnamed, nameddimsindices_r)
145
150
return q, r
146
151
end
147
-
148
- function TensorAlgebra. qr (a:: AbstractNamedDimsArray , nameddimsindices_codomain; kwargs... )
149
- return qr (
150
- a,
151
- nameddimsindices_codomain,
152
- setdiff (nameddimsindices (a), to_nameddimsindices (a, nameddimsindices_codomain));
153
- kwargs... ,
154
- )
152
+ function TensorAlgebra. qr (a:: AbstractNamedDimsArray , dimnames_codomain; kwargs... )
153
+ codomain = to_nameddimsindices (a, dimnames_codomain)
154
+ domain = setdiff (nameddimsindices (a), codomain)
155
+ return qr (a, codomain, domain; kwargs... )
155
156
end
156
-
157
157
function LinearAlgebra. qr (a:: AbstractNamedDimsArray , args... ; kwargs... )
158
158
return TensorAlgebra. qr (a, args... ; kwargs... )
159
159
end
160
160
161
+ function TensorAlgebra. lq (
162
+ a:: AbstractNamedDimsArray , dimnames_codomain, dimnames_domain; kwargs...
163
+ )
164
+ codomain = to_nameddimsindices (a, dimnames_codomain)
165
+ domain = to_nameddimsindices (a, dimnames_domain)
166
+ l_unnamed, q_unnamed = lq (dename (a), nameddimsindices (a), codomain, domain; kwargs... )
167
+ name_l = randname (dimnames (a, 1 ))
168
+ name_q = name_l
169
+ namedindices_l = named (last (axes (l_unnamed)), name_l)
170
+ namedindices_q = named (first (axes (q_unnamed)), name_q)
171
+ nameddimsindices_l = (codomain... , namedindices_l)
172
+ nameddimsindices_q = (namedindices_q, domain... )
173
+ l = nameddimsarray (l_unnamed, nameddimsindices_l)
174
+ q = nameddimsarray (q_unnamed, nameddimsindices_q)
175
+ return l, q
176
+ end
177
+ function TensorAlgebra. lq (a:: AbstractNamedDimsArray , dimnames_codomain; kwargs... )
178
+ codomain = to_nameddimsindices (a, dimnames_codomain)
179
+ domain = setdiff (nameddimsindices (a), codomain)
180
+ return lq (a, codomain, domain; kwargs... )
181
+ end
182
+ function LinearAlgebra. lq (a:: AbstractNamedDimsArray , args... ; kwargs... )
183
+ return TensorAlgebra. lq (a, args... ; kwargs... )
184
+ end
185
+
161
186
function TensorAlgebra. svd (
162
- a:: AbstractNamedDimsArray , nameddimsindices_codomain, nameddimsindices_domain
187
+ a:: AbstractNamedDimsArray , dimnames_codomain, dimnames_domain; kwargs ...
163
188
)
189
+ codomain = to_nameddimsindices (a, dimnames_codomain)
190
+ domain = to_nameddimsindices (a, dimnames_domain)
164
191
u_unnamed, s_unnamed, v_unnamed = svd (
165
- unname (a),
166
- nameddimsindices (a),
167
- to_nameddimsindices (a, nameddimsindices_codomain),
168
- to_nameddimsindices (a, nameddimsindices_domain),
192
+ dename (a), nameddimsindices (a), codomain, domain; kwargs...
169
193
)
170
194
name_u = randname (dimnames (a, 1 ))
171
195
name_v = randname (dimnames (a, 1 ))
172
196
namedindices_u = named (last (axes (u_unnamed)), name_u)
173
197
namedindices_v = named (first (axes (v_unnamed)), name_v)
174
- nameddimsindices_u = (
175
- to_nameddimsindices (a, nameddimsindices_codomain)... , namedindices_u
176
- )
198
+ nameddimsindices_u = (codomain... , namedindices_u)
177
199
nameddimsindices_s = (namedindices_u, namedindices_v)
178
- nameddimsindices_v = (namedindices_v, to_nameddimsindices (a, nameddimsindices_domain) ... )
200
+ nameddimsindices_v = (namedindices_v, domain ... )
179
201
u = nameddimsarray (u_unnamed, nameddimsindices_u)
180
202
s = nameddimsarray (s_unnamed, nameddimsindices_s)
181
203
v = nameddimsarray (v_unnamed, nameddimsindices_v)
182
204
return u, s, v
183
205
end
184
-
185
- function TensorAlgebra. svd (a:: AbstractNamedDimsArray , nameddimsindices_codomain; kwargs... )
206
+ function TensorAlgebra. svd (a:: AbstractNamedDimsArray , dimnames_codomain; kwargs... )
186
207
return svd (
187
208
a,
188
- nameddimsindices_codomain ,
189
- setdiff (nameddimsindices (a), to_nameddimsindices (a, nameddimsindices_codomain ));
209
+ dimnames_codomain ,
210
+ setdiff (nameddimsindices (a), to_nameddimsindices (a, dimnames_codomain ));
190
211
kwargs... ,
191
212
)
192
213
end
193
-
194
214
function LinearAlgebra. svd (a:: AbstractNamedDimsArray , args... ; kwargs... )
195
215
return TensorAlgebra. svd (a, args... ; kwargs... )
196
216
end
217
+
218
+ function TensorAlgebra. svdvals (
219
+ a:: AbstractNamedDimsArray , dimnames_codomain, dimnames_domain; kwargs...
220
+ )
221
+ return svdvals (
222
+ dename (a),
223
+ nameddimsindices (a),
224
+ to_nameddimsindices (a, dimnames_codomain),
225
+ to_nameddimsindices (a, dimnames_domain);
226
+ kwargs... ,
227
+ )
228
+ end
229
+ function TensorAlgebra. svdvals (a:: AbstractNamedDimsArray , dimnames_codomain; kwargs... )
230
+ codomain = to_nameddimsindices (a, dimnames_codomain)
231
+ domain = setdiff (nameddimsindices (a), codomain)
232
+ return svdvals (a, codomain, domain; kwargs... )
233
+ end
234
+ function LinearAlgebra. svdvals (a:: AbstractNamedDimsArray , args... ; kwargs... )
235
+ return TensorAlgebra. svdvals (a, args... ; kwargs... )
236
+ end
237
+
238
+ function TensorAlgebra. eigen (
239
+ a:: AbstractNamedDimsArray , dimnames_codomain, dimnames_domain; kwargs...
240
+ )
241
+ codomain = to_nameddimsindices (a, dimnames_codomain)
242
+ domain = to_nameddimsindices (a, dimnames_domain)
243
+ d_unnamed, v_unnamed = eigen (dename (a), nameddimsindices (a), codomain, domain; kwargs... )
244
+ name_d = randname (dimnames (a, 1 ))
245
+ name_d′ = randname (name_d)
246
+ name_v = name_d
247
+ namedindices_d = named (last (axes (d_unnamed)), name_d)
248
+ namedindices_d′ = named (first (axes (d_unnamed)), name_d′)
249
+ namedindices_v = named (last (axes (v_unnamed)), name_v)
250
+ nameddimsindices_d = (namedindices_d′, namedindices_d)
251
+ nameddimsindices_v = (domain... , namedindices_v)
252
+ d = nameddimsarray (d_unnamed, nameddimsindices_d)
253
+ v = nameddimsarray (v_unnamed, nameddimsindices_v)
254
+ return d, v
255
+ end
256
+ function LinearAlgebra. eigen (a:: AbstractNamedDimsArray , args... ; kwargs... )
257
+ return TensorAlgebra. eigen (a, args... ; kwargs... )
258
+ end
259
+
260
+ function TensorAlgebra. eigvals (
261
+ a:: AbstractNamedDimsArray , dimnames_codomain, dimnames_domain; kwargs...
262
+ )
263
+ codomain = to_nameddimsindices (a, dimnames_codomain)
264
+ domain = to_nameddimsindices (a, dimnames_domain)
265
+ return eigvals (dename (a), nameddimsindices (a), codomain, domain; kwargs... )
266
+ end
267
+ function LinearAlgebra. eigvals (a:: AbstractNamedDimsArray , args... ; kwargs... )
268
+ return TensorAlgebra. eigvals (a, args... ; kwargs... )
269
+ end
270
+
271
+ function TensorAlgebra. left_null (
272
+ a:: AbstractNamedDimsArray , dimnames_codomain, dimnames_domain; kwargs...
273
+ )
274
+ codomain = to_nameddimsindices (a, dimnames_codomain)
275
+ domain = to_nameddimsindices (a, dimnames_domain)
276
+ n_unnamed = left_null (dename (a), nameddimsindices (a), codomain, domain; kwargs... )
277
+ name_n = randname (dimnames (a, 1 ))
278
+ namedindices_n = named (last (axes (n_unnamed)), name_n)
279
+ nameddimsindices_n = (codomain... , namedindices_n)
280
+ return nameddimsarray (n_unnamed, nameddimsindices_n)
281
+ end
282
+ function TensorAlgebra. left_null (a:: AbstractNamedDimsArray , dimnames_codomain; kwargs... )
283
+ codomain = to_nameddimsindices (a, dimnames_codomain)
284
+ domain = setdiff (nameddimsindices (a), codomain)
285
+ return left_null (a, codomain, domain; kwargs... )
286
+ end
287
+
288
+ function TensorAlgebra. right_null (
289
+ a:: AbstractNamedDimsArray , dimnames_codomain, dimnames_domain; kwargs...
290
+ )
291
+ codomain = to_nameddimsindices (a, dimnames_codomain)
292
+ domain = to_nameddimsindices (a, dimnames_domain)
293
+ n_unnamed = right_null (dename (a), nameddimsindices (a), codomain, domain; kwargs... )
294
+ name_n = randname (dimnames (a, 1 ))
295
+ namedindices_n = named (first (axes (n_unnamed)), name_n)
296
+ nameddimsindices_n = (namedindices_n, domain... )
297
+ return nameddimsarray (n_unnamed, nameddimsindices_n)
298
+ end
299
+ function TensorAlgebra. right_null (a:: AbstractNamedDimsArray , dimnames_codomain; kwargs... )
300
+ codomain = to_nameddimsindices (a, dimnames_codomain)
301
+ domain = setdiff (nameddimsindices (a), codomain)
302
+ return right_null (a, codomain, domain; kwargs... )
303
+ end
0 commit comments