Skip to content

Commit abe5237

Browse files
authored
Set up package extensions, improve permutedims (#14)
1 parent d1958cc commit abe5237

File tree

14 files changed

+119
-38
lines changed

14 files changed

+119
-38
lines changed

Project.toml

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,21 @@ BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
1010
BroadcastMapConversion = "4a4adec5-520f-4750-bb37-d5e66b4ddeb2"
1111
Derive = "a07dfc7f-7d04-4eb5-84cc-a97f051f655a"
1212
Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4"
13-
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
1413
GradedUnitRanges = "e2de450a-8a67-46c7-b59c-01d5a3d041c5"
15-
LabelledNumbers = "f856a3a6-4152-4ec4-b2a7-02c1a55d7993"
1614
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1715
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
18-
NestedPermutedDimsArrays = "2c2a8ec4-3cfc-4276-aa3e-1307b4294e58"
1916
SparseArraysBase = "0d5efcca-f356-4864-8770-e1ed8d78f208"
2017
SplitApplyCombine = "03a91e81-4c3e-53e1-a0a4-9c0c8f19dd66"
21-
TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
2218
TypeParameterAccessors = "7e5a90cf-f82e-492e-a09b-e3e26432c138"
2319

20+
[weakdeps]
21+
LabelledNumbers = "f856a3a6-4152-4ec4-b2a7-02c1a55d7993"
22+
TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
23+
24+
[extensions]
25+
BlockSparseArraysAdaptExt = "Adapt"
26+
BlockSparseArraysTensorAlgebraExt = ["LabelledNumbers", "TensorAlgebra"]
27+
2428
[compat]
2529
Adapt = "4.1.1"
2630
Aqua = "0.8.9"
@@ -29,16 +33,20 @@ BlockArrays = "1.2.0"
2933
Derive = "0.3.1"
3034
Dictionaries = "0.4.3"
3135
GPUArraysCore = "0.1.0"
36+
GradedUnitRanges = "0.1.0"
37+
LabelledNumbers = "0.1.0"
3238
LinearAlgebra = "1.10"
3339
MacroTools = "0.5.13"
3440
SparseArraysBase = "0.2"
3541
SplitApplyCombine = "1.2.3"
3642
TensorAlgebra = "0.1.0"
43+
TypeParameterAccessors = "0.1.0"
3744
Test = "1.10"
3845
julia = "1.10"
3946

4047
[extras]
4148
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
49+
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
4250
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
4351

4452
[targets]
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
module BlockSparseArraysAdaptExt
22
using Adapt: Adapt, adapt
3-
using ..BlockSparseArrays: AbstractBlockSparseArray, map_stored_blocks
3+
using BlockSparseArrays: AbstractBlockSparseArray, map_stored_blocks
44
Adapt.adapt_structure(to, x::AbstractBlockSparseArray) = map_stored_blocks(adapt(to), x)
55
end

ext/BlockSparseArraysGradedUnitRangesExt/test/Project.toml

Lines changed: 0 additions & 3 deletions
This file was deleted.

ext/BlockSparseArraysTensorAlgebraExt/src/BlockSparseArraysTensorAlgebraExt.jl renamed to ext/BlockSparseArraysTensorAlgebraExt/BlockSparseArraysGradedUnitRangesExt.jl

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,8 @@
11
module BlockSparseArraysTensorAlgebraExt
22
using BlockArrays: AbstractBlockedUnitRange
3-
using ..BlockSparseArrays: AbstractBlockSparseArray, blockreshape
4-
using GradedUnitRanges: tensor_product
3+
using BlockSparseArrays: AbstractBlockSparseArray, blockreshape
54
using TensorAlgebra: TensorAlgebra, FusionStyle, BlockReshapeFusion
65

7-
function TensorAlgebra.:(a1::AbstractBlockedUnitRange, a2::AbstractBlockedUnitRange)
8-
return tensor_product(a1, a2)
9-
end
10-
116
TensorAlgebra.FusionStyle(::AbstractBlockedUnitRange) = BlockReshapeFusion()
127

138
function TensorAlgebra.fusedims(

ext/BlockSparseArraysGradedUnitRangesExt/src/BlockSparseArraysGradedUnitRangesExt.jl renamed to ext/BlockSparseArraysTensorAlgebraExt/BlockSparseArraysTensorAlgebraExt.jl

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,38 @@
1-
module BlockSparseArraysGradedUnitRangesExt
1+
module BlockSparseArraysTensorAlgebraExt
2+
using BlockArrays: AbstractBlockedUnitRange
3+
using GradedUnitRanges: tensor_product
4+
using TensorAlgebra: TensorAlgebra, FusionStyle, BlockReshapeFusion
5+
6+
function TensorAlgebra.:(a1::AbstractBlockedUnitRange, a2::AbstractBlockedUnitRange)
7+
return tensor_product(a1, a2)
8+
end
9+
10+
using BlockArrays: AbstractBlockedUnitRange
11+
using BlockSparseArrays: AbstractBlockSparseArray, blockreshape
12+
using TensorAlgebra: TensorAlgebra, FusionStyle, BlockReshapeFusion
13+
14+
TensorAlgebra.FusionStyle(::AbstractBlockedUnitRange) = BlockReshapeFusion()
15+
16+
function TensorAlgebra.fusedims(
17+
::BlockReshapeFusion, a::AbstractArray, axes::AbstractUnitRange...
18+
)
19+
return blockreshape(a, axes)
20+
end
21+
22+
function TensorAlgebra.splitdims(
23+
::BlockReshapeFusion, a::AbstractArray, axes::AbstractUnitRange...
24+
)
25+
return blockreshape(a, axes)
26+
end
27+
228
using BlockArrays:
329
AbstractBlockVector,
430
AbstractBlockedUnitRange,
531
Block,
632
BlockIndexRange,
733
blockedrange,
834
blocks
9-
using ..BlockSparseArrays:
35+
using BlockSparseArrays:
1036
BlockSparseArrays,
1137
AbstractBlockSparseArray,
1238
AbstractBlockSparseArrayInterface,

ext/BlockSparseArraysTensorAlgebraExt/test/Project.toml

Lines changed: 0 additions & 3 deletions
This file was deleted.

src/BlockSparseArrays.jl

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,4 @@ include("abstractblocksparsearray/cat.jl")
2222
include("blocksparsearray/defaults.jl")
2323
include("blocksparsearray/blocksparsearray.jl")
2424
include("BlockArraysSparseArraysBaseExt/BlockArraysSparseArraysBaseExt.jl")
25-
include("../ext/BlockSparseArraysTensorAlgebraExt/src/BlockSparseArraysTensorAlgebraExt.jl")
26-
include(
27-
"../ext/BlockSparseArraysGradedUnitRangesExt/src/BlockSparseArraysGradedUnitRangesExt.jl"
28-
)
29-
include("../ext/BlockSparseArraysAdaptExt/src/BlockSparseArraysAdaptExt.jl")
3025
end

src/abstractblocksparsearray/map.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,22 @@ function Base.copyto!(
137137
return @interface interface(a_src) copyto!(a_dest, a_src)
138138
end
139139

140+
# This avoids going through the generic version that calls `Base.permutedims!`,
141+
# which eventually calls block sparse `map!`, which involves slicing operations
142+
# that are not friendly to GPU (since they involve `SubArray` wrapping
143+
# `PermutedDimsArray`).
144+
# TODO: Handle slicing better in `map!` so that this can be removed.
145+
function Base.permutedims(a::AnyAbstractBlockSparseArray, perm)
146+
@interface interface(a) permutedims(a, perm)
147+
end
148+
149+
# The `::AbstractBlockSparseArrayInterface` version
150+
# has a special case for when `a_dest` and `PermutedDimsArray(a_src, perm)`
151+
# have the same blocking, and therefore can just use:
152+
# ```julia
153+
# permutedims!(blocks(a_dest), blocks(a_src), perm)
154+
# ```
155+
# TODO: Handle slicing better in `map!` so that this can be removed.
140156
function Base.permutedims!(a_dest, a_src::AnyAbstractBlockSparseArray, perm)
141157
return @interface interface(a_src) permutedims!(a_dest, a_src, perm)
142158
end

src/blocksparsearrayinterface/blocksparsearrayinterface.jl

Lines changed: 60 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,17 @@ using BlockArrays:
1414
blocklengths,
1515
blocks,
1616
findblockindex
17-
using Derive: Derive, @interface
17+
using Derive: Derive, @interface, DefaultArrayInterface
1818
using LinearAlgebra: Adjoint, Transpose
1919
using SparseArraysBase:
20-
AbstractSparseArrayInterface, eachstoredindex, perm, iperm, storedlength, storedvalues
20+
AbstractSparseArrayInterface,
21+
getstoredindex,
22+
getunstoredindex,
23+
eachstoredindex,
24+
perm,
25+
iperm,
26+
storedlength,
27+
storedvalues
2128

2229
# Like `SparseArraysBase.eachstoredindex` but
2330
# at the block level, i.e. iterates over the
@@ -154,6 +161,39 @@ end
154161
return a
155162
end
156163

164+
# Version of `permutedims!` that assumes the destination and source
165+
# have the same blocking.
166+
# TODO: Delete this and handle this logic in block sparse `map!`.
167+
function blocksparse_permutedims!(a_dest::AbstractArray, a_src::AbstractArray, perm)
168+
blocks(a_dest) .= blocks(PermutedDimsArray(a_src, perm))
169+
return a_dest
170+
end
171+
172+
# We overload `permutedims` here so that we can assume the destination and source
173+
# have the same blocking and avoid non-GPU friendly slicing operations in block sparse `map!`.
174+
# TODO: Delete this and handle this logic in block sparse `map!`.
175+
@interface ::AbstractBlockSparseArrayInterface function Base.permutedims(
176+
a::AbstractArray, perm
177+
)
178+
a_dest = similar(PermutedDimsArray(a, perm))
179+
blocksparse_permutedims!(a_dest, a, perm)
180+
return a_dest
181+
end
182+
183+
# We overload `permutedims!` here so that we can special case when the destination and source
184+
# have the same blocking and avoid non-GPU friendly slicing operations in block sparse `map!`.
185+
# TODO: Delete this and handle this logic in block sparse `map!`.
186+
@interface ::AbstractBlockSparseArrayInterface function Base.permutedims!(
187+
a_dest::AbstractArray, a_src::AbstractArray, perm
188+
)
189+
if all(blockisequal.(axes(a_dest), axes(PermutedDimsArray(a_src, perm))))
190+
blocksparse_permutedims!(a_dest, a_src, perm)
191+
return a_dest
192+
end
193+
@interface DefaultArrayInterface() permutedims!(a_dest, a_src, perm)
194+
return a_dest
195+
end
196+
157197
@interface ::AbstractBlockSparseArrayInterface function Base.fill!(a::AbstractArray, value)
158198
# TODO: Only do this check if `value isa Number`?
159199
if iszero(value)
@@ -190,6 +230,7 @@ _getindices(i::CartesianIndex, indices) = CartesianIndex(_getindices(Tuple(i), i
190230

191231
# Represents the array of arrays of a `PermutedDimsArray`
192232
# wrapping a block spare array, i.e. `blocks(array)` where `a` is a `PermutedDimsArray`.
233+
# TODO: Delete this in favor of `NestedPermutedDimsArrays.NestedPermutedDimsArray`.
193234
struct SparsePermutedDimsArrayBlocks{
194235
T,N,BlockType<:AbstractArray{T,N},Array<:PermutedDimsArray{T,N}
195236
} <: AbstractSparseArray{BlockType,N}
@@ -203,23 +244,31 @@ end
203244
function Base.size(a::SparsePermutedDimsArrayBlocks)
204245
return _getindices(size(blocks(parent(a.array))), _perm(a.array))
205246
end
206-
function Base.getindex(
247+
function SparseArraysBase.isstored(
248+
a::SparsePermutedDimsArrayBlocks{<:Any,N}, index::Vararg{Int,N}
249+
) where {N}
250+
return isstored(blocks(parent(a.array)), _getindices(index, _invperm(a.array))...)
251+
end
252+
function SparseArraysBase.getstoredindex(
207253
a::SparsePermutedDimsArrayBlocks{<:Any,N}, index::Vararg{Int,N}
208254
) where {N}
209255
return PermutedDimsArray(
210-
blocks(parent(a.array))[_getindices(index, _invperm(a.array))...], _perm(a.array)
256+
getstoredindex(blocks(parent(a.array)), _getindices(index, _invperm(a.array))...),
257+
_perm(a.array),
258+
)
259+
end
260+
function SparseArraysBase.getunstoredindex(
261+
a::SparsePermutedDimsArrayBlocks{<:Any,N}, index::Vararg{Int,N}
262+
) where {N}
263+
return PermutedDimsArray(
264+
getunstoredindex(blocks(parent(a.array)), _getindices(index, _invperm(a.array))...),
265+
_perm(a.array),
211266
)
212267
end
213268
function SparseArraysBase.eachstoredindex(a::SparsePermutedDimsArrayBlocks)
214269
return map(I -> _getindices(I, _perm(a.array)), eachstoredindex(blocks(parent(a.array))))
215270
end
216-
# TODO: Either make this the generic interface or define
217-
# `SparseArraysBase.sparse_storage`, which is used
218-
# to defined this.
219-
function SparseArraysBase.storedlength(a::SparsePermutedDimsArrayBlocks)
220-
return length(eachstoredindex(a))
221-
end
222-
## TODO: Delete.
271+
## TODO: Define `storedvalues` instead.
223272
## function SparseArraysBase.sparse_storage(a::SparsePermutedDimsArrayBlocks)
224273
## return error("Not implemented")
225274
## end

0 commit comments

Comments
 (0)