Skip to content

Commit fe9c657

Browse files
committed
More generalizations
1 parent addaf0b commit fe9c657

File tree

7 files changed

+48
-31
lines changed

7 files changed

+48
-31
lines changed

src/abstractblocksparsearray/abstractblocksparsearray.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,12 @@ end
1919

2020
# Specialized in order to fix ambiguity error with `BlockArrays`.
2121
function Base.getindex(a::AbstractBlockSparseArray{<:Any,N}, I::Vararg{Int,N}) where {N}
22-
return @interface BlockSparseArrayInterface() getindex(a, I...)
22+
return @interface interface(a) getindex(a, I...)
2323
end
2424

2525
# Specialized in order to fix ambiguity error with `BlockArrays`.
2626
function Base.getindex(a::AbstractBlockSparseArray{<:Any,0})
27-
return @interface BlockSparseArrayInterface() getindex(a)
27+
return @interface interface(a) getindex(a)
2828
end
2929

3030
## # Fix ambiguity error with `BlockArrays`.
@@ -53,7 +53,7 @@ end
5353

5454
# Fix ambiguity error.
5555
function Base.setindex!(a::AbstractBlockSparseArray{<:Any,0}, value)
56-
@interface BlockSparseArrayInterface() setindex!(a, value)
56+
@interface interface(a) setindex!(a, value)
5757
return a
5858
end
5959

src/abstractblocksparsearray/broadcast.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ using BlockArrays: AbstractBlockedUnitRange, BlockSlice
22
using Base.Broadcast: Broadcast
33

44
function Broadcast.BroadcastStyle(arraytype::Type{<:AnyAbstractBlockSparseArray})
5-
return BlockSparseArrayStyle{ndims(arraytype)}()
5+
return BlockSparseArrayStyle(BroadcastStyle(blocktype(arraytype)))
66
end
77

88
# Fix ambiguity error with `BlockArrays`.

src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ const AnyAbstractBlockSparseVecOrMat{T,N} = Union{
2929
AnyAbstractBlockSparseVector{T},AnyAbstractBlockSparseMatrix{T}
3030
}
3131

32-
function DerivableInterfaces.interface(::Type{<:AnyAbstractBlockSparseArray})
33-
return BlockSparseArrayInterface()
32+
function DerivableInterfaces.interface(arrayt::Type{<:AnyAbstractBlockSparseArray})
33+
return BlockSparseArrayInterface(interface(blocktype(arrayt)))
3434
end
3535

3636
# a[1:2, 1:2]
@@ -88,7 +88,7 @@ end
8888

8989
# BlockArrays `AbstractBlockArray` interface
9090
function BlockArrays.blocks(a::AnyAbstractBlockSparseArray)
91-
@interface BlockSparseArrayInterface() blocks(a)
91+
@interface interface(a) blocks(a)
9292
end
9393

9494
# Fix ambiguity error with `BlockArrays`
@@ -284,7 +284,7 @@ function Base.similar(
284284
elt::Type,
285285
axes::Tuple{Vararg{AbstractUnitRange{<:Integer}}},
286286
)
287-
return @interface BlockSparseArrayInterface() similar(arraytype, elt, axes)
287+
return @interface interface(arraytype) similar(arraytype, elt, axes)
288288
end
289289

290290
# TODO: Define a `@interface BlockSparseArrayInterface() similar` function.
@@ -311,8 +311,7 @@ function Base.similar(
311311
AbstractBlockedUnitRange{<:Integer},Vararg{AbstractBlockedUnitRange{<:Integer}}
312312
},
313313
)
314-
# TODO: Use `@interface interface(a) similar(...)`.
315-
return @interface BlockSparseArrayInterface() similar(a, elt, axes)
314+
return @interface interface(a) similar(a, elt, axes)
316315
end
317316

