Skip to content

Commit 0ff0ed3

Browse files
committed
Define TensorAlgebra.svd, change qr namespace
1 parent 4453f64 commit 0ff0ed3

File tree

6 files changed

+62
-77
lines changed

6 files changed

+62
-77
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "TensorAlgebra"
22
uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.1.6"
4+
version = "0.1.7"
55

66
[deps]
77
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"

src/LinearAlgebraExtensions/LinearAlgebraExtensions.jl

Lines changed: 0 additions & 3 deletions
This file was deleted.

src/LinearAlgebraExtensions/qr.jl

Lines changed: 0 additions & 69 deletions
This file was deleted.

src/TensorAlgebra.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ include("contract/output_labels.jl")
1212
include("contract/blockedperms.jl")
1313
include("contract/allocate_output.jl")
1414
include("contract/contract_matricize/contract.jl")
15-
# TODO: Rename to `TensorAlgebraLinearAlgebraExt`.
16-
include("LinearAlgebraExtensions/LinearAlgebraExtensions.jl")
15+
include("factorizations.jl")
1716

1817
end

src/factorizations.jl

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
using ArrayLayouts: LayoutMatrix
2+
using LinearAlgebra: LinearAlgebra, Diagonal
3+
4+
function qr(a::AbstractArray, biperm::BlockedPermutation{2})
5+
a_matricized = fusedims(a, biperm)
6+
# TODO: Make this more generic, allow choosing thin or full,
7+
# make sure this works on GPU.
8+
q_fact, r_matricized = LinearAlgebra.qr(a_matricized)
9+
q_matricized = typeof(a_matricized)(q_fact)
10+
axes_codomain, axes_domain = blockpermute(axes(a), biperm)
11+
axes_q = (axes_codomain..., axes(q_matricized, 2))
12+
axes_r = (axes(r_matricized, 1), axes_domain...)
13+
q = splitdims(q_matricized, axes_q)
14+
r = splitdims(r_matricized, axes_r)
15+
return q, r
16+
end
17+
18+
function qr(
19+
a::AbstractArray, labels_a, labels_codomain, labels_domain
20+
)
21+
# TODO: Generalize to conversion to `Tuple` isn't needed.
22+
return qr(a, blockedperm_indexin(Tuple(labels_a), Tuple(labels_codomain), Tuple(labels_domain)))
23+
end
24+
25+
function svd(a::AbstractArray, biperm::BlockedPermutation{2})
26+
a_matricized = fusedims(a, biperm)
27+
usv_matricized = LinearAlgebra.svd(a_matricized)
28+
u_matricized = usv_matricized.U
29+
s_diag = usv_matricized.S
30+
v_matricized = usv_matricized.Vt
31+
axes_codomain, axes_domain = blockpermute(axes(a), biperm)
32+
axes_u = (axes_codomain..., axes(u_matricized, 2))
33+
axes_v = (axes(v_matricized, 1), axes_domain...)
34+
u = splitdims(u_matricized, axes_u)
35+
# TODO: Use `DiagonalArrays.diagonal` to make it more general.
36+
s = Diagonal(s_diag)
37+
v = splitdims(v_matricized, axes_v)
38+
return u, s, v
39+
end
40+
41+
function svd(
42+
a::AbstractArray, labels_a, labels_codomain, labels_domain
43+
)
44+
return svd(a, blockedperm_indexin(Tuple(labels_a), Tuple(labels_codomain), Tuple(labels_domain)))
45+
end

test/test_basics.jl

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
using EllipsisNotation: var".."
2-
using LinearAlgebra: norm, qr
2+
using LinearAlgebra: norm
33
using StableRNGs: StableRNG
4-
using TensorAlgebra: contract, contract!, fusedims, splitdims
4+
using TensorAlgebra: contract, contract!, fusedims, qr, splitdims, svd
55
using TensorOperations: TensorOperations
66
using Test: @test, @test_broken, @testset
77

@@ -222,3 +222,16 @@ end
222222
a′ = contract(labels_a, q, (labels_q..., label_qr), r, (label_qr, labels_r...))
223223
@test a a′
224224
end
225+
@testset "svd (eltype=$elt)" for elt in elts
226+
a = randn(elt, 5, 4, 3, 2)
227+
labels_a = (:a, :b, :c, :d)
228+
labels_u = (:b, :a)
229+
labels_v = (:d, :c)
230+
u, s, v = svd(a, labels_a, labels_u, labels_v)
231+
label_u = :u
232+
label_v = :v
233+
# TODO: Define multi-arg `contract`?
234+
us, labels_us = contract(u, (labels_u..., label_u), s, (label_u, label_v))
235+
a′ = contract(labels_a, us, labels_us, v, (label_v, labels_v...))
236+
@test a a′
237+
end

0 commit comments

Comments
 (0)