Skip to content

Commit 79fc797

Browse files
authored
More factorizations (polar, orth, etc.) (#74)
1 parent d2d7be9 commit 79fc797

File tree

3 files changed

+58
-42
lines changed

3 files changed

+58
-42
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "NamedDimsArrays"
22
uuid = "60cbd0c0-df58-4cb7-918c-6f5607b73fde"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.6.1"
4+
version = "0.6.2"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -36,6 +36,6 @@ LinearAlgebra = "1.10"
3636
MapBroadcast = "0.1.6"
3737
Random = "1.10"
3838
SimpleTraits = "0.9.4"
39-
TensorAlgebra = "0.2"
39+
TensorAlgebra = "0.2.9"
4040
TypeParameterAccessors = "0.2, 0.3"
4141
julia = "1.10"

src/tensoralgebra.jl

Lines changed: 43 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,17 @@ using TensorAlgebra:
66
contract!,
77
eigen,
88
eigvals,
9+
factorize,
910
fusedims,
1011
left_null,
12+
left_orth,
13+
left_polar,
1114
lq,
1215
permmortar,
1316
qr,
1417
right_null,
18+
right_orth,
19+
right_polar,
1520
splitdims,
1621
svd,
1722
svdvals
@@ -133,55 +138,57 @@ function TensorAlgebra.splitdims(na::AbstractNamedDimsArray, splitters::Pair...)
133138
return nameddimsarray(a_split, names_split)
134139
end
135140

136-
function TensorAlgebra.qr(
137-
a::AbstractNamedDimsArray, dimnames_codomain, dimnames_domain; kwargs...
141+
# Generic interface for forwarding binary factorizations
142+
# to the corresponding functions in TensorAlgebra.jl.
143+
function factorize_with(
144+
f, a::AbstractNamedDimsArray, dimnames_codomain, dimnames_domain; kwargs...
138145
)
139146
codomain = to_nameddimsindices(a, dimnames_codomain)
140147
domain = to_nameddimsindices(a, dimnames_domain)
141-
q_unnamed, r_unnamed = qr(dename(a), nameddimsindices(a), codomain, domain; kwargs...)
142-
name_q = randname(dimnames(a, 1))
143-
name_r = name_q
144-
namedindices_q = named(last(axes(q_unnamed)), name_q)
145-
namedindices_r = named(first(axes(r_unnamed)), name_r)
146-
nameddimsindices_q = (codomain..., namedindices_q)
147-
nameddimsindices_r = (namedindices_r, domain...)
148-
q = nameddimsarray(q_unnamed, nameddimsindices_q)
149-
r = nameddimsarray(r_unnamed, nameddimsindices_r)
150-
return q, r
148+
x_unnamed, y_unnamed = f(dename(a), nameddimsindices(a), codomain, domain; kwargs...)
149+
name_x = randname(dimnames(a, 1))
150+
name_y = name_x
151+
namedindices_x = named(last(axes(x_unnamed)), name_x)
152+
namedindices_y = named(first(axes(y_unnamed)), name_y)
153+
nameddimsindices_x = (codomain..., namedindices_x)
154+
nameddimsindices_y = (namedindices_y, domain...)
155+
x = nameddimsarray(x_unnamed, nameddimsindices_x)
156+
y = nameddimsarray(y_unnamed, nameddimsindices_y)
157+
return x, y
151158
end
152-
function TensorAlgebra.qr(a::AbstractNamedDimsArray, dimnames_codomain; kwargs...)
159+
function factorize_with(f, a::AbstractNamedDimsArray, dimnames_codomain; kwargs...)
153160
codomain = to_nameddimsindices(a, dimnames_codomain)
154161
domain = setdiff(nameddimsindices(a), codomain)
155-
return qr(a, codomain, domain; kwargs...)
156-
end
157-
function LinearAlgebra.qr(a::AbstractNamedDimsArray, args...; kwargs...)
158-
return TensorAlgebra.qr(a, args...; kwargs...)
162+
return factorize_with(f, a, codomain, domain; kwargs...)
159163
end
160164

161-
function TensorAlgebra.lq(
162-
a::AbstractNamedDimsArray, dimnames_codomain, dimnames_domain; kwargs...
163-
)
164-
codomain = to_nameddimsindices(a, dimnames_codomain)
165-
domain = to_nameddimsindices(a, dimnames_domain)
166-
l_unnamed, q_unnamed = lq(dename(a), nameddimsindices(a), codomain, domain; kwargs...)
167-
name_l = randname(dimnames(a, 1))
168-
name_q = name_l
169-
namedindices_l = named(last(axes(l_unnamed)), name_l)
170-
namedindices_q = named(first(axes(q_unnamed)), name_q)
171-
nameddimsindices_l = (codomain..., namedindices_l)
172-
nameddimsindices_q = (namedindices_q, domain...)
173-
l = nameddimsarray(l_unnamed, nameddimsindices_l)
174-
q = nameddimsarray(q_unnamed, nameddimsindices_q)
175-
return l, q
165+
for f in [:qr, :lq, :left_polar, :right_polar, :left_orth, :right_orth, :factorize]
166+
@eval begin
167+
function TensorAlgebra.$f(
168+
a::AbstractNamedDimsArray, dimnames_codomain, dimnames_domain; kwargs...
169+
)
170+
return factorize_with($f, a, dimnames_codomain, dimnames_domain; kwargs...)
171+
end
172+
function TensorAlgebra.$f(a::AbstractNamedDimsArray, dimnames_codomain; kwargs...)
173+
return factorize_with($f, a, dimnames_codomain; kwargs...)
174+
end
175+
end
176176
end
177-
function TensorAlgebra.lq(a::AbstractNamedDimsArray, dimnames_codomain; kwargs...)
178-
codomain = to_nameddimsindices(a, dimnames_codomain)
179-
domain = setdiff(nameddimsindices(a), codomain)
180-
return lq(a, codomain, domain; kwargs...)
177+
178+
# Overload LinearAlgebra functions where relevant.
179+
function LinearAlgebra.qr(a::AbstractNamedDimsArray, args...; kwargs...)
180+
return TensorAlgebra.qr(a, args...; kwargs...)
181181
end
182182
function LinearAlgebra.lq(a::AbstractNamedDimsArray, args...; kwargs...)
183183
return TensorAlgebra.lq(a, args...; kwargs...)
184184
end
185+
function LinearAlgebra.factorize(a::AbstractNamedDimsArray, args...; kwargs...)
186+
return TensorAlgebra.factorize(a, args...; kwargs...)
187+
end
188+
189+
#
190+
# Non-binary factorizations.
191+
#
185192

186193
function TensorAlgebra.svd(
187194
a::AbstractNamedDimsArray, dimnames_codomain, dimnames_domain; kwargs...

test/test_tensoralgebra.jl

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,14 @@
1-
using LinearAlgebra: lq, norm, qr, svd
2-
using NamedDimsArrays: dename, left_null, nameddimsindices, namedoneto, right_null
1+
using LinearAlgebra: factorize, lq, norm, qr, svd
2+
using NamedDimsArrays:
3+
dename,
4+
left_null,
5+
left_orth,
6+
left_polar,
7+
nameddimsindices,
8+
namedoneto,
9+
right_null,
10+
right_orth,
11+
right_polar
312
using TensorAlgebra: TensorAlgebra, contract, fusedims, splitdims
413
using Test: @test, @testset, @test_broken
514
elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
@@ -50,14 +59,14 @@ elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
5059
a = randn(elt, i, j)
5160
# TODO: Should this be allowed?
5261
# TODO: Add support for specifying new name.
53-
for f in (qr, lq)
62+
for f in (qr, lq, left_polar, right_polar, left_orth, right_orth, factorize)
5463
x, y = f(a, (i,))
5564
@test x * y a
5665
end
5766

5867
a = randn(elt, i, j, k, l)
5968
# TODO: Add support for specifying new name.
60-
for f in (qr, lq)
69+
for f in (qr, lq, left_polar, right_polar, left_orth, right_orth, factorize)
6170
x, y = f(a, (i, k), (j, l))
6271
@test x * y a
6372
end

0 commit comments

Comments
 (0)