Skip to content

Commit 621935b

Browse files
authored
Update for TypeParameterAccessors.jl v0.4 (#138)
1 parent c297c03 commit 621935b

File tree

8 files changed

+62
-26
lines changed

8 files changed

+62
-26
lines changed

Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "BlockSparseArrays"
22
uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.7.5"
4+
version = "0.7.6"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -31,7 +31,7 @@ Adapt = "4.1.1"
3131
Aqua = "0.8.9"
3232
ArrayLayouts = "1.10.4"
3333
BlockArrays = "1.2.0"
34-
DerivableInterfaces = "0.5"
34+
DerivableInterfaces = "0.5.2"
3535
DiagonalArrays = "0.3"
3636
Dictionaries = "0.4.3"
3737
FillArrays = "1.13.0"
@@ -44,7 +44,7 @@ SparseArraysBase = "0.5"
4444
SplitApplyCombine = "1.2.3"
4545
TensorAlgebra = "0.3.2"
4646
Test = "1.10"
47-
TypeParameterAccessors = "0.2.0, 0.3"
47+
TypeParameterAccessors = "0.4"
4848
julia = "1.10"
4949

5050
[extras]

src/BlockArraysExtensions/blockedunitrange.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,16 @@ axis(a::AbstractVector) = axes(a, 1)
2929
function eachblockaxis(a::AbstractVector)
3030
return map(axis, blocks(a))
3131
end
32+
function blockaxistype(a::AbstractVector)
33+
return eltype(eachblockaxis(a))
34+
end
3235

3336
# Take a collection of axes and mortar them
3437
# into a single blocked axis.
3538
function mortar_axis(axs)
39+
return blockrange(axs)
40+
end
41+
function mortar_axis(axs::Vector{<:Base.OneTo{<:Integer}})
3642
return blockedrange(length.(axs))
3743
end
3844

src/abstractblocksparsearray/abstractblocksparsearray.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,12 @@ end
1919

2020
# Specialized in order to fix ambiguity error with `BlockArrays`.
2121
function Base.getindex(a::AbstractBlockSparseArray{<:Any,N}, I::Vararg{Int,N}) where {N}
22-
return @interface BlockSparseArrayInterface() getindex(a, I...)
22+
return @interface interface(a) getindex(a, I...)
2323
end
2424

2525
# Specialized in order to fix ambiguity error with `BlockArrays`.
2626
function Base.getindex(a::AbstractBlockSparseArray{<:Any,0})
27-
return @interface BlockSparseArrayInterface() getindex(a)
27+
return @interface interface(a) getindex(a)
2828
end
2929

3030
## # Fix ambiguity error with `BlockArrays`.
@@ -39,21 +39,21 @@ end
3939
##
4040
## # Fix ambiguity error with `BlockArrays`.
4141
## function Base.getindex(a::AbstractBlockSparseArray, I::Vararg{AbstractVector})
42-
## ## return @interface BlockSparseArrayInterface() getindex(a, I...)
42+
## ## return @interface interface(a) getindex(a, I...)
4343
## return ArrayLayouts.layout_getindex(a, I...)
4444
## end
4545

4646
# Specialized in order to fix ambiguity error with `BlockArrays`.
4747
function Base.setindex!(
4848
a::AbstractBlockSparseArray{<:Any,N}, value, I::Vararg{Int,N}
4949
) where {N}
50-
@interface BlockSparseArrayInterface() setindex!(a, value, I...)
50+
@interface interface(a) setindex!(a, value, I...)
5151
return a
5252
end
5353

5454
# Fix ambiguity error.
5555
function Base.setindex!(a::AbstractBlockSparseArray{<:Any,0}, value)
56-
@interface BlockSparseArrayInterface() setindex!(a, value)
56+
@interface interface(a) setindex!(a, value)
5757
return a
5858
end
5959

src/abstractblocksparsearray/arraylayouts.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,8 @@ function Base.similar(
2727
elt::Type,
2828
axes,
2929
) where {A,B}
30-
# TODO: Check that this equals `similartype(blocktype(B), elt, axes)`,
31-
# or maybe promote them?
32-
output_blocktype = similartype(blocktype(A), elt, axes)
30+
# TODO: Use something like `Base.promote_op(*, A, B)` to determine the output block type.
31+
output_blocktype = similartype(blocktype(A), Type{elt}, Tuple{blockaxistype.(axes)...})
3332
return similar(BlockSparseArray{elt,length(axes),output_blocktype}, axes)
3433
end
3534

src/abstractblocksparsearray/unblockedsubarray.jl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,6 @@ function Broadcast.BroadcastStyle(arraytype::Type{<:UnblockedSubArray})
3030
return BroadcastStyle(blocktype(parenttype(arraytype)))
3131
end
3232

33-
function TypeParameterAccessors.similartype(arraytype::Type{<:UnblockedSubArray}, elt::Type)
34-
return similartype(blocktype(parenttype(arraytype)), elt)
35-
end
36-
3733
function Base.similar(
3834
a::UnblockedSubArray, elt::Type, axes::Tuple{Base.OneTo,Vararg{Base.OneTo}}
3935
)

src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ const AnyAbstractBlockSparseVecOrMat{T,N} = Union{
2929
AnyAbstractBlockSparseVector{T},AnyAbstractBlockSparseMatrix{T}
3030
}
3131

32-
function DerivableInterfaces.interface(::Type{<:AnyAbstractBlockSparseArray})
33-
return BlockSparseArrayInterface()
32+
function DerivableInterfaces.interface(arrayt::Type{<:AnyAbstractBlockSparseArray})
33+
return BlockSparseArrayInterface(interface(blocktype(arrayt)))
3434
end
3535

3636
# a[1:2, 1:2]
@@ -231,9 +231,9 @@ function Base.similar(
231231
end
232232

233233
function blocksparse_similar(a, elt::Type, axes::Tuple)
234-
return BlockSparseArray{elt,length(axes),similartype(blocktype(a), elt, axes)}(
235-
undef, axes
236-
)
234+
ndims = length(axes)
235+
blockt = similartype(blocktype(a), Type{elt}, Tuple{blockaxistype.(axes)...})
236+
return BlockSparseArray{elt,ndims,blockt}(undef, axes)
237237
end
238238
@interface ::AbstractBlockSparseArrayInterface function Base.similar(
239239
a::AbstractArray, elt::Type, axes::Tuple{Vararg{Int}}

src/blocksparsearrayinterface/blocksparsearrayinterface.jl

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,13 @@ using BlockArrays:
1616
blocklength,
1717
blocks,
1818
findblockindex
19-
using DerivableInterfaces: DerivableInterfaces, @interface, DefaultArrayInterface, zero!
19+
using DerivableInterfaces:
20+
DerivableInterfaces,
21+
@interface,
22+
AbstractArrayInterface,
23+
DefaultArrayInterface,
24+
interface,
25+
zero!
2026
using LinearAlgebra: Adjoint, Transpose
2127
using SparseArraysBase:
2228
AbstractSparseArrayInterface,
@@ -101,18 +107,47 @@ blockstype(a::BlockArray) = blockstype(typeof(a))
101107
blocktype(arraytype::Type{<:BlockArray}) = eltype(blockstype(arraytype))
102108
blocktype(a::BlockArray) = eltype(blocks(a))
103109

104-
abstract type AbstractBlockSparseArrayInterface{N} <: AbstractSparseArrayInterface{N} end
110+
abstract type AbstractBlockSparseArrayInterface{N,B<:AbstractArrayInterface{N}} <:
111+
AbstractSparseArrayInterface{N} end
112+
113+
function blockinterface(interface::AbstractBlockSparseArrayInterface{<:Any,B}) where {B}
114+
return B()
115+
end
105116

106117
# TODO: Also support specifying the `blocktype` along with the `eltype`.
107-
function Base.similar(::AbstractBlockSparseArrayInterface, T::Type, ax::Tuple)
108-
return similar(BlockSparseArray{T}, ax)
118+
function Base.similar(interface::AbstractBlockSparseArrayInterface, T::Type, ax::Tuple)
119+
# TODO: Generalize by storing the block interface in the block sparse array interface.
120+
N = length(ax)
121+
B = similartype(typeof(blockinterface(interface)), Type{T}, Tuple{blockaxistype.(ax)...})
122+
return similar(BlockSparseArray{T,N,B}, ax)
109123
end
110124

111-
struct BlockSparseArrayInterface{N} <: AbstractBlockSparseArrayInterface{N} end
125+
struct BlockSparseArrayInterface{N,B<:AbstractArrayInterface{N}} <:
126+
AbstractBlockSparseArrayInterface{N,B}
127+
blockinterface::B
128+
end
129+
function BlockSparseArrayInterface{N}(blockinterface::AbstractArrayInterface{N}) where {N}
130+
return BlockSparseArrayInterface{N,typeof(blockinterface)}(blockinterface)
131+
end
132+
function BlockSparseArrayInterface{M,B}(::Val{N}) where {M,B<:AbstractArrayInterface{M},N}
133+
B′ = B(Val(N))
134+
return BlockSparseArrayInterface(B′)
135+
end
136+
function BlockSparseArrayInterface{N}() where {N}
137+
return BlockSparseArrayInterface{N}(DefaultArrayInterface{N}())
138+
end
112139
BlockSparseArrayInterface(::Val{N}) where {N} = BlockSparseArrayInterface{N}()
113140
BlockSparseArrayInterface{M}(::Val{N}) where {M,N} = BlockSparseArrayInterface{N}()
114141
BlockSparseArrayInterface() = BlockSparseArrayInterface{Any}()
115142

143+
function DerivableInterfaces.combine_interface_rule(
144+
interface1::AbstractBlockSparseArrayInterface,
145+
interface2::AbstractBlockSparseArrayInterface,
146+
)
147+
B = interface(blockinterface(interface1), blockinterface(interface2))
148+
return BlockSparseArrayInterface(B)
149+
end
150+
116151
@interface ::AbstractBlockSparseArrayInterface function BlockArrays.blocks(a::AbstractArray)
117152
return error("Not implemented")
118153
end

test/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,4 +38,4 @@ Suppressor = "0.2"
3838
TensorAlgebra = "0.3.2"
3939
Test = "1"
4040
TestExtras = "0.3"
41-
TypeParameterAccessors = "0.3"
41+
TypeParameterAccessors = "0.4"

0 commit comments

Comments
 (0)