318317
# Fixes ambiguity error with `OffsetArrays`.
@@ -321,8 +320,7 @@ function Base.similar(
321320
elt::Type,
322321
axes::Tuple{AbstractUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}},
323322
)
324-
# TODO: Use `@interface interface(a) similar(...)`.
325-
return @interface BlockSparseArrayInterface() similar(a, elt, axes)
323+
return @interface interface(a) similar(a, elt, axes)
326324
end
327325

328326
# Fixes ambiguity error with `BlockArrays`.
@@ -331,8 +329,7 @@ function Base.similar(
331329
elt::Type,
332330
axes::Tuple{AbstractBlockedUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}},
333331
)
334-
# TODO: Use `@interface interface(a) similar(...)`.
335-
return @interface BlockSparseArrayInterface() similar(a, elt, axes)
332+
return @interface interface(a) similar(a, elt, axes)
336333
end
337334

338335
# Fixes ambiguity errors with BlockArrays.
@@ -345,16 +342,14 @@ function Base.similar(
345342
Vararg{AbstractUnitRange{<:Integer}},
346343
},
347344
)
348-
# TODO: Use `@interface interface(a) similar(...)`.
349-
return @interface BlockSparseArrayInterface() similar(a, elt, axes)
345+
return @interface interface(a) similar(a, elt, axes)
350346
end
351347

352348
# Fixes ambiguity error with `StaticArrays`.
353349
function Base.similar(
354350
a::AnyAbstractBlockSparseArray, elt::Type, axes::Tuple{Base.OneTo,Vararg{Base.OneTo}}
355351
)
356-
# TODO: Use `@interface interface(a) similar(...)`.
357-
return @interface BlockSparseArrayInterface() similar(a, elt, axes)
352+
return @interface interface(a) similar(a, elt, axes)
358353
end
359354

360355
struct BlockType{T} end

src/blocksparsearray/blocksparsearray.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,10 @@ Base.axes(a::BlockSparseArray) = a.axes
254254
@interface ::AbstractBlockSparseArrayInterface BlockArrays.blocks(a::BlockSparseArray) =
255255
a.blocks
256256

