Skip to content

Define TensorAlgebra.matricize #13

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Apr 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "GradedArrays"
uuid = "bc96ca6e-b7c8-4bb6-888e-c93f838762c2"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.2.3"
version = "0.2.4"

[deps]
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
Expand All @@ -23,16 +23,16 @@ TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
GradedArraysTensorAlgebraExt = "TensorAlgebra"

[compat]
BlockArrays = "1.5.0"
BlockSparseArrays = "0.4.0"
BlockArrays = "1.6.0"
BlockSparseArrays = "0.4.2"
Compat = "4.16.0"
DerivableInterfaces = "0.4.4"
FillArrays = "1.13.0"
HalfIntegers = "1.6.0"
LinearAlgebra = "1.10.0"
Random = "1.10.0"
SplitApplyCombine = "1.2.3"
TensorAlgebra = "0.2.7"
TensorAlgebra = "0.3.2"
TensorProducts = "0.1.3"
TypeParameterAccessors = "0.3.9"
julia = "1.10"
125 changes: 46 additions & 79 deletions ext/GradedArraysTensorAlgebraExt/GradedArraysTensorAlgebraExt.jl
Original file line number Diff line number Diff line change
@@ -1,109 +1,76 @@
module GradedArraysTensorAlgebraExt

using BlockArrays: Block, BlockIndexRange, blockedrange, blocks
using BlockSparseArrays:
BlockSparseArrays,
AbstractBlockSparseArray,
AbstractBlockSparseArrayInterface,
BlockSparseArray,
BlockSparseArrayInterface,
BlockSparseMatrix,
BlockSparseVector,
block_merge
using DerivableInterfaces: @interface
using BlockArrays: blocks
using BlockSparseArrays: BlockSparseArray, blockreshape
using GradedArrays: GradedArray
using GradedArrays.GradedUnitRanges:
GradedUnitRanges,
AbstractGradedUnitRange,
blockmergesortperm,
blocksortperm,
dual,
flip,
invblockperm,
nondual,
unmerged_tensor_product
using LinearAlgebra: Adjoint, Transpose
using GradedArrays.SymmetrySectors: trivial
using TensorAlgebra:
TensorAlgebra, FusionStyle, BlockReshapeFusion, SectorFusion, fusedims, splitdims
using TensorProducts: OneToOne
TensorAlgebra,
AbstractBlockPermutation,
BlockedTuple,
FusionStyle,
trivial_axis,
unmatricize

#=
reducewhile(f, op, collection, state)
struct SectorFusion <: FusionStyle end

reducewhile(x -> length(x) < 3, vcat, ["a", "b", "c", "d"], 2; init=String[]) ==
(["b", "c"], 4)
=#
function reducewhile(f, op, collection, state; init)
prev_result = init
prev_state = state
result = prev_result
while f(result)
prev_result = result
prev_state = state
value_and_state = iterate(collection, state)
isnothing(value_and_state) && break
value, state = value_and_state
result = op(result, value)
end
return prev_result, prev_state
end

#=
groupreducewhile(f, op, collection, ngroups)

groupreducewhile((i, x) -> length(x) ≤ i, vcat, ["a", "b", "c", "d", "e", "f"], 3; init=String[]) ==
(["a"], ["b", "c"], ["d", "e", "f"])
=#
function groupreducewhile(f, op, collection, ngroups; init)
state = firstindex(collection)
return ntuple(ngroups) do group_number
result, state = reducewhile(x -> f(group_number, x), op, collection, state; init)
return result
end
end
TensorAlgebra.FusionStyle(::Type{<:GradedArray}) = SectorFusion()

Check warning on line 24 in ext/GradedArraysTensorAlgebraExt/GradedArraysTensorAlgebraExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/GradedArraysTensorAlgebraExt/GradedArraysTensorAlgebraExt.jl#L24

Added line #L24 was not covered by tests

