Skip to content

Commit d2d7be9

Browse files
authored
Add more factorizations (#72)
1 parent 3249398 commit d2d7be9

File tree

3 files changed

+181
-49
lines changed

3 files changed

+181
-49
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "NamedDimsArrays"
22
uuid = "60cbd0c0-df58-4cb7-918c-6f5607b73fde"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.6.0"
4+
version = "0.6.1"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/tensoralgebra.jl

Lines changed: 148 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,20 @@
11
using LinearAlgebra: LinearAlgebra
22
using TensorAlgebra:
3-
TensorAlgebra, blockedperm, contract, contract!, fusedims, permmortar, qr, splitdims, svd
3+
TensorAlgebra,
4+
blockedperm,
5+
contract,
6+
contract!,
7+
eigen,
8+
eigvals,
9+
fusedims,
10+
left_null,
11+
lq,
12+
permmortar,
13+
qr,
14+
right_null,
15+
splitdims,
16+
svd,
17+
svdvals
418
using TensorAlgebra.BaseExtensions: BaseExtensions
519

620
function TensorAlgebra.contract!(
@@ -94,7 +108,7 @@ function TensorAlgebra.fusedims(na::AbstractNamedDimsArray, fusions::Pair...)
94108
)
95109
end
96110
perm = blockedperm(na, nameddimsindices_fuse...)
97-
a_fused = fusedims(unname(na), perm)
111+
a_fused = fusedims(dename(na), perm)
98112
return nameddimsarray(a_fused, nameddimsindices_fused)
99113
end
100114

@@ -107,7 +121,7 @@ function TensorAlgebra.splitdims(na::AbstractNamedDimsArray, splitters::Pair...)
107121
split_lengths = unname.(split_namedlengths)
108122
return fused_dim => split_lengths
109123
end
110-
a_split = splitdims(unname(na), splitters_unnamed...)
124+
a_split = splitdims(dename(na), splitters_unnamed...)
111125
names_split = Any[tuple.(nameddimsindices(na))...]
112126
for splitter in splitters
113127
fused_name, split_namedlengths = splitter
@@ -120,77 +134,170 @@ function TensorAlgebra.splitdims(na::AbstractNamedDimsArray, splitters::Pair...)
120134
end
121135

122136
function TensorAlgebra.qr(
123-
a::AbstractNamedDimsArray,
124-
nameddimsindices_codomain,
125-
nameddimsindices_domain;
126-
positive=nothing,
137+
a::AbstractNamedDimsArray, dimnames_codomain, dimnames_domain; kwargs...
127138
)
128-
@assert isnothing(positive) || !positive
129-
q_unnamed, r_unnamed = qr(
130-
unname(a),
131-
nameddimsindices(a),
132-
to_nameddimsindices(a, nameddimsindices_codomain),
133-
to_nameddimsindices(a, nameddimsindices_domain),
134-
)
139+
codomain = to_nameddimsindices(a, dimnames_codomain)
140+
domain = to_nameddimsindices(a, dimnames_domain)
141+
q_unnamed, r_unnamed = qr(dename(a), nameddimsindices(a), codomain, domain; kwargs...)
135142
name_q = randname(dimnames(a, 1))
136143
name_r = name_q
137144
namedindices_q = named(last(axes(q_unnamed)), name_q)
138145
namedindices_r = named(first(axes(r_unnamed)), name_r)
139-
nameddimsindices_q = (
140-
to_nameddimsindices(a, nameddimsindices_codomain)..., namedindices_q
141-
)
142-
nameddimsindices_r = (namedindices_r, to_nameddimsindices(a, nameddimsindices_domain)...)
146+
nameddimsindices_q = (codomain..., namedindices_q)
147+
nameddimsindices_r = (namedindices_r, domain...)
143148
q = nameddimsarray(q_unnamed, nameddimsindices_q)
144149
r = nameddimsarray(r_unnamed, nameddimsindices_r)
145150
return q, r
146151
end
147-
148-
function TensorAlgebra.qr(a::AbstractNamedDimsArray, nameddimsindices_codomain; kwargs...)
149-
return qr(
150-
a,
151-
nameddimsindices_codomain,
152-
setdiff(nameddimsindices(a), to_nameddimsindices(a, nameddimsindices_codomain));
153-
kwargs...,
154-
)
152+
function TensorAlgebra.qr(a::AbstractNamedDimsArray, dimnames_codomain; kwargs...)
153+
codomain = to_nameddimsindices(a, dimnames_codomain)
154+
domain = setdiff(nameddimsindices(a), codomain)
155+
return qr(a, codomain, domain; kwargs...)
155156
end
156-
157157
function LinearAlgebra.qr(a::AbstractNamedDimsArray, args...; kwargs...)
158158
return TensorAlgebra.qr(a, args...; kwargs...)
159159
end
160160

161+
function TensorAlgebra.lq(
162+
a::AbstractNamedDimsArray, dimnames_codomain, dimnames_domain; kwargs...
163+
)
164+
codomain = to_nameddimsindices(a, dimnames_codomain)
165+
domain = to_nameddimsindices(a, dimnames_domain)
166+
l_unnamed, q_unnamed = lq(dename(a), nameddimsindices(a), codomain, domain; kwargs...)
167+
name_l = randname(dimnames(a, 1))
168+
name_q = name_l
169+
namedindices_l = named(last(axes(l_unnamed)), name_l)
170+
namedindices_q = named(first(axes(q_unnamed)), name_q)
171+
nameddimsindices_l = (codomain..., namedindices_l)
172+
nameddimsindices_q = (namedindices_q, domain...)
173+
l = nameddimsarray(l_unnamed, nameddimsindices_l)
174+
q = nameddimsarray(q_unnamed, nameddimsindices_q)
175+
return l, q
176+
end
177+
function TensorAlgebra.lq(a::AbstractNamedDimsArray, dimnames_codomain; kwargs...)
178+
codomain = to_nameddimsindices(a, dimnames_codomain)
179+
domain = setdiff(nameddimsindices(a), codomain)
180+
return lq(a, codomain, domain; kwargs...)
181+
end
182+
function LinearAlgebra.lq(a::AbstractNamedDimsArray, args...; kwargs...)
183+
return TensorAlgebra.lq(a, args...; kwargs...)
184+
end
185+
161186
function TensorAlgebra.svd(
162-
a::AbstractNamedDimsArray, nameddimsindices_codomain, nameddimsindices_domain
187+
a::AbstractNamedDimsArray, dimnames_codomain, dimnames_domain; kwargs...
163188
)
189+
codomain = to_nameddimsindices(a, dimnames_codomain)
190+
domain = to_nameddimsindices(a, dimnames_domain)
164191
u_unnamed, s_unnamed, v_unnamed = svd(
165-
unname(a),
166-
nameddimsindices(a),
167-
to_nameddimsindices(a, nameddimsindices_codomain),
168-
to_nameddimsindices(a, nameddimsindices_domain),
192+
dename(a), nameddimsindices(a), codomain, domain; kwargs...
169193
)
170194
name_u = randname(dimnames(a, 1))
171195
name_v = randname(dimnames(a, 1))
172196
namedindices_u = named(last(axes(u_unnamed)), name_u)
173197
namedindices_v = named(first(axes(v_unnamed)), name_v)
174-
nameddimsindices_u = (
175-
to_nameddimsindices(a, nameddimsindices_codomain)..., namedindices_u
176-
)
198+
nameddimsindices_u = (codomain..., namedindices_u)
177199
nameddimsindices_s = (namedindices_u, namedindices_v)
178-
nameddimsindices_v = (namedindices_v, to_nameddimsindices(a, nameddimsindices_domain)...)
200+
nameddimsindices_v = (namedindices_v, domain...)
179201
u = nameddimsarray(u_unnamed, nameddimsindices_u)
180202
s = nameddimsarray(s_unnamed, nameddimsindices_s)
181203
v = nameddimsarray(v_unnamed, nameddimsindices_v)
182204
return u, s, v
183205
end
184-
185-
function TensorAlgebra.svd(a::AbstractNamedDimsArray, nameddimsindices_codomain; kwargs...)
206+
function TensorAlgebra.svd(a::AbstractNamedDimsArray, dimnames_codomain; kwargs...)
186207
return svd(
187208
a,
188-
nameddimsindices_codomain,
189-
setdiff(nameddimsindices(a), to_nameddimsindices(a, nameddimsindices_codomain));
209+
dimnames_codomain,
210+
setdiff(nameddimsindices(a), to_nameddimsindices(a, dimnames_codomain));
190211
kwargs...,
191212
)
192213
end
193-
194214
function LinearAlgebra.svd(a::AbstractNamedDimsArray, args...; kwargs...)
195215
return TensorAlgebra.svd(a, args...; kwargs...)
196216
end
217+
218+
function TensorAlgebra.svdvals(
219+
a::AbstractNamedDimsArray, dimnames_codomain, dimnames_domain; kwargs...
220+
)
221+
return svdvals(
222+
dename(a),
223+
nameddimsindices(a),
224+
to_nameddimsindices(a, dimnames_codomain),
225+
to_nameddimsindices(a, dimnames_domain);
226+
kwargs...,
227+
)
228+
end
229+
function TensorAlgebra.svdvals(a::AbstractNamedDimsArray, dimnames_codomain; kwargs...)
230+
codomain = to_nameddimsindices(a, dimnames_codomain)
231+
domain = setdiff(nameddimsindices(a), codomain)
232+
return svdvals(a, codomain, domain; kwargs...)
233+
end
234+
function LinearAlgebra.svdvals(a::AbstractNamedDimsArray, args...; kwargs...)
235+
return TensorAlgebra.svdvals(a, args...; kwargs...)
236+
end
237+
238+
function TensorAlgebra.eigen(
239+
a::AbstractNamedDimsArray, dimnames_codomain, dimnames_domain; kwargs...
240+
)
241+
codomain = to_nameddimsindices(a, dimnames_codomain)
242+
domain = to_nameddimsindices(a, dimnames_domain)
243+
d_unnamed, v_unnamed = eigen(dename(a), nameddimsindices(a), codomain, domain; kwargs...)
244+
name_d = randname(dimnames(a, 1))
245+
name_d′ = randname(name_d)
246+
name_v = name_d
247+
namedindices_d = named(last(axes(d_unnamed)), name_d)
248+
namedindices_d′ = named(first(axes(d_unnamed)), name_d′)
249+
namedindices_v = named(last(axes(v_unnamed)), name_v)
250+
nameddimsindices_d = (namedindices_d′, namedindices_d)
251+
nameddimsindices_v = (domain..., namedindices_v)
252+
d = nameddimsarray(d_unnamed, nameddimsindices_d)
253+
v = nameddimsarray(v_unnamed, nameddimsindices_v)
254+
return d, v
255+
end
256+
function LinearAlgebra.eigen(a::AbstractNamedDimsArray, args...; kwargs...)
257+
return TensorAlgebra.eigen(a, args...; kwargs...)
258+
end
259+
260+
function TensorAlgebra.eigvals(
261+
a::AbstractNamedDimsArray, dimnames_codomain, dimnames_domain; kwargs...
262+
)
263+
codomain = to_nameddimsindices(a, dimnames_codomain)
264+
domain = to_nameddimsindices(a, dimnames_domain)
265+
return eigvals(dename(a), nameddimsindices(a), codomain, domain; kwargs...)
266+
end
267+
function LinearAlgebra.eigvals(a::AbstractNamedDimsArray, args...; kwargs...)
268+
return TensorAlgebra.eigvals(a, args...; kwargs...)
269+
end
270+
271+
function TensorAlgebra.left_null(
272+
a::AbstractNamedDimsArray, dimnames_codomain, dimnames_domain; kwargs...
273+
)
274+
codomain = to_nameddimsindices(a, dimnames_codomain)
275+
domain = to_nameddimsindices(a, dimnames_domain)
276+
n_unnamed = left_null(dename(a), nameddimsindices(a), codomain, domain; kwargs...)
277+
name_n = randname(dimnames(a, 1))
278+
namedindices_n = named(last(axes(n_unnamed)), name_n)
279+
nameddimsindices_n = (codomain..., namedindices_n)
280+
return nameddimsarray(n_unnamed, nameddimsindices_n)
281+
end
282+
function TensorAlgebra.left_null(a::AbstractNamedDimsArray, dimnames_codomain; kwargs...)
283+
codomain = to_nameddimsindices(a, dimnames_codomain)
284+
domain = setdiff(nameddimsindices(a), codomain)
285+
return left_null(a, codomain, domain; kwargs...)
286+
end
287+
288+
function TensorAlgebra.right_null(
289+
a::AbstractNamedDimsArray, dimnames_codomain, dimnames_domain; kwargs...
290+
)
291+
codomain = to_nameddimsindices(a, dimnames_codomain)
292+
domain = to_nameddimsindices(a, dimnames_domain)
293+
n_unnamed = right_null(dename(a), nameddimsindices(a), codomain, domain; kwargs...)
294+
name_n = randname(dimnames(a, 1))
295+
namedindices_n = named(first(axes(n_unnamed)), name_n)
296+
nameddimsindices_n = (namedindices_n, domain...)
297+
return nameddimsarray(n_unnamed, nameddimsindices_n)
298+
end
299+
function TensorAlgebra.right_null(a::AbstractNamedDimsArray, dimnames_codomain; kwargs...)
300+
codomain = to_nameddimsindices(a, dimnames_codomain)
301+
domain = setdiff(nameddimsindices(a), codomain)
302+
return right_null(a, codomain, domain; kwargs...)
303+
end

test/test_tensoralgebra.jl

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
using LinearAlgebra: qr, svd
2-
using NamedDimsArrays: namedoneto, dename
1+
using LinearAlgebra: lq, norm, qr, svd
2+
using NamedDimsArrays: dename, left_null, nameddimsindices, namedoneto, right_null
33
using TensorAlgebra: TensorAlgebra, contract, fusedims, splitdims
44
using Test: @test, @testset, @test_broken
55
elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
@@ -43,20 +43,24 @@ elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
4343
@test dename(na_split, ("j", "i", "b"))
4444
reshape(dename(na, ("a", "b")), (dename(j), dename(i), dename(b)))
4545
end
46-
@testset "qr" begin
46+
@testset "qr/lq" begin
4747
dims = (2, 2, 2, 2)
4848
i, j, k, l = namedoneto.(dims, ("i", "j", "k", "l"))
4949

5050
a = randn(elt, i, j)
5151
# TODO: Should this be allowed?
5252
# TODO: Add support for specifying new name.
53-
q, r = qr(a, (i,))
54-
@test q * r a
53+
for f in (qr, lq)
54+
x, y = f(a, (i,))
55+
@test x * y a
56+
end
5557

5658
a = randn(elt, i, j, k, l)
5759
# TODO: Add support for specifying new name.
58-
q, r = qr(a, (i, k), (j, l))
59-
@test q * r a
60+
for f in (qr, lq)
61+
x, y = f(a, (i, k), (j, l))
62+
@test x * y a
63+
end
6064
end
6165
@testset "svd" begin
6266
dims = (2, 2, 2, 2)
@@ -72,5 +76,26 @@ elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
7276
# TODO: Add support for specifying new name.
7377
u, s, v = svd(a, (i, k), (j, l))
7478
@test u * s * v a
79+
80+
# Test truncation.
81+
a = randn(elt, i, j, k, l)
82+
u, s, v = svd(a, (i, k), (j, l); trunc=(; maxrank=2))
83+
@test u * s * v a
84+
@test Int.(Tuple(size(s))) == (2, 2)
85+
end
86+
@testset "left_null/eight_null" begin
87+
dims = (2, 2, 2, 2)
88+
i, j, k, l = namedoneto.(dims, ("i", "j", "k", "l"))
89+
90+
a = randn(elt, i, j, k, l)
91+
# TODO: Add support for specifying new name.
92+
for n in (left_null(a, (i, k), (j, l)), left_null(a, (i, k)))
93+
@test (i, k) nameddimsindices(n)
94+
@test norm(n * a) 0
95+
end
96+
for n in (right_null(a, (i, k), (j, l)), right_null(a, (i, k)))
97+
@test (j, l) nameddimsindices(n)
98+
@test norm(n * a) 0
99+
end
75100
end
76101
end

0 commit comments

Comments
 (0)