257+
function blocktype(arraytype::Type{<:BlockSparseArray{<:Any,<:Any,A}}) where {A}
258+
return A
259+
end
260+
257261
# TODO: Use `TypeParameterAccessors`.
258262
function blockstype(
259263
arraytype::Type{<:BlockSparseArray{T,N,A,Blocks}}

src/blocksparsearrayinterface/arraylayouts.jl

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using ArrayLayouts: ArrayLayouts, Dot, MatMulMatAdd, MatMulVecAdd, MulAdd
22
using BlockArrays: BlockArrays, BlockLayout, muladd!
3-
using DerivableInterfaces: @interface
3+
using DerivableInterfaces: DerivableInterfaces, @interface, interface
44
using SparseArraysBase: SparseLayout
55
using LinearAlgebra: LinearAlgebra, dot, mul!
66

@@ -11,15 +11,18 @@ using LinearAlgebra: LinearAlgebra, dot, mul!
1111
return a_dest
1212
end
1313

14+
function DerivableInterfaces.interface(m::MulAdd)
15+
return interface(m.A, m.B, m.C)
16+
end
17+
1418
function ArrayLayouts.materialize!(
1519
m::MatMulMatAdd{
1620
<:BlockLayout{<:SparseLayout},
1721
<:BlockLayout{<:SparseLayout},
1822
<:BlockLayout{<:SparseLayout},
1923
},
2024
)
21-
α, a1, a2, β, a_dest = m.α, m.A, m.B, m.β, m.C
22-
@interface BlockSparseArrayInterface() muladd!(m.α, m.A, m.B, m.β, m.C)
25+
@interface interface(m) muladd!(m.α, m.A, m.B, m.β, m.C)
2326
return m.C
2427
end
2528
function ArrayLayouts.materialize!(
@@ -29,7 +32,7 @@ function ArrayLayouts.materialize!(
2932
<:BlockLayout{<:SparseLayout},
3033
},
3134
)
32-
@interface BlockSparseArrayInterface() matmul!(m)
35+
@interface interface(m) matmul!(m)
3336
return m.C
3437
end
3538

@@ -42,5 +45,5 @@ end
4245
end
4346

4447
function Base.copy(d::Dot{<:BlockLayout{<:SparseLayout},<:BlockLayout{<:SparseLayout}})
45-
return @interface BlockSparseArrayInterface() dot(d.A, d.B)
48+
return @interface interface(d.A, d.B) dot(d.A, d.B)
4649
end

src/blocksparsearrayinterface/blocksparsearrayinterface.jl

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ using BlockArrays:
1616
blocklength,
1717
blocks,
1818
findblockindex
19-
using DerivableInterfaces: DerivableInterfaces, @interface, DefaultArrayInterface, zero!
19+
using DerivableInterfaces:
20+
DerivableInterfaces, @interface, AbstractArrayInterface, DefaultArrayInterface, zero!
2021
using LinearAlgebra: Adjoint, Transpose
2122
using SparseArraysBase:
2223
AbstractSparseArrayInterface,
@@ -104,11 +105,17 @@ blocktype(a::BlockArray) = eltype(blocks(a))
104105
abstract type AbstractBlockSparseArrayInterface <: AbstractSparseArrayInterface end
105106

106107
# TODO: Also support specifying the `blocktype` along with the `eltype`.
107-
function DerivableInterfaces.arraytype(::AbstractBlockSparseArrayInterface, T::Type)
108-
return BlockSparseArray{T}
108+
function DerivableInterfaces.arraytype(
109+
interface::AbstractBlockSparseArrayInterface, T::Type
110+
)
111+
B = DerivableInterfaces.arraytype(interface.blockinterface, T)
112+
return BlockSparseArray{T,<:Any,B}
109113
end
110114

111-
struct BlockSparseArrayInterface <: AbstractBlockSparseArrayInterface end
115+
struct BlockSparseArrayInterface{B<:AbstractArrayInterface} <:
116+
AbstractBlockSparseArrayInterface
117+
blockinterface::B
118+
end
112119

113120
@interface ::AbstractBlockSparseArrayInterface BlockArrays.blocks(a::AbstractArray) = error(
114121
"Not implemented"

src/blocksparsearrayinterface/broadcast.jl

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,21 @@ using GPUArraysCore: @allowscalar
33
using MapBroadcast: Mapped
44
using DerivableInterfaces: DerivableInterfaces, @interface
55

6-
abstract type AbstractBlockSparseArrayStyle{N} <: AbstractArrayStyle{N} end
6+
abstract type AbstractBlockSparseArrayStyle{N,B} <: AbstractArrayStyle{N} end
77

8-
function DerivableInterfaces.interface(::Type{<:AbstractBlockSparseArrayStyle})
9-
return BlockSparseArrayInterface()
8+
function DerivableInterfaces.interface(
9+
::Type{<:AbstractBlockSparseArrayStyle{N,B}}
10+
) where {N,B}
11+
return BlockSparseArrayInterface(interface(B))
1012
end
1113

12-
struct BlockSparseArrayStyle{N} <: AbstractBlockSparseArrayStyle{N} end
14+
struct BlockSparseArrayStyle{N,B<:AbstractArrayStyle{N}} <:
15+
AbstractBlockSparseArrayStyle{N,B}
16+
blockstyle::B
17+
end
18+
function BlockSparseArrayStyle{N}(blockstyle::AbstractArrayStyle{N}) where {N}
19+
return BlockSparseArrayStyle{N,typeof(blockstyle)}(blockstyle)
20+
end
1321

1422
# Define for new sparse array types.
1523
# function Broadcast.BroadcastStyle(arraytype::Type{<:MyBlockSparseArray})

0 commit comments

Comments
 (0)