From c9702e7311e06723040732c4dcc91ad4153cfa1e Mon Sep 17 00:00:00 2001 From: mtfishman Date: Mon, 16 Dec 2024 00:39:37 -0500 Subject: [PATCH] Set up package extensions, improve permutedims --- Project.toml | 16 +++-- .../{src => }/BlockSparseArraysAdaptExt.jl | 2 +- .../test/Project.toml | 3 - ...> BlockSparseArraysGradedUnitRangesExt.jl} | 7 +- .../BlockSparseArraysTensorAlgebraExt.jl} | 30 +++++++- .../reducewhile.jl | 0 .../test/Project.toml | 3 - src/BlockSparseArrays.jl | 5 -- src/abstractblocksparsearray/map.jl | 16 +++++ .../blocksparsearrayinterface.jl | 71 ++++++++++++++++--- test/basics/test_basics.jl | 2 +- test/basics/test_extensions.jl | 2 - .../basics/test_gradedunitrangesext.jl | 0 .../basics/test_tensoralgebraext.jl | 0 14 files changed, 119 insertions(+), 38 deletions(-) rename ext/BlockSparseArraysAdaptExt/{src => }/BlockSparseArraysAdaptExt.jl (68%) delete mode 100644 ext/BlockSparseArraysGradedUnitRangesExt/test/Project.toml rename ext/BlockSparseArraysTensorAlgebraExt/{src/BlockSparseArraysTensorAlgebraExt.jl => BlockSparseArraysGradedUnitRangesExt.jl} (69%) rename ext/{BlockSparseArraysGradedUnitRangesExt/src/BlockSparseArraysGradedUnitRangesExt.jl => BlockSparseArraysTensorAlgebraExt/BlockSparseArraysTensorAlgebraExt.jl} (87%) rename ext/{BlockSparseArraysGradedUnitRangesExt/src => BlockSparseArraysTensorAlgebraExt}/reducewhile.jl (100%) delete mode 100644 ext/BlockSparseArraysTensorAlgebraExt/test/Project.toml delete mode 100644 test/basics/test_extensions.jl rename ext/BlockSparseArraysGradedUnitRangesExt/test/runtests.jl => test/basics/test_gradedunitrangesext.jl (100%) rename ext/BlockSparseArraysTensorAlgebraExt/test/runtests.jl => test/basics/test_tensoralgebraext.jl (100%) diff --git a/Project.toml b/Project.toml index cfa01700..ad3b4115 100644 --- a/Project.toml +++ b/Project.toml @@ -10,17 +10,21 @@ BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" BroadcastMapConversion = "4a4adec5-520f-4750-bb37-d5e66b4ddeb2" Derive = "a07dfc7f-7d04-4eb5-84cc-a97f051f655a" Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4" -GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" GradedUnitRanges = "e2de450a-8a67-46c7-b59c-01d5a3d041c5" -LabelledNumbers = "f856a3a6-4152-4ec4-b2a7-02c1a55d7993" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" -NestedPermutedDimsArrays = "2c2a8ec4-3cfc-4276-aa3e-1307b4294e58" SparseArraysBase = "0d5efcca-f356-4864-8770-e1ed8d78f208" SplitApplyCombine = "03a91e81-4c3e-53e1-a0a4-9c0c8f19dd66" -TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" TypeParameterAccessors = "7e5a90cf-f82e-492e-a09b-e3e26432c138" +[weakdeps] +LabelledNumbers = "f856a3a6-4152-4ec4-b2a7-02c1a55d7993" +TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" + +[extensions] +BlockSparseArraysAdaptExt = "Adapt" +BlockSparseArraysTensorAlgebraExt = ["LabelledNumbers", "TensorAlgebra"] + [compat] Adapt = "4.1.1" Aqua = "0.8.9" @@ -29,16 +33,20 @@ BlockArrays = "1.2.0" Derive = "0.3.1" Dictionaries = "0.4.3" GPUArraysCore = "0.1.0" +GradedUnitRanges = "0.1.0" +LabelledNumbers = "0.1.0" LinearAlgebra = "1.10" MacroTools = "0.5.13" SparseArraysBase = "0.2" SplitApplyCombine = "1.2.3" TensorAlgebra = "0.1.0" +TypeParameterAccessors = "0.1.0" Test = "1.10" julia = "1.10" [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] diff --git a/ext/BlockSparseArraysAdaptExt/src/BlockSparseArraysAdaptExt.jl b/ext/BlockSparseArraysAdaptExt/BlockSparseArraysAdaptExt.jl similarity index 68% rename from ext/BlockSparseArraysAdaptExt/src/BlockSparseArraysAdaptExt.jl rename to ext/BlockSparseArraysAdaptExt/BlockSparseArraysAdaptExt.jl index 68cbf05e..b6a3d49d 100644 --- a/ext/BlockSparseArraysAdaptExt/src/BlockSparseArraysAdaptExt.jl +++ b/ext/BlockSparseArraysAdaptExt/BlockSparseArraysAdaptExt.jl @@ -1,5 +1,5 @@ module BlockSparseArraysAdaptExt using Adapt: Adapt, adapt -using ..BlockSparseArrays: AbstractBlockSparseArray, map_stored_blocks +using BlockSparseArrays: AbstractBlockSparseArray, map_stored_blocks Adapt.adapt_structure(to, x::AbstractBlockSparseArray) = map_stored_blocks(adapt(to), x) end diff --git a/ext/BlockSparseArraysGradedUnitRangesExt/test/Project.toml b/ext/BlockSparseArraysGradedUnitRangesExt/test/Project.toml deleted file mode 100644 index 7934869d..00000000 --- a/ext/BlockSparseArraysGradedUnitRangesExt/test/Project.toml +++ /dev/null @@ -1,3 +0,0 @@ -[deps] -BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/ext/BlockSparseArraysTensorAlgebraExt/src/BlockSparseArraysTensorAlgebraExt.jl b/ext/BlockSparseArraysTensorAlgebraExt/BlockSparseArraysGradedUnitRangesExt.jl similarity index 69% rename from ext/BlockSparseArraysTensorAlgebraExt/src/BlockSparseArraysTensorAlgebraExt.jl rename to ext/BlockSparseArraysTensorAlgebraExt/BlockSparseArraysGradedUnitRangesExt.jl index a734ded9..19a19b4f 100644 --- a/ext/BlockSparseArraysTensorAlgebraExt/src/BlockSparseArraysTensorAlgebraExt.jl +++ b/ext/BlockSparseArraysTensorAlgebraExt/BlockSparseArraysGradedUnitRangesExt.jl @@ -1,13 +1,8 @@ module BlockSparseArraysTensorAlgebraExt using BlockArrays: AbstractBlockedUnitRange -using ..BlockSparseArrays: AbstractBlockSparseArray, blockreshape -using GradedUnitRanges: tensor_product +using BlockSparseArrays: AbstractBlockSparseArray, blockreshape using TensorAlgebra: TensorAlgebra, FusionStyle, BlockReshapeFusion -function TensorAlgebra.:⊗(a1::AbstractBlockedUnitRange, a2::AbstractBlockedUnitRange) - return tensor_product(a1, a2) -end - TensorAlgebra.FusionStyle(::AbstractBlockedUnitRange) = BlockReshapeFusion() function TensorAlgebra.fusedims( diff --git a/ext/BlockSparseArraysGradedUnitRangesExt/src/BlockSparseArraysGradedUnitRangesExt.jl b/ext/BlockSparseArraysTensorAlgebraExt/BlockSparseArraysTensorAlgebraExt.jl similarity index 87% rename from ext/BlockSparseArraysGradedUnitRangesExt/src/BlockSparseArraysGradedUnitRangesExt.jl rename to ext/BlockSparseArraysTensorAlgebraExt/BlockSparseArraysTensorAlgebraExt.jl index 5846ce5a..b398c209 100644 --- a/ext/BlockSparseArraysGradedUnitRangesExt/src/BlockSparseArraysGradedUnitRangesExt.jl +++ b/ext/BlockSparseArraysTensorAlgebraExt/BlockSparseArraysTensorAlgebraExt.jl @@ -1,4 +1,30 @@ -module BlockSparseArraysGradedUnitRangesExt +module BlockSparseArraysTensorAlgebraExt +using BlockArrays: AbstractBlockedUnitRange +using GradedUnitRanges: tensor_product +using TensorAlgebra: TensorAlgebra, FusionStyle, BlockReshapeFusion + +function TensorAlgebra.:⊗(a1::AbstractBlockedUnitRange, a2::AbstractBlockedUnitRange) + return tensor_product(a1, a2) +end + +using BlockArrays: AbstractBlockedUnitRange +using BlockSparseArrays: AbstractBlockSparseArray, blockreshape +using TensorAlgebra: TensorAlgebra, FusionStyle, BlockReshapeFusion + +TensorAlgebra.FusionStyle(::AbstractBlockedUnitRange) = BlockReshapeFusion() + +function TensorAlgebra.fusedims( + ::BlockReshapeFusion, a::AbstractArray, axes::AbstractUnitRange... +) + return blockreshape(a, axes) +end + +function TensorAlgebra.splitdims( + ::BlockReshapeFusion, a::AbstractArray, axes::AbstractUnitRange... +) + return blockreshape(a, axes) +end + using BlockArrays: AbstractBlockVector, AbstractBlockedUnitRange, @@ -6,7 +32,7 @@ using BlockArrays: BlockIndexRange, blockedrange, blocks -using ..BlockSparseArrays: +using BlockSparseArrays: BlockSparseArrays, AbstractBlockSparseArray, AbstractBlockSparseArrayInterface, diff --git a/ext/BlockSparseArraysGradedUnitRangesExt/src/reducewhile.jl b/ext/BlockSparseArraysTensorAlgebraExt/reducewhile.jl similarity index 100% rename from ext/BlockSparseArraysGradedUnitRangesExt/src/reducewhile.jl rename to ext/BlockSparseArraysTensorAlgebraExt/reducewhile.jl diff --git a/ext/BlockSparseArraysTensorAlgebraExt/test/Project.toml b/ext/BlockSparseArraysTensorAlgebraExt/test/Project.toml deleted file mode 100644 index 7934869d..00000000 --- a/ext/BlockSparseArraysTensorAlgebraExt/test/Project.toml +++ /dev/null @@ -1,3 +0,0 @@ -[deps] -BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/src/BlockSparseArrays.jl b/src/BlockSparseArrays.jl index 41a1b930..3f54093e 100644 --- a/src/BlockSparseArrays.jl +++ b/src/BlockSparseArrays.jl @@ -22,9 +22,4 @@ include("abstractblocksparsearray/cat.jl") include("blocksparsearray/defaults.jl") include("blocksparsearray/blocksparsearray.jl") include("BlockArraysSparseArraysBaseExt/BlockArraysSparseArraysBaseExt.jl") -include("../ext/BlockSparseArraysTensorAlgebraExt/src/BlockSparseArraysTensorAlgebraExt.jl") -include( - "../ext/BlockSparseArraysGradedUnitRangesExt/src/BlockSparseArraysGradedUnitRangesExt.jl" -) -include("../ext/BlockSparseArraysAdaptExt/src/BlockSparseArraysAdaptExt.jl") end diff --git a/src/abstractblocksparsearray/map.jl b/src/abstractblocksparsearray/map.jl index 93ba526d..ef5d81e3 100644 --- a/src/abstractblocksparsearray/map.jl +++ b/src/abstractblocksparsearray/map.jl @@ -137,6 +137,22 @@ function Base.copyto!( return @interface interface(a_src) copyto!(a_dest, a_src) end +# This avoids going through the generic version that calls `Base.permutedims!`, +# which eventually calls block sparse `map!`, which involves slicing operations +# that are not friendly to GPU (since they involve `SubArray` wrapping +# `PermutedDimsArray`). +# TODO: Handle slicing better in `map!` so that this can be removed. +function Base.permutedims(a::AnyAbstractBlockSparseArray, perm) + @interface interface(a) permutedims(a, perm) +end + +# The `::AbstractBlockSparseArrayInterface` version +# has a special case for when `a_dest` and `PermutedDimsArray(a_src, perm)` +# have the same blocking, and therefore can just use: +# ```julia +# permutedims!(blocks(a_dest), blocks(a_src), perm) +# ``` +# TODO: Handle slicing better in `map!` so that this can be removed. function Base.permutedims!(a_dest, a_src::AnyAbstractBlockSparseArray, perm) return @interface interface(a_src) permutedims!(a_dest, a_src, perm) end diff --git a/src/blocksparsearrayinterface/blocksparsearrayinterface.jl b/src/blocksparsearrayinterface/blocksparsearrayinterface.jl index a3e391e0..b4583c6f 100644 --- a/src/blocksparsearrayinterface/blocksparsearrayinterface.jl +++ b/src/blocksparsearrayinterface/blocksparsearrayinterface.jl @@ -14,10 +14,17 @@ using BlockArrays: blocklengths, blocks, findblockindex -using Derive: Derive, @interface +using Derive: Derive, @interface, DefaultArrayInterface using LinearAlgebra: Adjoint, Transpose using SparseArraysBase: - AbstractSparseArrayInterface, eachstoredindex, perm, iperm, storedlength, storedvalues + AbstractSparseArrayInterface, + getstoredindex, + getunstoredindex, + eachstoredindex, + perm, + iperm, + storedlength, + storedvalues # Like `SparseArraysBase.eachstoredindex` but # at the block level, i.e. iterates over the @@ -154,6 +161,39 @@ end return a end +# Version of `permutedims!` that assumes the destination and source +# have the same blocking. +# TODO: Delete this and handle this logic in block sparse `map!`. +function blocksparse_permutedims!(a_dest::AbstractArray, a_src::AbstractArray, perm) + blocks(a_dest) .= blocks(PermutedDimsArray(a_src, perm)) + return a_dest +end + +# We overload `permutedims` here so that we can assume the destination and source +# have the same blocking and avoid non-GPU friendly slicing operations in block sparse `map!`. +# TODO: Delete this and handle this logic in block sparse `map!`. +@interface ::AbstractBlockSparseArrayInterface function Base.permutedims( + a::AbstractArray, perm +) + a_dest = similar(PermutedDimsArray(a, perm)) + blocksparse_permutedims!(a_dest, a, perm) + return a_dest +end + +# We overload `permutedims!` here so that we can special case when the destination and source +# have the same blocking and avoid non-GPU friendly slicing operations in block sparse `map!`. +# TODO: Delete this and handle this logic in block sparse `map!`. +@interface ::AbstractBlockSparseArrayInterface function Base.permutedims!( + a_dest::AbstractArray, a_src::AbstractArray, perm +) + if all(blockisequal.(axes(a_dest), axes(PermutedDimsArray(a_src, perm)))) + blocksparse_permutedims!(a_dest, a_src, perm) + return a_dest + end + @interface DefaultArrayInterface() permutedims!(a_dest, a_src, perm) + return a_dest +end + @interface ::AbstractBlockSparseArrayInterface function Base.fill!(a::AbstractArray, value) # TODO: Only do this check if `value isa Number`? if iszero(value) @@ -190,6 +230,7 @@ _getindices(i::CartesianIndex, indices) = CartesianIndex(_getindices(Tuple(i), i # Represents the array of arrays of a `PermutedDimsArray` # wrapping a block spare array, i.e. `blocks(array)` where `a` is a `PermutedDimsArray`. +# TODO: Delete this in favor of `NestedPermutedDimsArrays.NestedPermutedDimsArray`. struct SparsePermutedDimsArrayBlocks{ T,N,BlockType<:AbstractArray{T,N},Array<:PermutedDimsArray{T,N} } <: AbstractSparseArray{BlockType,N} @@ -203,23 +244,31 @@ end function Base.size(a::SparsePermutedDimsArrayBlocks) return _getindices(size(blocks(parent(a.array))), _perm(a.array)) end -function Base.getindex( +function SparseArraysBase.isstored( + a::SparsePermutedDimsArrayBlocks{<:Any,N}, index::Vararg{Int,N} +) where {N} + return isstored(blocks(parent(a.array)), _getindices(index, _invperm(a.array))...) +end +function SparseArraysBase.getstoredindex( a::SparsePermutedDimsArrayBlocks{<:Any,N}, index::Vararg{Int,N} ) where {N} return PermutedDimsArray( - blocks(parent(a.array))[_getindices(index, _invperm(a.array))...], _perm(a.array) + getstoredindex(blocks(parent(a.array)), _getindices(index, _invperm(a.array))...), + _perm(a.array), + ) +end +function SparseArraysBase.getunstoredindex( + a::SparsePermutedDimsArrayBlocks{<:Any,N}, index::Vararg{Int,N} +) where {N} + return PermutedDimsArray( + getunstoredindex(blocks(parent(a.array)), _getindices(index, _invperm(a.array))...), + _perm(a.array), ) end function SparseArraysBase.eachstoredindex(a::SparsePermutedDimsArrayBlocks) return map(I -> _getindices(I, _perm(a.array)), eachstoredindex(blocks(parent(a.array)))) end -# TODO: Either make this the generic interface or define -# `SparseArraysBase.sparse_storage`, which is used -# to defined this. -function SparseArraysBase.storedlength(a::SparsePermutedDimsArrayBlocks) - return length(eachstoredindex(a)) -end -## TODO: Delete. +## TODO: Define `storedvalues` instead. ## function SparseArraysBase.sparse_storage(a::SparsePermutedDimsArrayBlocks) ## return error("Not implemented") ## end diff --git a/test/basics/test_basics.jl b/test/basics/test_basics.jl index def25c8f..6bb31674 100644 --- a/test/basics/test_basics.jl +++ b/test/basics/test_basics.jl @@ -415,7 +415,7 @@ arrayts = (Array, JLArray) a[Block(3, 2, 2, 3)] = dev(randn(elt, 1, 2, 2, 1)) perm = (2, 3, 4, 1) for b in (PermutedDimsArray(a, perm), permutedims(a, perm)) - @test Array(b) == permutedims(Array(a), perm) + @test @allowscalar(Array(b)) == permutedims(Array(a), perm) @test issetequal(eachblockstoredindex(b), [Block(2, 2, 3, 3)]) @test @allowscalar b[Block(2, 2, 3, 3)] == permutedims(a[Block(3, 2, 2, 3)], perm) end diff --git a/test/basics/test_extensions.jl b/test/basics/test_extensions.jl deleted file mode 100644 index 3506e33e..00000000 --- a/test/basics/test_extensions.jl +++ /dev/null @@ -1,2 +0,0 @@ -include("../../ext/BlockSparseArraysTensorAlgebraExt/test/runtests.jl") -include("../../ext/BlockSparseArraysGradedUnitRangesExt/test/runtests.jl") diff --git a/ext/BlockSparseArraysGradedUnitRangesExt/test/runtests.jl b/test/basics/test_gradedunitrangesext.jl similarity index 100% rename from ext/BlockSparseArraysGradedUnitRangesExt/test/runtests.jl rename to test/basics/test_gradedunitrangesext.jl diff --git a/ext/BlockSparseArraysTensorAlgebraExt/test/runtests.jl b/test/basics/test_tensoralgebraext.jl similarity index 100% rename from ext/BlockSparseArraysTensorAlgebraExt/test/runtests.jl rename to test/basics/test_tensoralgebraext.jl