From 33c6fc41f49938c9cf464a5272e4df469dd8a7f0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Fri, 14 Feb 2025 18:03:19 -0500 Subject: [PATCH 1/4] quickfix --- .../BlockSparseArraysTensorAlgebraExt.jl | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/ext/BlockSparseArraysTensorAlgebraExt/BlockSparseArraysTensorAlgebraExt.jl b/ext/BlockSparseArraysTensorAlgebraExt/BlockSparseArraysTensorAlgebraExt.jl index fc5237c7..4a1766cd 100644 --- a/ext/BlockSparseArraysTensorAlgebraExt/BlockSparseArraysTensorAlgebraExt.jl +++ b/ext/BlockSparseArraysTensorAlgebraExt/BlockSparseArraysTensorAlgebraExt.jl @@ -1,6 +1,6 @@ module BlockSparseArraysTensorAlgebraExt using BlockArrays: AbstractBlockedUnitRange -using GradedUnitRanges: tensor_product +using GradedUnitRanges: tensor_product, gradedrange using TensorAlgebra: TensorAlgebra, FusionStyle, BlockReshapeFusion function TensorAlgebra.:⊗(a1::AbstractBlockedUnitRange, a2::AbstractBlockedUnitRange) @@ -94,13 +94,17 @@ function TensorAlgebra.splitdims( groupreducewhile(tensor_product, split_axes, ndims(a); init=OneToOne()) do i, axis return length(axis) ≤ length(axes(a, i)) end - blockperms = invblockperm.(blocksortperm.(axes_prod)) + blockperms = blocksortperm.(axes_prod) + sorted_axes = ntuple( + i -> gradedrange(map(b -> length(axes_prod[i][b]), blockperms[i])), ndims(a) + ) + # TODO: This is doing extra copies of the blocks, # use `@view a[axes_prod...]` instead. # That will require implementing some reindexing logic # for this combination of slicing. - a_unblocked = a[axes_prod...] - a_blockpermed = a_unblocked[blockperms...] + a_unblocked = a[sorted_axes...] + a_blockpermed = a_unblocked[invblockperm.(blockperms)...] return splitdims(BlockReshapeFusion(), a_blockpermed, split_axes...) end From e4dfa352b974840c526ed4ea055ab69b47110f10 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Tue, 18 Feb 2025 12:42:29 -0500 Subject: [PATCH 2/4] explicit test --- test/test_gradedunitrangesext.jl | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/test/test_gradedunitrangesext.jl b/test/test_gradedunitrangesext.jl index f5632580..1f76c9d9 100644 --- a/test/test_gradedunitrangesext.jl +++ b/test/test_gradedunitrangesext.jl @@ -19,10 +19,12 @@ using SymmetrySectors: U1 using TensorAlgebra: fusedims, splitdims using LinearAlgebra: adjoint using Random: randn! -function blockdiagonal!(f, a::AbstractArray) - for i in 1:minimum(blocksize(a)) +function randn_blockdiagonal(elt::Type, axes::Tuple) + a = BlockSparseArray{elt}(axes) + blockdiaglength = minimum(blocksize(a)) + for i in 1:blockdiaglength b = Block(ntuple(Returns(i), ndims(a))) - a[b] = f(a[b]) + a[b] = randn!(a[b]) end return a end @@ -32,8 +34,7 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @testset "map" begin d1 = gradedrange([U1(0) => 2, U1(1) => 2]) d2 = gradedrange([U1(0) => 2, U1(1) => 2]) - a = BlockSparseArray{elt}(d1, d2, d1, d2) - blockdiagonal!(randn!, a) + a = randn_blockdiagonal(elt, (d1, d2, d1, d2)) @test axes(a, 1) isa GradedOneTo @test axes(view(a, 1:4, 1:4, 1:4, 1:4), 1) isa GradedOneTo @@ -89,8 +90,7 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @testset "fusedims" begin d1 = gradedrange([U1(0) => 1, U1(1) => 1]) d2 = gradedrange([U1(0) => 1, U1(1) => 1]) - a = BlockSparseArray{elt}(d1, d2, d1, d2) - blockdiagonal!(randn!, a) + a = randn_blockdiagonal(elt, (d1, d2, d1, d2)) m = fusedims(a, (1, 2), (3, 4)) for ax in axes(m) @test ax isa GradedOneTo @@ -107,6 +107,11 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @test a[2, 2, 2, 2] == m[4, 4] @test blocksize(m) == (3, 3) @test a == splitdims(m, (d1, d2), (d1, d2)) + + # check block fusing and splitting + d = gradedrange([U1(0) => 2, U1(1) => 1]) + a = randn_blockdiagonal(elt, (d, d, dual(d), dual(d))) + @test splitdims(fusedims(a, (1, 2), (3, 4)), axes(a)...) == a end @testset "dual axes" begin From ade5505371f276c3d52d6aba6232b27ec5ce4fa6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Tue, 18 Feb 2025 12:43:57 -0500 Subject: [PATCH 3/4] cleaner code --- .../BlockSparseArraysTensorAlgebraExt.jl | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/ext/BlockSparseArraysTensorAlgebraExt/BlockSparseArraysTensorAlgebraExt.jl b/ext/BlockSparseArraysTensorAlgebraExt/BlockSparseArraysTensorAlgebraExt.jl index 4a1766cd..37f03d09 100644 --- a/ext/BlockSparseArraysTensorAlgebraExt/BlockSparseArraysTensorAlgebraExt.jl +++ b/ext/BlockSparseArraysTensorAlgebraExt/BlockSparseArraysTensorAlgebraExt.jl @@ -1,6 +1,6 @@ module BlockSparseArraysTensorAlgebraExt using BlockArrays: AbstractBlockedUnitRange -using GradedUnitRanges: tensor_product, gradedrange +using GradedUnitRanges: tensor_product using TensorAlgebra: TensorAlgebra, FusionStyle, BlockReshapeFusion function TensorAlgebra.:⊗(a1::AbstractBlockedUnitRange, a2::AbstractBlockedUnitRange) @@ -95,9 +95,7 @@ function TensorAlgebra.splitdims( return length(axis) ≤ length(axes(a, i)) end blockperms = blocksortperm.(axes_prod) - sorted_axes = ntuple( - i -> gradedrange(map(b -> length(axes_prod[i][b]), blockperms[i])), ndims(a) - ) + sorted_axes = map((r, I) -> only(axes(r[I])), axes_prod, blockperms) # TODO: This is doing extra copies of the blocks, # use `@view a[axes_prod...]` instead. From 1ddbfb4d012df830924555e535e4e4ea2d2890e9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Tue, 18 Feb 2025 14:53:10 -0500 Subject: [PATCH 4/4] bump version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 1944cd5c..96944c68 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "BlockSparseArrays" uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4" authors = ["ITensor developers and contributors"] -version = "0.2.18" +version = "0.2.19" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"