TensorAlgebra.FusionStyle(::AbstractGradedUnitRange) = SectorFusion()
# TBD consider heterogeneous sectors?
TensorAlgebra.trivial_axis(t::Tuple{Vararg{AbstractGradedUnitRange}}) = trivial(first(t))

Check warning on line 27 in ext/GradedArraysTensorAlgebraExt/GradedArraysTensorAlgebraExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/GradedArraysTensorAlgebraExt/GradedArraysTensorAlgebraExt.jl#L27

Added line #L27 was not covered by tests

# Sort the blocks by sector and then merge the common sectors.
function block_mergesort(a::AbstractArray)
I = blockmergesortperm.(axes(a))
return a[I...]
function matricize_axes(

Check warning on line 29 in ext/GradedArraysTensorAlgebraExt/GradedArraysTensorAlgebraExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/GradedArraysTensorAlgebraExt/GradedArraysTensorAlgebraExt.jl#L29

Added line #L29 was not covered by tests
blocked_axes::BlockedTuple{2,<:Any,<:Tuple{Vararg{AbstractUnitRange}}}
)
@assert !isempty(blocked_axes)
default_axis = trivial_axis(Tuple(blocked_axes))
codomain_axes, domain_axes = blocks(blocked_axes)
codomain_axis = unmerged_tensor_product(default_axis, codomain_axes...)
unflipped_domain_axis = unmerged_tensor_product(default_axis, domain_axes...)
return codomain_axis, flip(unflipped_domain_axis)

Check warning on line 37 in ext/GradedArraysTensorAlgebraExt/GradedArraysTensorAlgebraExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/GradedArraysTensorAlgebraExt/GradedArraysTensorAlgebraExt.jl#L32-L37

Added lines #L32 - L37 were not covered by tests
end

