From f985875ec94e721a354554750b4abe36d6437d16 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Sun, 18 May 2025 19:00:04 -0400 Subject: [PATCH 1/7] Define simpler codepath for blockwise map --- Project.toml | 2 +- src/blocksparsearrayinterface/map.jl | 25 +++++++++++++++++++++++++ 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index b583b073..d5200c3b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "BlockSparseArrays" uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4" authors = ["ITensor developers and contributors"] -version = "0.5.2" +version = "0.5.3" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/blocksparsearrayinterface/map.jl b/src/blocksparsearrayinterface/map.jl index 2df71338..4e897deb 100644 --- a/src/blocksparsearrayinterface/map.jl +++ b/src/blocksparsearrayinterface/map.jl @@ -1,6 +1,21 @@ +using BlockArrays: BlockRange, blockisequal using DerivableInterfaces: @interface, AbstractArrayInterface, interface using GPUArraysCore: @allowscalar +function map_blockwise!(f, a_dest::AbstractArray, a_srcs::AbstractArray...) + # TODO: This assumes element types are numbers, generalize this logic. + f_preserves_zeros = f(zero.(eltype.(a_srcs))...) == zero(eltype(a_dest)) + Is = if f_preserves_zeros + ∪(map(eachblockstoredindex, (a_dest, a_srcs...))...) + else + BlockRange(a_dest) + end + for I in Is + map!(f, view(a_dest, I), map(Base.Fix2(view, I), a_srcs)...) + end + return a_dest +end + # TODO: Rewrite this so that it takes the blocking structure # made by combining the blocking of the axes (i.e. the blocking that # is used to determine `union_stored_blocked_cartesianindices(...)`). @@ -16,6 +31,16 @@ using GPUArraysCore: @allowscalar @interface interface map_zero_dim!(f, a_dest, a_srcs...) return a_dest end + blockwise = all( + ntuple(ndims(a_dest)) do dim + ax = map(Base.Fix2(axes, dim), (a_dest, a_srcs...)) + return blockisequal(ax...) + end, + ) + if blockwise + map_blockwise!(f, a_dest, a_srcs...) + return a_dest + end # TODO: This assumes element types are numbers, generalize this logic. f_preserves_zeros = f(zero.(eltype.(a_srcs))...) == zero(eltype(a_dest)) a_dest, a_srcs = reblock(a_dest), reblock.(a_srcs) From b26e59a2a328add0fae215b961352ac0a4687dab Mon Sep 17 00:00:00 2001 From: mtfishman Date: Mon, 19 May 2025 08:52:53 -0400 Subject: [PATCH 2/7] Fix for GPUs --- src/blocksparsearrayinterface/map.jl | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/src/blocksparsearrayinterface/map.jl b/src/blocksparsearrayinterface/map.jl index 4e897deb..326b47a0 100644 --- a/src/blocksparsearrayinterface/map.jl +++ b/src/blocksparsearrayinterface/map.jl @@ -11,7 +11,22 @@ function map_blockwise!(f, a_dest::AbstractArray, a_srcs::AbstractArray...) BlockRange(a_dest) end for I in Is - map!(f, view(a_dest, I), map(Base.Fix2(view, I), a_srcs)...) + # TODO: Use: + # block_dest = @view a_dest[I] + # or: + # block_dest = @view! a_dest[I] + block_dest = blocks_maybe_single(a_dest)[Int.(Tuple(I))...] + # TODO: Use: + # block_srcs = map(a_src -> @view(a_src[I]), a_srcs) + block_srcs = map(a_srcs) do a_src + return blocks_maybe_single(a_src)[Int.(Tuple(I))...] + end + # TODO: Use `map!!` to handle immutable blocks. + map!(f, block_dest, block_srcs...) + # Replace the entire block, handles initializing new blocks + # or if blocks are immutable. + # TODO: Use `a_dest[I] = block_dest`. + blocks(a_dest)[Int.(Tuple(I))...] = block_dest end return a_dest end From 1ea6c94fc04d6e631abc2a3bc0fce63d7f6f00ca Mon Sep 17 00:00:00 2001 From: mtfishman Date: Mon, 19 May 2025 09:12:56 -0400 Subject: [PATCH 3/7] Fix for GradedArrays --- src/blocksparsearrayinterface/blocksparsearrayinterface.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/blocksparsearrayinterface/blocksparsearrayinterface.jl b/src/blocksparsearrayinterface/blocksparsearrayinterface.jl index a76edbfb..06eaa6f6 100644 --- a/src/blocksparsearrayinterface/blocksparsearrayinterface.jl +++ b/src/blocksparsearrayinterface/blocksparsearrayinterface.jl @@ -445,7 +445,7 @@ end to_blocks_indices(I::BlockSlice{<:BlockRange{1}}) = Int.(I.block) to_blocks_indices(I::BlockIndices{<:Vector{<:Block{1}}}) = Int.(I.blocks) -to_blocks_indices(I::Base.Slice{<:BlockedOneTo}) = Base.OneTo(blocklength(I.indices)) +to_blocks_indices(I::Base.Slice) = Base.OneTo(blocklength(I.indices)) @interface ::AbstractBlockSparseArrayInterface function BlockArrays.blocks( a::SubArray{<:Any,<:Any,<:Any,<:Tuple{Vararg{BlockSliceCollection}}} From 1e47d75c753a7e4f8641c4c28ec7e56c3de36db8 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Mon, 19 May 2025 10:54:15 -0400 Subject: [PATCH 4/7] Try fixing tests --- .../BlockArraysExtensions.jl | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/BlockArraysExtensions/BlockArraysExtensions.jl b/src/BlockArraysExtensions/BlockArraysExtensions.jl index 3853304f..b4c2d926 100644 --- a/src/BlockArraysExtensions/BlockArraysExtensions.jl +++ b/src/BlockArraysExtensions/BlockArraysExtensions.jl @@ -290,13 +290,6 @@ function blockrange(axis::AbstractUnitRange, r::Int) return error("Slicing with integer values isn't supported.") end -function blockrange(axis::AbstractUnitRange, r::AbstractVector{<:Block{1}}) - for b in r - @assert b ∈ blockaxes(axis, 1) - end - return r -end - # This handles changing the blocking, for example: # a = BlockSparseArray{Float64}([2, 2, 2, 2], [2, 2, 2, 2]) # I = blockedrange([4, 4]) @@ -315,13 +308,20 @@ end # I = BlockedVector([Block(4), Block(3), Block(2), Block(1)], [2, 2]) # I = BlockVector([Block(4), Block(3), Block(2), Block(1)], [2, 2]) # a[I, I] -function blockrange(axis::BlockedOneTo{<:Integer}, r::AbstractBlockVector{<:Block{1}}) +function blockrange(axis::AbstractUnitRange, r::AbstractBlockVector{<:Block{1}}) for b in r @assert b ∈ blockaxes(axis, 1) end return only(blockaxes(r)) end +function blockrange(axis::AbstractUnitRange, r::AbstractVector{<:Block{1}}) + for b in r + @assert b ∈ blockaxes(axis, 1) + end + return r +end + using BlockArrays: BlockSlice function blockrange(axis::AbstractUnitRange, r::BlockSlice) return blockrange(axis, r.block) From 3e0c3cebd8a1c73584cb2b9ababc583843d07d81 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Mon, 19 May 2025 11:17:04 -0400 Subject: [PATCH 5/7] Better code organization --- src/blocksparsearrayinterface/map.jl | 26 ++++++++++++++++++-------- 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/src/blocksparsearrayinterface/map.jl b/src/blocksparsearrayinterface/map.jl index 326b47a0..74b8b203 100644 --- a/src/blocksparsearrayinterface/map.jl +++ b/src/blocksparsearrayinterface/map.jl @@ -2,11 +2,27 @@ using BlockArrays: BlockRange, blockisequal using DerivableInterfaces: @interface, AbstractArrayInterface, interface using GPUArraysCore: @allowscalar +# Check if the block structures are the same. +function same_block_structure(as::AbstractArray...) + isempty(as) && return true + return all( + ntuple(ndims(first(as))) do dim + ax = map(Base.Fix2(axes, dim), as) + return blockisequal(ax...) + end, + ) +end + +# Find the common stored blocks, assuming the block structures are the same. +function union_eachblockstoredindex(as::AbstractArray...) + return ∪(map(eachblockstoredindex, (a_dest, a_srcs...))...) +end + function map_blockwise!(f, a_dest::AbstractArray, a_srcs::AbstractArray...) # TODO: This assumes element types are numbers, generalize this logic. f_preserves_zeros = f(zero.(eltype.(a_srcs))...) == zero(eltype(a_dest)) Is = if f_preserves_zeros - ∪(map(eachblockstoredindex, (a_dest, a_srcs...))...) + union_eachblockstoredindex(a_dest, a_srcs...) else BlockRange(a_dest) end @@ -46,13 +62,7 @@ end @interface interface map_zero_dim!(f, a_dest, a_srcs...) return a_dest end - blockwise = all( - ntuple(ndims(a_dest)) do dim - ax = map(Base.Fix2(axes, dim), (a_dest, a_srcs...)) - return blockisequal(ax...) - end, - ) - if blockwise + if same_block_structure(a_dest, a_srcs...) map_blockwise!(f, a_dest, a_srcs...) return a_dest end From c3a0d5d31590aab7bfd5046d462d3b78c77cddda Mon Sep 17 00:00:00 2001 From: mtfishman Date: Mon, 19 May 2025 11:28:56 -0400 Subject: [PATCH 6/7] Fix typo --- src/blocksparsearrayinterface/map.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/blocksparsearrayinterface/map.jl b/src/blocksparsearrayinterface/map.jl index 74b8b203..b046eee2 100644 --- a/src/blocksparsearrayinterface/map.jl +++ b/src/blocksparsearrayinterface/map.jl @@ -15,7 +15,7 @@ end # Find the common stored blocks, assuming the block structures are the same. function union_eachblockstoredindex(as::AbstractArray...) - return ∪(map(eachblockstoredindex, (a_dest, a_srcs...))...) + return ∪(map(eachblockstoredindex, as)) end function map_blockwise!(f, a_dest::AbstractArray, a_srcs::AbstractArray...) From 0f7f8849d7167cddb1cc4b0415fb0588f1fee2c3 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Mon, 19 May 2025 11:44:53 -0400 Subject: [PATCH 7/7] Fix another typo --- src/blocksparsearrayinterface/map.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/blocksparsearrayinterface/map.jl b/src/blocksparsearrayinterface/map.jl index b046eee2..60a6bc8c 100644 --- a/src/blocksparsearrayinterface/map.jl +++ b/src/blocksparsearrayinterface/map.jl @@ -15,7 +15,7 @@ end # Find the common stored blocks, assuming the block structures are the same. function union_eachblockstoredindex(as::AbstractArray...) - return ∪(map(eachblockstoredindex, as)) + return ∪(map(eachblockstoredindex, as)...) end function map_blockwise!(f, a_dest::AbstractArray, a_srcs::AbstractArray...)