Skip to content

Commit a779543

Browse files
authored
Define TensorAlgebra.matricize (#13)
1 parent 9ad40f8 commit a779543

File tree

6 files changed

+126
-121
lines changed

6 files changed

+126
-121
lines changed

Project.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "GradedArrays"
22
uuid = "bc96ca6e-b7c8-4bb6-888e-c93f838762c2"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.2.3"
4+
version = "0.2.4"
55

66
[deps]
77
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
@@ -23,16 +23,16 @@ TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
2323
GradedArraysTensorAlgebraExt = "TensorAlgebra"
2424

2525
[compat]
26-
BlockArrays = "1.5.0"
27-
BlockSparseArrays = "0.4.0"
26+
BlockArrays = "1.6.0"
27+
BlockSparseArrays = "0.4.2"
2828
Compat = "4.16.0"
2929
DerivableInterfaces = "0.4.4"
3030
FillArrays = "1.13.0"
3131
HalfIntegers = "1.6.0"
3232
LinearAlgebra = "1.10.0"
3333
Random = "1.10.0"
3434
SplitApplyCombine = "1.2.3"
35-
TensorAlgebra = "0.2.7"
35+
TensorAlgebra = "0.3.2"
3636
TensorProducts = "0.1.3"
3737
TypeParameterAccessors = "0.3.9"
3838
julia = "1.10"
Lines changed: 46 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -1,109 +1,76 @@
11
module GradedArraysTensorAlgebraExt
22

3-
using BlockArrays: Block, BlockIndexRange, blockedrange, blocks
4-
using BlockSparseArrays:
5-
BlockSparseArrays,
6-
AbstractBlockSparseArray,
7-
AbstractBlockSparseArrayInterface,
8-
BlockSparseArray,
9-
BlockSparseArrayInterface,
10-
BlockSparseMatrix,
11-
BlockSparseVector,
12-
block_merge
13-
using DerivableInterfaces: @interface
3+
using BlockArrays: blocks
4+
using BlockSparseArrays: BlockSparseArray, blockreshape
5+
using GradedArrays: GradedArray
146
using GradedArrays.GradedUnitRanges:
15-
GradedUnitRanges,
167
AbstractGradedUnitRange,
178
blockmergesortperm,
189
blocksortperm,
19-
dual,
10+
flip,
2011
invblockperm,
21-
nondual,
2212
unmerged_tensor_product
23-
using LinearAlgebra: Adjoint, Transpose
13+
using GradedArrays.SymmetrySectors: trivial
2414
using TensorAlgebra:
25-
TensorAlgebra, FusionStyle, BlockReshapeFusion, SectorFusion, fusedims, splitdims
26-
using TensorProducts: OneToOne
15+
TensorAlgebra,
16+
AbstractBlockPermutation,
17+
BlockedTuple,
18+
FusionStyle,
19+
trivial_axis,
20+
unmatricize
2721

28-
#=
29-
reducewhile(f, op, collection, state)
22+
struct SectorFusion <: FusionStyle end
3023

31-
reducewhile(x -> length(x) < 3, vcat, ["a", "b", "c", "d"], 2; init=String[]) ==
32-
(["b", "c"], 4)
33-
=#
34-
function reducewhile(f, op, collection, state; init)
35-
prev_result = init
36-
prev_state = state
37-
result = prev_result
38-
while f(result)
39-
prev_result = result
40-
prev_state = state
41-
value_and_state = iterate(collection, state)
42-
isnothing(value_and_state) && break
43-
value, state = value_and_state
44-
result = op(result, value)
45-
end
46-
return prev_result, prev_state
47-
end
48-
49-
#=
50-
groupreducewhile(f, op, collection, ngroups)
51-
52-
groupreducewhile((i, x) -> length(x) ≤ i, vcat, ["a", "b", "c", "d", "e", "f"], 3; init=String[]) ==
53-
(["a"], ["b", "c"], ["d", "e", "f"])
54-
=#
55-
function groupreducewhile(f, op, collection, ngroups; init)
56-
state = firstindex(collection)
57-
return ntuple(ngroups) do group_number
58-
result, state = reducewhile(x -> f(group_number, x), op, collection, state; init)
59-
return result
60-
end
61-
end
24+
TensorAlgebra.FusionStyle(::Type{<:GradedArray}) = SectorFusion()
6225

63-
TensorAlgebra.FusionStyle(::AbstractGradedUnitRange) = SectorFusion()
26+
# TBD consider heterogeneous sectors?
27+
TensorAlgebra.trivial_axis(t::Tuple{Vararg{AbstractGradedUnitRange}}) = trivial(first(t))
6428

65-
# Sort the blocks by sector and then merge the common sectors.
66-
function block_mergesort(a::AbstractArray)
67-
I = blockmergesortperm.(axes(a))
68-
return a[I...]
29+
function matricize_axes(
30+
blocked_axes::BlockedTuple{2,<:Any,<:Tuple{Vararg{AbstractUnitRange}}}
31+
)
32+
@assert !isempty(blocked_axes)
33+
default_axis = trivial_axis(Tuple(blocked_axes))
34+
codomain_axes, domain_axes = blocks(blocked_axes)
35+
codomain_axis = unmerged_tensor_product(default_axis, codomain_axes...)
36+
unflipped_domain_axis = unmerged_tensor_product(default_axis, domain_axes...)
37+
return codomain_axis, flip(unflipped_domain_axis)
6938
end
7039

71-
function TensorAlgebra.fusedims(
72-
::SectorFusion, a::AbstractArray, merged_axes::AbstractUnitRange...
40+
function TensorAlgebra.matricize(
41+
::SectorFusion, a::AbstractArray, biperm::AbstractBlockPermutation{2}
7342
)
74-
# First perform a fusion using a block reshape.
75-
# TODO avoid groupreducewhile. Require refactor of fusedims.
76-
unmerged_axes = groupreducewhile(
77-
unmerged_tensor_product, axes(a), length(merged_axes); init=OneToOne()
78-
) do i, axis
79-
return length(axis) length(merged_axes[i])
80-
end
81-
82-
a_reshaped = fusedims(BlockReshapeFusion(), a, unmerged_axes...)
43+
a_perm = permutedims(a, Tuple(biperm))
44+
codomain_axis, domain_axis = matricize_axes(axes(a)[biperm])
45+
a_reshaped = blockreshape(a_perm, (codomain_axis, domain_axis))
8346
# Sort the blocks by sector and merge the equivalent sectors.
84-
return block_mergesort(a_reshaped)
47+
return sectormergesort(a_reshaped)
8548
end
8649

87-
function TensorAlgebra.splitdims(
88-
::SectorFusion, a::AbstractArray, split_axes::AbstractUnitRange...
50+
function TensorAlgebra.unmatricize(
51+
::SectorFusion,
52+
m::AbstractMatrix,
53+
blocked_axes::BlockedTuple{2,<:Any,<:Tuple{Vararg{AbstractUnitRange}}},
8954
)
9055
# First, fuse axes to get `blockmergesortperm`.
9156
# Then unpermute the blocks.
92-
axes_prod = groupreducewhile(
93-
unmerged_tensor_product, split_axes, ndims(a); init=OneToOne()
94-
) do i, axis
95-
return length(axis) length(axes(a, i))
96-
end
97-
blockperms = blocksortperm.(axes_prod)
98-
sorted_axes = map((r, I) -> only(axes(r[I])), axes_prod, blockperms)
57+
fused_axes = matricize_axes(blocked_axes)
58+
59+
blockperms = blocksortperm.(fused_axes)
60+
sorted_axes = map((r, I) -> only(axes(r[I])), fused_axes, blockperms)
9961

10062
# TODO: This is doing extra copies of the blocks,
10163
# use `@view a[axes_prod...]` instead.
10264
# That will require implementing some reindexing logic
10365
# for this combination of slicing.
104-
a_unblocked = a[sorted_axes...]
105-
a_blockpermed = a_unblocked[invblockperm.(blockperms)...]
106-
return splitdims(BlockReshapeFusion(), a_blockpermed, split_axes...)
66+
m_unblocked = m[sorted_axes...]
67+
m_blockpermed = m_unblocked[invblockperm.(blockperms)...]
68+
return unmatricize(FusionStyle(BlockSparseArray), m_blockpermed, blocked_axes)
10769
end
10870

71+
# Sort the blocks by sector and then merge the common sectors.
72+
function sectormergesort(a::AbstractArray)
73+
I = blockmergesortperm.(axes(a))
74+
return a[I...]
75+
end
10976
end

src/gradedarray.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,19 @@ using BlockSparseArrays:
44
AbstractBlockSparseMatrix,
55
AnyAbstractBlockSparseArray,
66
BlockSparseArray,
7+
BlockSparseMatrix,
8+
BlockSparseVector,
79
blocktype
810
using ..GradedUnitRanges: AbstractGradedUnitRange, dual
911
using LinearAlgebra: Adjoint
1012
using TypeParameterAccessors: similartype, unwrap_array_type
1113

14+
const GradedArray{T,M,A,Blocks,Axes} = BlockSparseArray{
15+
T,M,A,Blocks,Axes
16+
} where {Axes<:Tuple{AbstractGradedUnitRange,Vararg{AbstractGradedUnitRange}}}
17+
const GradedMatrix{T,A,Blocks,Axes} = GradedArray{T,2,A,Blocks,Axes}
18+
const GradedVector{T,A,Blocks,Axes} = GradedArray{T,1,A,Blocks,Axes}
19+
1220
# TODO: Handle this through some kind of trait dispatch, maybe
1321
# a `SymmetryStyle`-like trait to check if the block sparse
1422
# matrix has graded axes.

test/Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,15 @@ TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"
1515

1616
[compat]
1717
Aqua = "0.8.11"
18-
BlockArrays = "1.5.0"
19-
BlockSparseArrays = "0.4.0"
18+
BlockArrays = "1.6.0"
19+
BlockSparseArrays = "0.4.2"
2020
GradedArrays = "0.2.0"
2121
LinearAlgebra = "1.10.0"
2222
Random = "1.10.0"
2323
SafeTestsets = "0.1.0"
2424
SparseArraysBase = "0.5.4"
2525
Suppressor = "0.2.8"
26-
TensorAlgebra = "0.2.7"
26+
TensorAlgebra = "0.3.2"
2727
TensorProducts = "0.1.3"
2828
Test = "1.10.0"
2929
TestExtras = "0.3.1"

test/test_gradedarray.jl

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@ using BlockArrays:
22
AbstractBlockArray, Block, BlockedOneTo, blockedrange, blocklengths, blocksize
33
using BlockSparseArrays:
44
BlockSparseArray, BlockSparseMatrix, BlockSparseVector, blockstoredlength
5-
using GradedArrays:
5+
using GradedArrays: GradedArray, GradedMatrix, GradedVector
6+
using GradedArrays.GradedUnitRanges:
67
GradedUnitRanges,
78
GradedOneTo,
89
GradedUnitRange,
910
GradedUnitRangeDual,
10-
blocklabels,
1111
dag,
1212
dual,
1313
gradedrange,
@@ -31,6 +31,19 @@ end
3131

3232
const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
3333
@testset "GradedArray (eltype=$elt)" for elt in elts
34+
@testset "definitions" begin
35+
r = gradedrange([U1(0) => 2, U1(1) => 2])
36+
a = BlockSparseArray{elt}(undef)
37+
@test !(a isa GradedArray) # no type piracy
38+
v = BlockSparseArray{elt}(undef, r)
39+
@test v isa GradedArray
40+
@test v isa GradedVector
41+
m = BlockSparseArray{elt}(undef, r, r)
42+
@test m isa GradedArray
43+
@test m isa GradedMatrix
44+
a = BlockSparseArray{elt}(undef, r, r, r)
45+
@test a isa GradedArray
46+
end
3447
@testset "map" begin
3548
d1 = gradedrange([U1(0) => 2, U1(1) => 2])
3649
d2 = gradedrange([U1(0) => 2, U1(1) => 2])
@@ -60,6 +73,7 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
6073
r = gradedrange([U1(0) => 2, U1(1) => 2])
6174
a = zeros(r, r, r, r)
6275
@test a isa BlockSparseArray{Float64}
76+
@test a isa GradedArray
6377
@test eltype(a) === Float64
6478
@test size(a) == (4, 4, 4, 4)
6579
@test iszero(a)

test/test_tensoralgebraext.jl

Lines changed: 49 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
using BlockArrays: Block, blocksize
2-
using BlockSparseArrays: BlockSparseArray
3-
using GradedArrays: GradedOneTo, blocklabels, dual, gradedrange
2+
using BlockSparseArrays: BlockSparseArray, BlockSparseMatrix
3+
using GradedArrays: GradedOneTo, blocklabels, dual, flip, gradedrange, space_isequal
44
using GradedArrays.SymmetrySectors: U1
55
using Random: randn!
6-
using TensorAlgebra: contract, fusedims, splitdims
6+
using TensorAlgebra: contract, matricize, unmatricize
77
using Test: @test, @test_broken, @testset
88

99
function randn_blockdiagonal(elt::Type, axes::Tuple)
@@ -18,7 +18,51 @@ end
1818

1919
const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
2020
@testset "`contract` `GradedArray` (eltype=$elt)" for elt in elts
21-
@testset "GradedOneTo with U(1)" begin
21+
@testset "matricize" begin
22+
d1 = gradedrange([U1(0) => 1, U1(1) => 1])
23+
d2 = gradedrange([U1(0) => 1, U1(1) => 1])
24+
a = randn_blockdiagonal(elt, (d1, d2, dual(d1), dual(d2)))
25+
m = matricize(a, (1, 2), (3, 4))
26+
@test m isa BlockSparseMatrix
27+
@test space_isequal(axes(m, 1), gradedrange([U1(0) => 1, U1(1) => 2, U1(2) => 1]))
28+
@test space_isequal(
29+
axes(m, 2), flip(gradedrange([U1(0) => 1, U1(-1) => 2, U1(-2) => 1]))
30+
)
31+
32+
for I in CartesianIndices(m)
33+
if I CartesianIndex.([(1, 1), (4, 4)])
34+
@test !iszero(m[I])
35+
else
36+
@test iszero(m[I])
37+
end
38+
end
39+
@test a[1, 1, 1, 1] == m[1, 1]
40+
@test a[2, 2, 2, 2] == m[4, 4]
41+
@test blocksize(m) == (3, 3)
42+
@test a == unmatricize(m, (d1, d2), (dual(d1), dual(d2)))
43+
44+
# check block fusing and splitting
45+
d = gradedrange([U1(0) => 2, U1(1) => 1])
46+
b = randn_blockdiagonal(elt, (d, d, dual(d), dual(d)))
47+
@test unmatricize(
48+
matricize(b, (1, 2), (3, 4)), (axes(b, 1), axes(b, 2)), (axes(b, 3), axes(b, 4))
49+
) == b
50+
51+
d1234 = gradedrange([U1(-2) => 1, U1(-1) => 4, U1(0) => 6, U1(1) => 4, U1(2) => 1])
52+
m = matricize(a, (1, 2, 3, 4), ())
53+
@test m isa BlockSparseMatrix
54+
@test space_isequal(axes(m, 1), d1234)
55+
@test space_isequal(axes(m, 2), flip(gradedrange([U1(0) => 1])))
56+
@test a == unmatricize(m, (d1, d2, dual(d1), dual(d2)), ())
57+
58+
m = matricize(a, (), (1, 2, 3, 4))
59+
@test m isa BlockSparseMatrix
60+
@test space_isequal(axes(m, 1), gradedrange([U1(0) => 1]))
61+
@test space_isequal(axes(m, 2), dual(d1234))
62+
@test a == unmatricize(m, (), (d1, d2, dual(d1), dual(d2)))
63+
end
64+
65+
@testset "contract with U(1)" begin
2266
d = gradedrange([U1(0) => 2, U1(1) => 3])
2367
a1 = randn_blockdiagonal(elt, (d, d, dual(d), dual(d)))
2468
a2 = randn_blockdiagonal(elt, (d, d, dual(d), dual(d)))
@@ -38,14 +82,12 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
3882
@test a_dest a_dest_dense
3983

4084
# matrix vector
41-
@test_broken a_dest, dimnames_dest = contract(a1, (2, -1, -2, 1), a3, (1, 2))
42-
#=
85+
a_dest, dimnames_dest = contract(a1, (2, -1, -2, 1), a3, (1, 2))
4386
a_dest_dense, dimnames_dest_dense = contract(a1_dense, (2, -1, -2, 1), a3_dense, (1, 2))
4487
@test dimnames_dest == dimnames_dest_dense
4588
@test size(a_dest) == size(a_dest_dense)
4689
@test a_dest isa BlockSparseArray
4790
@test a_dest a_dest_dense
48-
=#
4991

5092
# vector matrix
5193
a_dest, dimnames_dest = contract(a3, (1, 2), a1, (2, -1, -2, 1))
@@ -71,30 +113,4 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
71113
@test a_dest isa BlockSparseArray
72114
@test a_dest a_dest_dense
73115
end
74-
@testset "fusedims" begin
75-
d1 = gradedrange([U1(0) => 1, U1(1) => 1])
76-
d2 = gradedrange([U1(0) => 1, U1(1) => 1])
77-
a = randn_blockdiagonal(elt, (d1, d2, d1, d2))
78-
m = fusedims(a, (1, 2), (3, 4))
79-
for ax in axes(m)
80-
@test ax isa GradedOneTo
81-
@test blocklabels(ax) == [U1(0), U1(1), U1(2)]
82-
end
83-
for I in CartesianIndices(m)
84-
if I CartesianIndex.([(1, 1), (4, 4)])
85-
@test !iszero(m[I])
86-
else
87-
@test iszero(m[I])
88-
end
89-
end
90-
@test a[1, 1, 1, 1] == m[1, 1]
91-
@test a[2, 2, 2, 2] == m[4, 4]
92-
@test blocksize(m) == (3, 3)
93-
@test a == splitdims(m, (d1, d2), (d1, d2))
94-
95-
# check block fusing and splitting
96-
d = gradedrange([U1(0) => 2, U1(1) => 1])
97-
a = randn_blockdiagonal(elt, (d, d, dual(d), dual(d)))
98-
@test splitdims(fusedims(a, (1, 2), (3, 4)), axes(a)...) == a
99-
end
100116
end

0 commit comments

Comments
 (0)