Skip to content

Commit 7372c2b

Browse files
committed
More general block types in broadcast style
1 parent f991254 commit 7372c2b

File tree

3 files changed

+19
-7
lines changed

3 files changed

+19
-7
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "BlockSparseArrays"
22
uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.7.10"
4+
version = "0.7.11"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/abstractblocksparsearray/broadcast.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
using BlockArrays: AbstractBlockedUnitRange, BlockSlice
2-
using Base.Broadcast: Broadcast
2+
using Base.Broadcast: Broadcast, BroadcastStyle
33

44
function Broadcast.BroadcastStyle(arraytype::Type{<:AnyAbstractBlockSparseArray})
5-
return BlockSparseArrayStyle{ndims(arraytype)}()
5+
return BlockSparseArrayStyle(BroadcastStyle(blocktype(arraytype)))
66
end
77

88
# Fix ambiguity error with `BlockArrays`.

src/blocksparsearrayinterface/broadcast.jl

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,33 @@ using GPUArraysCore: @allowscalar
33
using MapBroadcast: Mapped
44
using DerivableInterfaces: DerivableInterfaces, @interface
55

6-
abstract type AbstractBlockSparseArrayStyle{N} <: AbstractArrayStyle{N} end
6+
abstract type AbstractBlockSparseArrayStyle{N,B} <: AbstractArrayStyle{N} end
77

8-
function DerivableInterfaces.interface(::Type{<:AbstractBlockSparseArrayStyle})
9-
return BlockSparseArrayInterface()
8+
function DerivableInterfaces.interface(
9+
::Type{<:AbstractBlockSparseArrayStyle{N,B}}
10+
) where {N,B}
11+
return BlockSparseArrayInterface(interface(B))
1012
end
1113

12-
struct BlockSparseArrayStyle{N} <: AbstractBlockSparseArrayStyle{N} end
14+
struct BlockSparseArrayStyle{N,B<:AbstractArrayStyle{N}} <:
15+
AbstractBlockSparseArrayStyle{N,B}
16+
blockstyle::B
17+
end
18+
function BlockSparseArrayStyle{N}(blockstyle::AbstractArrayStyle{N}) where {N}
19+
return BlockSparseArrayStyle{N,typeof(blockstyle)}(blockstyle)
20+
end
1321

1422
# Define for new sparse array types.
1523
# function Broadcast.BroadcastStyle(arraytype::Type{<:MyBlockSparseArray})
1624
# return BlockSparseArrayStyle{ndims(arraytype)}()
1725
# end
1826

27+
BlockSparseArrayStyle{N}() where {N} = BlockSparseArrayStyle{N}(DefaultArrayStyle{N}())
1928
BlockSparseArrayStyle(::Val{N}) where {N} = BlockSparseArrayStyle{N}()
2029
BlockSparseArrayStyle{M}(::Val{N}) where {M,N} = BlockSparseArrayStyle{N}()
30+
function BlockSparseArrayStyle{M,B}(::Val{N}) where {M,B<:AbstractArrayStyle{M},N}
31+
return BlockSparseArrayStyle{N}(B(Val(N)))
32+
end
2133

2234
Broadcast.BroadcastStyle(a::BlockSparseArrayStyle, ::DefaultArrayStyle{0}) = a
2335
function Broadcast.BroadcastStyle(

0 commit comments

Comments
 (0)