Skip to content

Commit 7ecf6eb

Browse files
committed
Define mixing block sparse array styles
1 parent 7372c2b commit 7ecf6eb

File tree

1 file changed

+12
-3
lines changed

1 file changed

+12
-3
lines changed

src/blocksparsearrayinterface/broadcast.jl

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,22 @@
1-
using Base.Broadcast: BroadcastStyle, AbstractArrayStyle, DefaultArrayStyle, Broadcasted
1+
using Base.Broadcast:
2+
Broadcast, BroadcastStyle, AbstractArrayStyle, DefaultArrayStyle, Broadcasted
23
using GPUArraysCore: @allowscalar
34
using MapBroadcast: Mapped
45
using DerivableInterfaces: DerivableInterfaces, @interface
56

6-
abstract type AbstractBlockSparseArrayStyle{N,B} <: AbstractArrayStyle{N} end
7+
abstract type AbstractBlockSparseArrayStyle{N,B<:AbstractArrayStyle{N}} <:
8+
AbstractArrayStyle{N} end
9+
blockstyle(::AbstractBlockSparseArrayStyle{<:Any,B}) where {<:Any,B} = B()
10+
11+
function Broadcast.BroadcastStyle(
12+
style1::AbstractBlockSparseArrayStyle, style2::AbstractBlockSparseArrayStyle
13+
)
14+
return BlockSparseArrayStyle(BroadcastStyle(blockstyle(style1), blockstyle(style2)))
15+
end
716

817
function DerivableInterfaces.interface(
918
::Type{<:AbstractBlockSparseArrayStyle{N,B}}
10-
) where {N,B}
19+
) where {N,B<:AbstractArrayStyle{N}}
1120
return BlockSparseArrayInterface(interface(B))
1221
end
1322

0 commit comments

Comments
 (0)