Skip to content

Commit 259c763

Browse files
committed
Define combining block sparse interfaces
1 parent 83a6108 commit 259c763

File tree

2 files changed

+18
-4
lines changed

2 files changed

+18
-4
lines changed

src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ const AnyAbstractBlockSparseVecOrMat{T,N} = Union{
2929
AnyAbstractBlockSparseVector{T},AnyAbstractBlockSparseMatrix{T}
3030
}
3131

32-
function DerivableInterfaces.interface(::Type{<:AnyAbstractBlockSparseArray})
33-
return BlockSparseArrayInterface()
32+
function DerivableInterfaces.interface(arrayt::Type{<:AnyAbstractBlockSparseArray})
33+
return BlockSparseArrayInterface(interface(blocktype(arrayt)))
3434
end
3535

3636
# a[1:2, 1:2]

src/blocksparsearrayinterface/blocksparsearrayinterface.jl

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,12 @@ using BlockArrays:
1717
blocks,
1818
findblockindex
1919
using DerivableInterfaces:
20-
DerivableInterfaces, @interface, AbstractArrayInterface, DefaultArrayInterface, zero!
20+
DerivableInterfaces,
21+
@interface,
22+
AbstractArrayInterface,
23+
DefaultArrayInterface,
24+
interface,
25+
zero!
2126
using LinearAlgebra: Adjoint, Transpose
2227
using SparseArraysBase:
2328
AbstractSparseArrayInterface,
@@ -125,7 +130,8 @@ function BlockSparseArrayInterface{N}(blockinterface::AbstractArrayInterface{N})
125130
return BlockSparseArrayInterface{N,typeof(blockinterface)}(blockinterface)
126131
end
127132
function BlockSparseArrayInterface{M,B}(::Val{N}) where {M,B<:AbstractArrayInterface{M},N}
128-
return BlockSparseArrayInterface{N}(B(Val(N)))
133+
B′ = B(Val(N))
134+
return BlockSparseArrayInterface(B′)
129135
end
130136
function BlockSparseArrayInterface{N}() where {N}
131137
return BlockSparseArrayInterface{N}(DefaultArrayInterface{N}())
@@ -134,6 +140,14 @@ BlockSparseArrayInterface(::Val{N}) where {N} = BlockSparseArrayInterface{N}()
134140
BlockSparseArrayInterface{M}(::Val{N}) where {M,N} = BlockSparseArrayInterface{N}()
135141
BlockSparseArrayInterface() = BlockSparseArrayInterface{Any}()
136142

143+
function DerivableInterfaces.combine_interface_rule(
144+
interface1::AbstractBlockSparseArrayInterface,
145+
interface2::AbstractBlockSparseArrayInterface,
146+
)
147+
B = interface(blockinterface(interface1), blockinterface(interface2))
148+
return BlockSparseArrayInterface(B)
149+
end
150+
137151
@interface ::AbstractBlockSparseArrayInterface function BlockArrays.blocks(a::AbstractArray)
138152
return error("Not implemented")
139153
end

0 commit comments

Comments
 (0)