diff --git a/Project.toml b/Project.toml index 1944cd5c..96944c68 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "BlockSparseArrays" uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4" authors = ["ITensor developers and contributors"] -version = "0.2.18" +version = "0.2.19" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/ext/BlockSparseArraysTensorAlgebraExt/BlockSparseArraysTensorAlgebraExt.jl b/ext/BlockSparseArraysTensorAlgebraExt/BlockSparseArraysTensorAlgebraExt.jl index fc5237c7..37f03d09 100644 --- a/ext/BlockSparseArraysTensorAlgebraExt/BlockSparseArraysTensorAlgebraExt.jl +++ b/ext/BlockSparseArraysTensorAlgebraExt/BlockSparseArraysTensorAlgebraExt.jl @@ -94,13 +94,15 @@ function TensorAlgebra.splitdims( groupreducewhile(tensor_product, split_axes, ndims(a); init=OneToOne()) do i, axis return length(axis) ≤ length(axes(a, i)) end - blockperms = invblockperm.(blocksortperm.(axes_prod)) + blockperms = blocksortperm.(axes_prod) + sorted_axes = map((r, I) -> only(axes(r[I])), axes_prod, blockperms) + # TODO: This is doing extra copies of the blocks, # use `@view a[axes_prod...]` instead. # That will require implementing some reindexing logic # for this combination of slicing. - a_unblocked = a[axes_prod...] - a_blockpermed = a_unblocked[blockperms...] + a_unblocked = a[sorted_axes...] + a_blockpermed = a_unblocked[invblockperm.(blockperms)...] return splitdims(BlockReshapeFusion(), a_blockpermed, split_axes...) end diff --git a/test/test_gradedunitrangesext.jl b/test/test_gradedunitrangesext.jl index f5632580..1f76c9d9 100644 --- a/test/test_gradedunitrangesext.jl +++ b/test/test_gradedunitrangesext.jl @@ -19,10 +19,12 @@ using SymmetrySectors: U1 using TensorAlgebra: fusedims, splitdims using LinearAlgebra: adjoint using Random: randn! -function blockdiagonal!(f, a::AbstractArray) - for i in 1:minimum(blocksize(a)) +function randn_blockdiagonal(elt::Type, axes::Tuple) + a = BlockSparseArray{elt}(axes) + blockdiaglength = minimum(blocksize(a)) + for i in 1:blockdiaglength b = Block(ntuple(Returns(i), ndims(a))) - a[b] = f(a[b]) + a[b] = randn!(a[b]) end return a end @@ -32,8 +34,7 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @testset "map" begin d1 = gradedrange([U1(0) => 2, U1(1) => 2]) d2 = gradedrange([U1(0) => 2, U1(1) => 2]) - a = BlockSparseArray{elt}(d1, d2, d1, d2) - blockdiagonal!(randn!, a) + a = randn_blockdiagonal(elt, (d1, d2, d1, d2)) @test axes(a, 1) isa GradedOneTo @test axes(view(a, 1:4, 1:4, 1:4, 1:4), 1) isa GradedOneTo @@ -89,8 +90,7 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @testset "fusedims" begin d1 = gradedrange([U1(0) => 1, U1(1) => 1]) d2 = gradedrange([U1(0) => 1, U1(1) => 1]) - a = BlockSparseArray{elt}(d1, d2, d1, d2) - blockdiagonal!(randn!, a) + a = randn_blockdiagonal(elt, (d1, d2, d1, d2)) m = fusedims(a, (1, 2), (3, 4)) for ax in axes(m) @test ax isa GradedOneTo @@ -107,6 +107,11 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @test a[2, 2, 2, 2] == m[4, 4] @test blocksize(m) == (3, 3) @test a == splitdims(m, (d1, d2), (d1, d2)) + + # check block fusing and splitting + d = gradedrange([U1(0) => 2, U1(1) => 1]) + a = randn_blockdiagonal(elt, (d, d, dual(d), dual(d))) + @test splitdims(fusedims(a, (1, 2), (3, 4)), axes(a)...) == a end @testset "dual axes" begin