Skip to content

Commit 920f4c3

Browse files
authored
Define svd (#20)
1 parent eab2fec commit 920f4c3

File tree

3 files changed

+100
-22
lines changed

3 files changed

+100
-22
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "NamedDimsArrays"
22
uuid = "60cbd0c0-df58-4cb7-918c-6f5607b73fde"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.3.7"
4+
version = "0.3.8"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/tensoralgebra.jl

Lines changed: 77 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
using LinearAlgebra: LinearAlgebra, qr
2-
using TensorAlgebra: TensorAlgebra, blockedperm, contract, contract!, fusedims, splitdims
1+
using LinearAlgebra: LinearAlgebra
2+
using TensorAlgebra:
3+
TensorAlgebra, blockedperm, contract, contract!, fusedims, qr, splitdims, svd
34
using TensorAlgebra.BaseExtensions: BaseExtensions
45

56
function TensorAlgebra.contract!(
@@ -35,6 +36,22 @@ function Base.:*(a1::AbstractNamedDimsArray, a2::AbstractNamedDimsArray)
3536
return contract(a1, a2)
3637
end
3738

39+
# Left associative fold/reduction.
40+
# Circumvent Base definitions:
41+
# ```julia
42+
# *(A::AbstractMatrix, B::AbstractMatrix, C::AbstractMatrix)
43+
# *(A::AbstractMatrix, B::AbstractMatrix, C::AbstractMatrix, D::AbstractMatrix)
44+
# ```
45+
# that optimize matrix multiplication sequence.
46+
function Base.:*(
47+
a1::AbstractNamedDimsArray,
48+
a2::AbstractNamedDimsArray,
49+
a3::AbstractNamedDimsArray,
50+
a_rest::AbstractNamedDimsArray...,
51+
)
52+
return *(*(a1, a2), a3, a_rest...)
53+
end
54+
3855
function LinearAlgebra.mul!(
3956
a_dest::AbstractNamedDimsArray,
4057
a1::AbstractNamedDimsArray,
@@ -99,32 +116,78 @@ function TensorAlgebra.splitdims(na::AbstractNamedDimsArray, splitters::Pair...)
99116
return nameddims(a_split, names_split)
100117
end
101118

102-
function LinearAlgebra.qr(
119+
function TensorAlgebra.qr(
103120
a::AbstractNamedDimsArray,
104121
nameddimsindices_codomain,
105122
nameddimsindices_domain;
106123
positive=nothing,
107124
)
108125
@assert isnothing(positive) || !positive
109-
# TODO: This should be `TensorAlgebra.qr` rather than overloading `LinearAlgebra.qr`.
110-
# TODO: Don't require wrapping in `Tuple`.
111-
q, r = qr(
126+
q_unnamed, r_unnamed = qr(
112127
unname(a),
113-
Tuple(nameddimsindices(a)),
114-
Tuple(to_nameddimsindices(a, nameddimsindices_codomain)),
115-
Tuple(to_nameddimsindices(a, nameddimsindices_domain)),
128+
nameddimsindices(a),
129+
to_nameddimsindices(a, nameddimsindices_codomain),
130+
to_nameddimsindices(a, nameddimsindices_domain),
131+
)
132+
name_q = randname(dimnames(a, 1))
133+
name_r = name_q
134+
namedindices_q = named(last(axes(q_unnamed)), name_q)
135+
namedindices_r = named(first(axes(r_unnamed)), name_r)
136+
nameddimsindices_q = (
137+
to_nameddimsindices(a, nameddimsindices_codomain)..., namedindices_q
116138
)
117-
name_qr = randname(nameddimsindices(a)[1])
118-
nameddimsindices_q = (to_nameddimsindices(a, nameddimsindices_codomain)..., name_qr)
119-
nameddimsindices_r = (name_qr, to_nameddimsindices(a, nameddimsindices_domain)...)
120-
return nameddims(q, nameddimsindices_q), nameddims(r, nameddimsindices_r)
139+
nameddimsindices_r = (namedindices_r, to_nameddimsindices(a, nameddimsindices_domain)...)
140+
q = nameddims(q_unnamed, nameddimsindices_q)
141+
r = nameddims(r_unnamed, nameddimsindices_r)
142+
return q, r
121143
end
122144

123-
function LinearAlgebra.qr(a::AbstractNamedDimsArray, nameddimsindices_codomain; kwargs...)
145+
function TensorAlgebra.qr(a::AbstractNamedDimsArray, nameddimsindices_codomain; kwargs...)
124146
return qr(
125147
a,
126148
nameddimsindices_codomain,
127149
setdiff(nameddimsindices(a), to_nameddimsindices(a, nameddimsindices_codomain));
128150
kwargs...,
129151
)
130152
end
153+
154+
function LinearAlgebra.qr(a::AbstractNamedDimsArray, args...; kwargs...)
155+
return TensorAlgebra.qr(a, args...; kwargs...)
156+
end
157+
158+
function TensorAlgebra.svd(
159+
a::AbstractNamedDimsArray, nameddimsindices_codomain, nameddimsindices_domain
160+
)
161+
u_unnamed, s_unnamed, v_unnamed = svd(
162+
unname(a),
163+
nameddimsindices(a),
164+
to_nameddimsindices(a, nameddimsindices_codomain),
165+
to_nameddimsindices(a, nameddimsindices_domain),
166+
)
167+
name_u = randname(dimnames(a, 1))
168+
name_v = randname(dimnames(a, 1))
169+
namedindices_u = named(last(axes(u_unnamed)), name_u)
170+
namedindices_v = named(first(axes(v_unnamed)), name_v)
171+
nameddimsindices_u = (
172+
to_nameddimsindices(a, nameddimsindices_codomain)..., namedindices_u
173+
)
174+
nameddimsindices_s = (namedindices_u, namedindices_v)
175+
nameddimsindices_v = (namedindices_v, to_nameddimsindices(a, nameddimsindices_domain)...)
176+
u = nameddims(u_unnamed, nameddimsindices_u)
177+
s = nameddims(s_unnamed, nameddimsindices_s)
178+
v = nameddims(v_unnamed, nameddimsindices_v)
179+
return u, s, v
180+
end
181+
182+
function TensorAlgebra.svd(a::AbstractNamedDimsArray, nameddimsindices_codomain; kwargs...)
183+
return svd(
184+
a,
185+
nameddimsindices_codomain,
186+
setdiff(nameddimsindices(a), to_nameddimsindices(a, nameddimsindices_codomain));
187+
kwargs...,
188+
)
189+
end
190+
191+
function LinearAlgebra.svd(a::AbstractNamedDimsArray, args...; kwargs...)
192+
return TensorAlgebra.svd(a, args...; kwargs...)
193+
end

test/basics/test_tensoralgebra.jl

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using LinearAlgebra: qr
1+
using LinearAlgebra: qr, svd
22
using NamedDimsArrays: namedoneto, dename
33
using TensorAlgebra: TensorAlgebra, contract, fusedims, splitdims
44
using Test: @test, @testset, @test_broken
@@ -47,15 +47,30 @@ elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
4747
dims = (2, 2, 2, 2)
4848
i, j, k, l = namedoneto.(dims, ("i", "j", "k", "l"))
4949

50-
na = randn(elt, i, j)
50+
a = randn(elt, i, j)
5151
# TODO: Should this be allowed?
5252
# TODO: Add support for specifying new name.
53-
q, r = qr(na, (i,))
54-
@test q * r na
53+
q, r = qr(a, (i,))
54+
@test q * r a
5555

56-
na = randn(elt, i, j, k, l)
56+
a = randn(elt, i, j, k, l)
57+
# TODO: Add support for specifying new name.
58+
q, r = qr(a, (i, k), (j, l))
59+
@test q * r a
60+
end
61+
@testset "svd" begin
62+
dims = (2, 2, 2, 2)
63+
i, j, k, l = namedoneto.(dims, ("i", "j", "k", "l"))
64+
65+
a = randn(elt, i, j)
66+
# TODO: Should this be allowed?
67+
# TODO: Add support for specifying new name.
68+
u, s, v = svd(a, (i,))
69+
@test u * s * v a
70+
71+
a = randn(elt, i, j, k, l)
5772
# TODO: Add support for specifying new name.
58-
q, r = qr(na, (i, k), (j, l))
59-
@test contract(q, r) na
73+
u, s, v = svd(a, (i, k), (j, l))
74+
@test u * s * v a
6075
end
6176
end

0 commit comments

Comments
 (0)