function TensorAlgebra.fusedims(
::SectorFusion, a::AbstractArray, merged_axes::AbstractUnitRange...
function TensorAlgebra.matricize(

Check warning on line 40 in ext/GradedArraysTensorAlgebraExt/GradedArraysTensorAlgebraExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/GradedArraysTensorAlgebraExt/GradedArraysTensorAlgebraExt.jl#L40

Added line #L40 was not covered by tests
::SectorFusion, a::AbstractArray, biperm::AbstractBlockPermutation{2}
)
# First perform a fusion using a block reshape.
# TODO avoid groupreducewhile. Require refactor of fusedims.
unmerged_axes = groupreducewhile(
unmerged_tensor_product, axes(a), length(merged_axes); init=OneToOne()
) do i, axis
return length(axis) ≤ length(merged_axes[i])
end

a_reshaped = fusedims(BlockReshapeFusion(), a, unmerged_axes...)
a_perm = permutedims(a, Tuple(biperm))
codomain_axis, domain_axis = matricize_axes(axes(a)[biperm])
a_reshaped = blockreshape(a_perm, (codomain_axis, domain_axis))

Check warning on line 45 in ext/GradedArraysTensorAlgebraExt/GradedArraysTensorAlgebraExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/GradedArraysTensorAlgebraExt/GradedArraysTensorAlgebraExt.jl#L43-L45

Added lines #L43 - L45 were not covered by tests
# Sort the blocks by sector and merge the equivalent sectors.
return block_mergesort(a_reshaped)
return sectormergesort(a_reshaped)

Check warning on line 47 in ext/GradedArraysTensorAlgebraExt/GradedArraysTensorAlgebraExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/GradedArraysTensorAlgebraExt/GradedArraysTensorAlgebraExt.jl#L47

Added line #L47 was not covered by tests
end

function TensorAlgebra.splitdims(
::SectorFusion, a::AbstractArray, split_axes::AbstractUnitRange...
function TensorAlgebra.unmatricize(

Check warning on line 50 in ext/GradedArraysTensorAlgebraExt/GradedArraysTensorAlgebraExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/GradedArraysTensorAlgebraExt/GradedArraysTensorAlgebraExt.jl#L50

Added line #L50 was not covered by tests
::SectorFusion,
m::AbstractMatrix,
blocked_axes::BlockedTuple{2,<:Any,<:Tuple{Vararg{AbstractUnitRange}}},
)
# First, fuse axes to get `blockmergesortperm`.
# Then unpermute the blocks.
axes_prod = groupreducewhile(
unmerged_tensor_product, split_axes, ndims(a); init=OneToOne()
) do i, axis
return length(axis) ≤ length(axes(a, i))
end
blockperms = blocksortperm.(axes_prod)
sorted_axes = map((r, I) -> only(axes(r[I])), axes_prod, blockperms)
fused_axes = matricize_axes(blocked_axes)

Check warning on line 57 in ext/GradedArraysTensorAlgebraExt/GradedArraysTensorAlgebraExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/GradedArraysTensorAlgebraExt/GradedArraysTensorAlgebraExt.jl#L57

Added line #L57 was not covered by tests

blockperms = blocksortperm.(fused_axes)
sorted_axes = map((r, I) -> only(axes(r[I])), fused_axes, blockperms)

Check warning on line 60 in ext/GradedArraysTensorAlgebraExt/GradedArraysTensorAlgebraExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/GradedArraysTensorAlgebraExt/GradedArraysTensorAlgebraExt.jl#L59-L60

Added lines #L59 - L60 were not covered by tests

# TODO: This is doing extra copies of the blocks,
# use `@view a[axes_prod...]` instead.
# That will require implementing some reindexing logic
# for this combination of slicing.
a_unblocked = a[sorted_axes...]
a_blockpermed = a_unblocked[invblockperm.(blockperms)...]
return splitdims(BlockReshapeFusion(), a_blockpermed, split_axes...)
m_unblocked = m[sorted_axes...]
m_blockpermed = m_unblocked[invblockperm.(blockperms)...]
return unmatricize(FusionStyle(BlockSparseArray), m_blockpermed, blocked_axes)

Check warning on line 68 in ext/GradedArraysTensorAlgebraExt/GradedArraysTensorAlgebraExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/GradedArraysTensorAlgebraExt/GradedArraysTensorAlgebraExt.jl#L66-L68

Added lines #L66 - L68 were not covered by tests
end

# Sort the blocks by sector and then merge the common sectors.
function sectormergesort(a::AbstractArray)
I = blockmergesortperm.(axes(a))
return a[I...]

Check warning on line 74 in ext/GradedArraysTensorAlgebraExt/GradedArraysTensorAlgebraExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/GradedArraysTensorAlgebraExt/GradedArraysTensorAlgebraExt.jl#L72-L74

Added lines #L72 - L74 were not covered by tests
end
end
8 changes: 8 additions & 0 deletions src/gradedarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,19 @@ using BlockSparseArrays:
AbstractBlockSparseMatrix,
AnyAbstractBlockSparseArray,
BlockSparseArray,
BlockSparseMatrix,
BlockSparseVector,
blocktype
using ..GradedUnitRanges: AbstractGradedUnitRange, dual
using LinearAlgebra: Adjoint
using TypeParameterAccessors: similartype, unwrap_array_type

const GradedArray{T,M,A,Blocks,Axes} = BlockSparseArray{
T,M,A,Blocks,Axes
} where {Axes<:Tuple{AbstractGradedUnitRange,Vararg{AbstractGradedUnitRange}}}
const GradedMatrix{T,A,Blocks,Axes} = GradedArray{T,2,A,Blocks,Axes}
const GradedVector{T,A,Blocks,Axes} = GradedArray{T,1,A,Blocks,Axes}

# TODO: Handle this through some kind of trait dispatch, maybe
# a `SymmetryStyle`-like trait to check if the block sparse
# matrix has graded axes.
Expand Down
6 changes: 3 additions & 3 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@ TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"

[compat]
Aqua = "0.8.11"
BlockArrays = "1.5.0"
BlockSparseArrays = "0.4.0"
BlockArrays = "1.6.0"
BlockSparseArrays = "0.4.2"
GradedArrays = "0.2.0"
LinearAlgebra = "1.10.0"
Random = "1.10.0"
SafeTestsets = "0.1.0"
SparseArraysBase = "0.5.4"
Suppressor = "0.2.8"
TensorAlgebra = "0.2.7"
TensorAlgebra = "0.3.2"
TensorProducts = "0.1.3"
Test = "1.10.0"
TestExtras = "0.3.1"
18 changes: 16 additions & 2 deletions test/test_gradedarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@ using BlockArrays:
AbstractBlockArray, Block, BlockedOneTo, blockedrange, blocklengths, blocksize
using BlockSparseArrays:
BlockSparseArray, BlockSparseMatrix, BlockSparseVector, blockstoredlength
using GradedArrays:
using GradedArrays: GradedArray, GradedMatrix, GradedVector
using GradedArrays.GradedUnitRanges:
GradedUnitRanges,
GradedOneTo,
GradedUnitRange,
GradedUnitRangeDual,
blocklabels,
dag,
dual,
gradedrange,
Expand All @@ -31,6 +31,19 @@ end

const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
@testset "GradedArray (eltype=$elt)" for elt in elts
@testset "definitions" begin
r = gradedrange([U1(0) => 2, U1(1) => 2])
a = BlockSparseArray{elt}(undef)
@test !(a isa GradedArray) # no type piracy
v = BlockSparseArray{elt}(undef, r)
@test v isa GradedArray
@test v isa GradedVector
m = BlockSparseArray{elt}(undef, r, r)
@test m isa GradedArray
@test m isa GradedMatrix
a = BlockSparseArray{elt}(undef, r, r, r)
@test a isa GradedArray
end
@testset "map" begin
d1 = gradedrange([U1(0) => 2, U1(1) => 2])
d2 = gradedrange([U1(0) => 2, U1(1) => 2])
Expand Down Expand Up @@ -60,6 +73,7 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
r = gradedrange([U1(0) => 2, U1(1) => 2])
a = zeros(r, r, r, r)
@test a isa BlockSparseArray{Float64}
@test a isa GradedArray
@test eltype(a) === Float64
@test size(a) == (4, 4, 4, 4)
@test iszero(a)
Expand Down
82 changes: 49 additions & 33 deletions test/test_tensoralgebraext.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
using BlockArrays: Block, blocksize
using BlockSparseArrays: BlockSparseArray
using GradedArrays: GradedOneTo, blocklabels, dual, gradedrange
using BlockSparseArrays: BlockSparseArray, BlockSparseMatrix
using GradedArrays: GradedOneTo, blocklabels, dual, flip, gradedrange, space_isequal
using GradedArrays.SymmetrySectors: U1
using Random: randn!
using TensorAlgebra: contract, fusedims, splitdims
using TensorAlgebra: contract, matricize, unmatricize
using Test: @test, @test_broken, @testset

function randn_blockdiagonal(elt::Type, axes::Tuple)
Expand All @@ -18,7 +18,51 @@ end

const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
@testset "`contract` `GradedArray` (eltype=$elt)" for elt in elts
@testset "GradedOneTo with U(1)" begin
@testset "matricize" begin
d1 = gradedrange([U1(0) => 1, U1(1) => 1])
d2 = gradedrange([U1(0) => 1, U1(1) => 1])
a = randn_blockdiagonal(elt, (d1, d2, dual(d1), dual(d2)))
m = matricize(a, (1, 2), (3, 4))
@test m isa BlockSparseMatrix
@test space_isequal(axes(m, 1), gradedrange([U1(0) => 1, U1(1) => 2, U1(2) => 1]))
@test space_isequal(
axes(m, 2), flip(gradedrange([U1(0) => 1, U1(-1) => 2, U1(-2) => 1]))
)

for I in CartesianIndices(m)
if I ∈ CartesianIndex.([(1, 1), (4, 4)])
@test !iszero(m[I])
else
@test iszero(m[I])
end
end
@test a[1, 1, 1, 1] == m[1, 1]
@test a[2, 2, 2, 2] == m[4, 4]
@test blocksize(m) == (3, 3)
@test a == unmatricize(m, (d1, d2), (dual(d1), dual(d2)))

# check block fusing and splitting
d = gradedrange([U1(0) => 2, U1(1) => 1])
b = randn_blockdiagonal(elt, (d, d, dual(d), dual(d)))
@test unmatricize(
matricize(b, (1, 2), (3, 4)), (axes(b, 1), axes(b, 2)), (axes(b, 3), axes(b, 4))
) == b

d1234 = gradedrange([U1(-2) => 1, U1(-1) => 4, U1(0) => 6, U1(1) => 4, U1(2) => 1])
m = matricize(a, (1, 2, 3, 4), ())
@test m isa BlockSparseMatrix
@test space_isequal(axes(m, 1), d1234)
@test space_isequal(axes(m, 2), flip(gradedrange([U1(0) => 1])))
@test a == unmatricize(m, (d1, d2, dual(d1), dual(d2)), ())

m = matricize(a, (), (1, 2, 3, 4))
@test m isa BlockSparseMatrix
@test space_isequal(axes(m, 1), gradedrange([U1(0) => 1]))
@test space_isequal(axes(m, 2), dual(d1234))
@test a == unmatricize(m, (), (d1, d2, dual(d1), dual(d2)))
end

@testset "contract with U(1)" begin
d = gradedrange([U1(0) => 2, U1(1) => 3])
a1 = randn_blockdiagonal(elt, (d, d, dual(d), dual(d)))
a2 = randn_blockdiagonal(elt, (d, d, dual(d), dual(d)))
Expand All @@ -38,14 +82,12 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
@test a_dest ≈ a_dest_dense

# matrix vector
@test_broken a_dest, dimnames_dest = contract(a1, (2, -1, -2, 1), a3, (1, 2))
#=
a_dest, dimnames_dest = contract(a1, (2, -1, -2, 1), a3, (1, 2))
a_dest_dense, dimnames_dest_dense = contract(a1_dense, (2, -1, -2, 1), a3_dense, (1, 2))
@test dimnames_dest == dimnames_dest_dense
@test size(a_dest) == size(a_dest_dense)
@test a_dest isa BlockSparseArray
@test a_dest ≈ a_dest_dense
=#

# vector matrix
a_dest, dimnames_dest = contract(a3, (1, 2), a1, (2, -1, -2, 1))
Expand All @@ -71,30 +113,4 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
@test a_dest isa BlockSparseArray
@test a_dest ≈ a_dest_dense
end
@testset "fusedims" begin
d1 = gradedrange([U1(0) => 1, U1(1) => 1])
d2 = gradedrange([U1(0) => 1, U1(1) => 1])
a = randn_blockdiagonal(elt, (d1, d2, d1, d2))
m = fusedims(a, (1, 2), (3, 4))
for ax in axes(m)
@test ax isa GradedOneTo
@test blocklabels(ax) == [U1(0), U1(1), U1(2)]
end
for I in CartesianIndices(m)
if I ∈ CartesianIndex.([(1, 1), (4, 4)])
@test !iszero(m[I])
else
@test iszero(m[I])
end
end
@test a[1, 1, 1, 1] == m[1, 1]
@test a[2, 2, 2, 2] == m[4, 4]
@test blocksize(m) == (3, 3)
@test a == splitdims(m, (d1, d2), (d1, d2))

# check block fusing and splitting
d = gradedrange([U1(0) => 2, U1(1) => 1])
a = randn_blockdiagonal(elt, (d, d, dual(d), dual(d)))
@test splitdims(fusedims(a, (1, 2), (3, 4)), axes(a)...) == a
end
end
Loading