Skip to content

Commit 4f322a8

Browse files
mtfishmanlkdvos
andauthored
Abelian symmetric SVD (#33)
Co-authored-by: Lukas Devos <[email protected]>
1 parent 4779a52 commit 4f322a8

File tree

7 files changed

+216
-13
lines changed

7 files changed

+216
-13
lines changed

Project.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "GradedArrays"
22
uuid = "bc96ca6e-b7c8-4bb6-888e-c93f838762c2"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.4.2"
4+
version = "0.4.4"
55

66
[deps]
77
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
@@ -11,6 +11,7 @@ DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f"
1111
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
1212
HalfIntegers = "f0d1745a-41c9-11e9-1dd9-e5d34d218721"
1313
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
14+
MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4"
1415
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1516
SplitApplyCombine = "03a91e81-4c3e-53e1-a0a4-9c0c8f19dd66"
1617
TensorProducts = "decf83d6-1968-43f4-96dc-fdb3fe15fc6d"
@@ -24,12 +25,13 @@ GradedArraysTensorAlgebraExt = "TensorAlgebra"
2425

2526
[compat]
2627
BlockArrays = "1.6.0"
27-
BlockSparseArrays = "0.5"
28+
BlockSparseArrays = "0.6.1"
2829
Compat = "4.16.0"
2930
DerivableInterfaces = "0.4.4"
3031
FillArrays = "1.13.0"
3132
HalfIntegers = "1.6.0"
3233
LinearAlgebra = "1.10.0"
34+
MatrixAlgebraKit = "0.2"
3335
Random = "1.10.0"
3436
SplitApplyCombine = "1.2.3"
3537
TensorAlgebra = "0.3.2"

src/GradedArrays.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ include("sector_product.jl")
2020

2121
include("fusion.jl")
2222
include("gradedarray.jl")
23+
include("factorizations.jl")
2324

2425
export SU2,
2526
U1,

src/factorizations.jl

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
using BlockArrays: blocks
2+
using BlockSparseArrays:
3+
BlockSparseArrays,
4+
BlockSparseMatrix,
5+
BlockPermutedDiagonalAlgorithm,
6+
BlockPermutedDiagonalTruncationStrategy,
7+
diagview,
8+
eachblockaxis,
9+
mortar_axis
10+
using LinearAlgebra: Diagonal
11+
using MatrixAlgebraKit: MatrixAlgebraKit, svd_compact!, svd_full!, svd_trunc!
12+
13+
function BlockSparseArrays.similar_output(
14+
::typeof(svd_compact!), A::GradedMatrix, S_axes, alg::BlockPermutedDiagonalAlgorithm
15+
)
16+
u_axis, v_axis = S_axes
17+
U = similar(A, axes(A, 1), dual(u_axis))
18+
T = real(eltype(A))
19+
S = BlockSparseMatrix{T,Diagonal{T,Vector{T}}}(undef, (u_axis, v_axis))
20+
Vt = similar(A, dual(v_axis), axes(A, 2))
21+
return U, S, Vt
22+
end
23+
24+
function BlockSparseArrays.similar_output(
25+
::typeof(svd_full!), A::GradedMatrix, S_axes, alg::BlockPermutedDiagonalAlgorithm
26+
)
27+
u_axis, s_axis = S_axes
28+
U = similar(A, axes(A, 1), dual(u_axis))
29+
T = real(eltype(A))
30+
S = similar(A, T, S_axes)
31+
Vt = similar(A, dual(S_axes[2]), axes(A, 2))
32+
return U, S, Vt
33+
end
34+
35+
const TGradedUSVᴴ = Tuple{<:GradedMatrix,<:GradedMatrix,<:GradedMatrix}
36+
37+
function BlockSparseArrays.similar_truncate(
38+
::typeof(svd_trunc!),
39+
(U, S, Vᴴ)::TGradedUSVᴴ,
40+
strategy::BlockPermutedDiagonalTruncationStrategy,
41+
indexmask=MatrixAlgebraKit.findtruncated(diagview(S), strategy),
42+
)
43+
u_axis, v_axis = axes(S)
44+
counter = Base.Fix1(count, Base.Fix1(getindex, indexmask))
45+
s_lengths = map(counter, blocks(u_axis))
46+
u_sectors = sectors(u_axis) .=> s_lengths
47+
v_sectors = sectors(v_axis) .=> s_lengths
48+
u_sectors_filtered = filter(>(0) last, u_sectors)
49+
v_sectors_filtered = filter(>(0) last, v_sectors)
50+
u_axis′ = gradedrange(u_sectors_filtered)
51+
u_axis = isdual(u_axis) ? dual(u_axis′) : u_axis′
52+
v_axis′ = gradedrange(v_sectors_filtered)
53+
v_axis = isdual(v_axis) ? dual(v_axis′) : v_axis′
54+
= similar(U, axes(U, 1), dual(u_axis))
55+
= similar(S, u_axis, v_axis)
56+
Ṽᴴ = similar(Vᴴ, dual(v_axis), axes(Vᴴ, 2))
57+
return Ũ, S̃, Ṽᴴ
58+
end

src/gradedarray.jl

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ using BlockSparseArrays:
55
AnyAbstractBlockSparseArray,
66
BlockSparseArray,
77
blocktype,
8+
eachblockstoredindex,
89
sparsemortar
910
using LinearAlgebra: Adjoint
1011
using TypeParameterAccessors: similartype, unwrap_array_type
@@ -41,13 +42,19 @@ function similar_blocksparse(
4142
elt::Type,
4243
axes::Tuple{AbstractGradedUnitRange,Vararg{AbstractGradedUnitRange}},
4344
)
44-
# TODO: Probably need to unwrap the type of `a` in certain cases
45-
# to make a proper block type.
46-
return BlockSparseArray{
47-
elt,length(axes),similartype(unwrap_array_type(blocktype(a)), elt, axes)
48-
}(
49-
undef, axes
45+
blockaxistypes = map(axes) do axis
46+
return eltype(Base.promote_op(eachblockaxis, typeof(axis)))
47+
end
48+
similar_blocktype = Base.promote_op(
49+
similar, blocktype(a), Type{elt}, Tuple{blockaxistypes...}
5050
)
51+
return BlockSparseArray{elt,length(axes),similar_blocktype}(undef, axes)
52+
end
53+
54+
function Base.similar(
55+
a::AbstractArray, elt::Type, axes::Tuple{SectorOneTo,Vararg{SectorOneTo}}
56+
)
57+
return similar(a, elt, Base.OneTo.(length.(axes)))
5158
end
5259

5360
function Base.similar(
@@ -120,6 +127,25 @@ end
120127

121128
ungrade(a::GradedArray) = sparsemortar(blocks(a), ungrade.(axes(a)))
122129

130+
function flux(a::GradedArray{<:Any,N}, I::Vararg{Block{1},N}) where {N}
131+
sects = ntuple(N) do d
132+
return flux(axes(a, d), I[d])
133+
end
134+
return (sects...)
135+
end
136+
function flux(a::GradedArray{<:Any,N}, I::Block{N}) where {N}
137+
return flux(a, Tuple(I)...)
138+
end
139+
function flux(a::GradedArray)
140+
sect = nothing
141+
for I in eachblockstoredindex(a)
142+
sect_I = flux(a, I)
143+
isnothing(sect) || sect_I == sect || throw(ArgumentError("Inconsistent flux."))
144+
sect = sect_I
145+
end
146+
return sect
147+
end
148+
123149
# Copy of `Base.dims2string` defined in `show.jl`.
124150
function dims_to_string(d)
125151
isempty(d) && return "0-dimensional"

src/gradedunitrange.jl

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ using BlockArrays:
1818
combine_blockaxes,
1919
findblock,
2020
mortar
21-
using BlockSparseArrays: BlockSparseArrays, blockedunitrange_getindices
21+
using BlockSparseArrays:
22+
BlockSparseArrays, blockedunitrange_getindices, eachblockaxis, mortar_axis
2223
using Compat: allequal
2324

2425
# ==================================== Definitions =======================================
@@ -60,7 +61,7 @@ end
6061

6162
# ===================================== Accessors ========================================
6263

63-
eachblockaxis(g::GradedUnitRange) = g.eachblockaxis
64+
BlockSparseArrays.eachblockaxis(g::GradedUnitRange) = g.eachblockaxis
6465
ungrade(g::GradedUnitRange) = g.full_range
6566

6667
sector_multiplicities(g::GradedUnitRange) = sector_multiplicity.(eachblockaxis(g))
@@ -69,12 +70,12 @@ sector_type(::Type{<:GradedUnitRange{<:Any,SUR}}) where {SUR} = sector_type(SUR)
6970

7071
# ==================================== Constructors ======================================
7172

72-
function mortar_axis(geachblockaxis::AbstractVector{<:SectorOneTo})
73+
function BlockSparseArrays.mortar_axis(geachblockaxis::AbstractVector{<:SectorOneTo})
7374
brange = blockedrange(length.(geachblockaxis))
7475
return GradedUnitRange(geachblockaxis, brange)
7576
end
7677

77-
function mortar_axis(gaxes::AbstractVector{<:GradedOneTo})
78+
function BlockSparseArrays.mortar_axis(gaxes::AbstractVector{<:GradedOneTo})
7879
return mortar_axis(mapreduce(eachblockaxis, vcat, gaxes))
7980
end
8081

@@ -102,6 +103,11 @@ function sectors(g::AbstractGradedUnitRange)
102103
return sector.(eachblockaxis(g))
103104
end
104105

106+
function flux(a::AbstractGradedUnitRange, I::Block{1})
107+
sect = sector(a[I])
108+
return isdual(a) ? dual(sect) : sect
109+
end
110+
105111
function map_sectors(f, g::GradedUnitRange)
106112
return GradedUnitRange(map_sectors.(f, eachblockaxis(g)), ungrade(g))
107113
end

test/Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
44
BlockSparseArrays = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
55
GradedArrays = "bc96ca6e-b7c8-4bb6-888e-c93f838762c2"
66
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
7+
MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4"
78
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
89
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
910
SparseArraysBase = "0d5efcca-f356-4864-8770-e1ed8d78f208"
@@ -16,9 +17,10 @@ TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"
1617
[compat]
1718
Aqua = "0.8.11"
1819
BlockArrays = "1.6.0"
19-
BlockSparseArrays = "0.5"
20+
BlockSparseArrays = "0.6"
2021
GradedArrays = "0.4"
2122
LinearAlgebra = "1.10.0"
23+
MatrixAlgebraKit = "0.2"
2224
Random = "1.10.0"
2325
SafeTestsets = "0.1.0"
2426
SparseArraysBase = "0.5.4"

test/test_factorizations.jl

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
using BlockArrays: Block, blocksizes
2+
using GradedArrays: U1, dual, flux, gradedrange
3+
using LinearAlgebra: I, diag, svdvals
4+
using MatrixAlgebraKit: svd_compact, svd_full, svd_trunc
5+
using Test: @test, @testset
6+
7+
const elts = (Float32, Float64, ComplexF32, ComplexF64)
8+
@testset "svd_compact (eltype=$elt)" for elt in elts
9+
for i in [2, 3], j in [2, 3], k in [2, 3], l in [2, 3]
10+
r1 = gradedrange([U1(0) => i, U1(1) => j])
11+
r2 = gradedrange([U1(0) => k, U1(1) => l])
12+
a = zeros(elt, r1, dual(r2))
13+
a[Block(2, 2)] = randn(elt, blocksizes(a)[2, 2])
14+
@test flux(a) == U1(0)
15+
u, s, vᴴ = svd_compact(a)
16+
@test sort(diag(Matrix(s)); rev=true) svdvals(Matrix(a))[1:size(s, 1)]
17+
@test u * s * vᴴ a
18+
@test Array(u'u) I
19+
@test Array(vᴴ * vᴴ') I
20+
@test flux(u) == U1(0)
21+
@test flux(s) == flux(a)
22+
@test flux(vᴴ) == U1(0)
23+
24+
r1 = gradedrange([U1(0) => i, U1(1) => j])
25+
r2 = gradedrange([U1(0) => k, U1(1) => l])
26+
a = zeros(elt, r1, dual(r2))
27+
a[Block(1, 2)] = randn(elt, blocksizes(a)[1, 2])
28+
@test flux(a) == U1(-1)
29+
u, s, vᴴ = svd_compact(a)
30+
@test sort(diag(Matrix(s)); rev=true) svdvals(Matrix(a))[1:size(s, 1)]
31+
@test u * s * vᴴ a
32+
@test Array(u'u) I
33+
@test Array(vᴴ * vᴴ') I
34+
@test flux(u) == U1(0)
35+
@test flux(s) == flux(a)
36+
@test flux(vᴴ) == U1(0)
37+
end
38+
end
39+
40+
@testset "svd_full (eltype=$elt)" for elt in elts
41+
for i in [2, 3], j in [2, 3], k in [2, 3], l in [2, 3]
42+
r1 = gradedrange([U1(0) => i, U1(1) => j])
43+
r2 = gradedrange([U1(0) => k, U1(1) => l])
44+
a = zeros(elt, r1, dual(r2))
45+
a[Block(2, 2)] = randn(elt, blocksizes(a)[2, 2])
46+
@test flux(a) == U1(0)
47+
u, s, vᴴ = svd_full(a)
48+
@test u * s * vᴴ a
49+
@test Array(u'u) I
50+
@test Array(u * u') I
51+
@test Array(vᴴ * vᴴ') I
52+
@test Array(vᴴ'vᴴ) I
53+
@test flux(u) == U1(0)
54+
@test flux(s) == flux(a)
55+
@test flux(vᴴ) == U1(0)
56+
57+
r1 = gradedrange([U1(0) => i, U1(1) => j])
58+
r2 = gradedrange([U1(0) => k, U1(1) => l])
59+
a = zeros(elt, r1, dual(r2))
60+
a[Block(1, 2)] = randn(elt, blocksizes(a)[1, 2])
61+
@test flux(a) == U1(-1)
62+
u, s, vᴴ = svd_full(a)
63+
@test u * s * vᴴ a
64+
@test Array(u'u) I
65+
@test Array(u * u') I
66+
@test Array(vᴴ * vᴴ') I
67+
@test Array(vᴴ'vᴴ) I
68+
@test flux(u) == U1(0)
69+
@test flux(s) == flux(a)
70+
@test flux(vᴴ) == U1(0)
71+
end
72+
end
73+
74+
@testset "svd_trunc (eltype=$elt)" for elt in elts
75+
for i in [2, 3], j in [2, 3], k in [2, 3], l in [2, 3]
76+
r1 = gradedrange([U1(0) => i, U1(1) => j])
77+
r2 = gradedrange([U1(0) => k, U1(1) => l])
78+
a = zeros(elt, r1, dual(r2))
79+
a[Block(2, 2)] = randn(elt, blocksizes(a)[2, 2])
80+
@test flux(a) == U1(0)
81+
u, s, vᴴ = svd_trunc(a; trunc=(; maxrank=1))
82+
@test sort(diag(Matrix(s)); rev=true) svdvals(Matrix(a))[1:size(s, 1)]
83+
@test size(u) == (size(a, 1), 1)
84+
@test size(s) == (1, 1)
85+
@test size(vᴴ) == (1, size(a, 2))
86+
@test Array(u'u) I
87+
@test Array(vᴴ * vᴴ') I
88+
@test flux(u) == U1(0)
89+
@test flux(s) == flux(a)
90+
@test flux(vᴴ) == U1(0)
91+
92+
r1 = gradedrange([U1(0) => i, U1(1) => j])
93+
r2 = gradedrange([U1(0) => k, U1(1) => l])
94+
a = zeros(elt, r1, dual(r2))
95+
a[Block(1, 2)] = randn(elt, blocksizes(a)[1, 2])
96+
@test flux(a) == U1(-1)
97+
u, s, vᴴ = svd_trunc(a; trunc=(; maxrank=1))
98+
@test sort(diag(Matrix(s)); rev=true) svdvals(Matrix(a))[1:size(s, 1)]
99+
@test size(u) == (size(a, 1), 1)
100+
@test size(s) == (1, 1)
101+
@test size(vᴴ) == (1, size(a, 2))
102+
@test Array(u'u) I
103+
@test Array(vᴴ * vᴴ') I
104+
@test flux(u) == U1(0)
105+
@test flux(s) == flux(a)
106+
@test flux(vᴴ) == U1(0)
107+
end
108+
end

0 commit comments

Comments
 (0)