@@ -3,21 +3,33 @@ using GPUArraysCore: @allowscalar
3
3
using MapBroadcast: Mapped
4
4
using DerivableInterfaces: DerivableInterfaces, @interface
5
5
6
- abstract type AbstractBlockSparseArrayStyle{N} <: AbstractArrayStyle{N} end
6
+ abstract type AbstractBlockSparseArrayStyle{N,B } <: AbstractArrayStyle{N} end
7
7
8
- function DerivableInterfaces. interface (:: Type{<:AbstractBlockSparseArrayStyle} )
9
- return BlockSparseArrayInterface ()
8
+ function DerivableInterfaces. interface (
9
+ :: Type{<:AbstractBlockSparseArrayStyle{N,B}}
10
+ ) where {N,B}
11
+ return BlockSparseArrayInterface (interface (B))
10
12
end
11
13
12
- struct BlockSparseArrayStyle{N} <: AbstractBlockSparseArrayStyle{N} end
14
+ struct BlockSparseArrayStyle{N,B<: AbstractArrayStyle{N} } < :
15
+ AbstractBlockSparseArrayStyle{N,B}
16
+ blockstyle:: B
17
+ end
18
+ function BlockSparseArrayStyle {N} (blockstyle:: AbstractArrayStyle{N} ) where {N}
19
+ return BlockSparseArrayStyle {N,typeof(blockstyle)} (blockstyle)
20
+ end
13
21
14
22
# Define for new sparse array types.
15
23
# function Broadcast.BroadcastStyle(arraytype::Type{<:MyBlockSparseArray})
16
24
# return BlockSparseArrayStyle{ndims(arraytype)}()
17
25
# end
18
26
27
+ BlockSparseArrayStyle {N} () where {N} = BlockSparseArrayStyle {N} (DefaultArrayStyle {N} ())
19
28
BlockSparseArrayStyle (:: Val{N} ) where {N} = BlockSparseArrayStyle {N} ()
20
29
BlockSparseArrayStyle {M} (:: Val{N} ) where {M,N} = BlockSparseArrayStyle {N} ()
30
+ function BlockSparseArrayStyle {M,B} (:: Val{N} ) where {M,B<: AbstractArrayStyle{M} ,N}
31
+ return BlockSparseArrayStyle {N} (B (Val (N)))
32
+ end
21
33
22
34
Broadcast. BroadcastStyle (a:: BlockSparseArrayStyle , :: DefaultArrayStyle{0} ) = a
23
35
function Broadcast. BroadcastStyle (
0 commit comments