Skip to content

Commit 8656a56

Browse files
committed
[WIP] [BlockSparseArrays] Define more constructors
1 parent 2e5da07 commit 8656a56

File tree

3 files changed

+112
-2
lines changed

3 files changed

+112
-2
lines changed

NDTensors/src/lib/BlockSparseArrays/src/blocksparsearray/blocksparsearray.jl

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,15 @@ struct BlockSparseArray{
1616
axes::Axes
1717
end
1818

19-
const BlockSparseMatrix{T,A,Blocks,Axes} = BlockSparseArray{T,2,A,Blocks,Axes}
20-
const BlockSparseVector{T,A,Blocks,Axes} = BlockSparseArray{T,1,A,Blocks,Axes}
19+
# TODO: Can this definition be shortened?
20+
const BlockSparseMatrix{T,A<:AbstractMatrix{T},Blocks<:AbstractMatrix{A},Axes<:Tuple{AbstractUnitRange,AbstractUnitRange}} = BlockSparseArray{
21+
T,2,A,Blocks,Axes
22+
}
23+
24+
# TODO: Can this definition be shortened?
25+
const BlockSparseVector{T,A<:AbstractVector{T},Blocks<:AbstractVector{A},Axes<:Tuple{AbstractUnitRange}} = BlockSparseArray{
26+
T,1,A,Blocks,Axes
27+
}
2128

2229
function BlockSparseArray(
2330
block_data::Dictionary{<:Block{N},<:AbstractArray{<:Any,N}},
@@ -72,6 +79,10 @@ function BlockSparseArray{T,N}(axes::Tuple{Vararg{AbstractUnitRange,N}}) where {
7279
return BlockSparseArray{T,N,default_arraytype(T, axes)}(axes)
7380
end
7481

82+
function BlockSparseArray{T,N}(axes::Vararg{AbstractUnitRange,N}) where {T,N}
83+
return BlockSparseArray{T,N}(axes)
84+
end
85+
7586
function BlockSparseArray{T,0}(axes::Tuple{}) where {T}
7687
return BlockSparseArray{T,0,default_arraytype(T, axes)}(axes)
7788
end
@@ -80,6 +91,10 @@ function BlockSparseArray{T,N}(dims::Tuple{Vararg{Vector{Int},N}}) where {T,N}
8091
return BlockSparseArray{T,N}(blockedrange.(dims))
8192
end
8293

94+
function BlockSparseArray{T,N}(dims::Vararg{Vector{Int},N}) where {T,N}
95+
return BlockSparseArray{T,N}(dims)
96+
end
97+
8398
function BlockSparseArray{T}(dims::Tuple{Vararg{Vector{Int}}}) where {T}
8499
return BlockSparseArray{T,length(dims)}(dims)
85100
end
@@ -117,18 +132,37 @@ function BlockSparseArray{T,N}(
117132
return BlockSparseArray{T,N}(axes)
118133
end
119134

135+
function BlockSparseArray{T,N}(
136+
::UndefInitializer, axes::Vararg{AbstractUnitRange,N}
137+
) where {T,N}
138+
return BlockSparseArray{T,N}(axes)
139+
end
140+
120141
function BlockSparseArray{T,N}(
121142
::UndefInitializer, dims::Tuple{Vararg{Vector{Int},N}}
122143
) where {T,N}
123144
return BlockSparseArray{T,N}(dims)
124145
end
125146

147+
function BlockSparseArray{T,N}(::UndefInitializer, dims::Vararg{Vector{Int},N}) where {T,N}
148+
return BlockSparseArray{T,N}(dims)
149+
end
150+
126151
function BlockSparseArray{T}(
127152
::UndefInitializer, axes::Tuple{Vararg{AbstractUnitRange}}
128153
) where {T}
129154
return BlockSparseArray{T}(axes)
130155
end
131156

157+
function BlockSparseArray{T}(::UndefInitializer, axes::Vararg{AbstractUnitRange}) where {T}
158+
return BlockSparseArray{T}(axes...)
159+
end
160+
161+
# Fix ambiguity error.
162+
function BlockSparseArray{T}(::UndefInitializer) where {T}
163+
return BlockSparseArray{T}()
164+
end
165+
132166
function BlockSparseArray{T}(::UndefInitializer, dims::Tuple{Vararg{Vector{Int}}}) where {T}
133167
return BlockSparseArray{T}(dims)
134168
end

NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ using ..SparseArrayInterface: perm, iperm, nstored, sparse_zero!
1717

1818
blocksparse_blocks(a::AbstractArray) = error("Not implemented")
1919

20+
blockstype(a::AbstractArray) = blockstype(typeof(a))
21+
2022
function blocksparse_getindex(a::AbstractArray{<:Any,N}, I::Vararg{Int,N}) where {N}
2123
@boundscheck checkbounds(a, I...)
2224
return a[findblockindex.(axes(a), I)...]

NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,18 @@ using LinearAlgebra: Adjoint, dot, mul!, norm
2020
using NDTensors.BlockSparseArrays:
2121
@view!,
2222
BlockSparseArray,
23+
BlockSparseMatrix,
24+
BlockSparseVector,
2325
BlockView,
2426
block_nstored,
2527
block_reshape,
2628
block_stored_indices,
29+
blockstype,
30+
blocktype,
2731
view!
2832
using NDTensors.GPUArraysCoreExtensions: cpu
2933
using NDTensors.SparseArrayInterface: nstored
34+
using NDTensors.SparseArrayDOKs: SparseArrayDOK, SparseMatrixDOK, SparseVectorDOK
3035
using NDTensors.TensorAlgebra: contract
3136
using Test: @test, @test_broken, @test_throws, @testset
3237
include("TestBlockSparseArraysUtils.jl")
@@ -72,6 +77,75 @@ using .NDTensorsTestUtils: devices_list, is_supported_eltype
7277
ah = adjoint(a)
7378
@test_broken [ah[Block(Tuple(it))] for it in eachindex(block_stored_indices(ah))] isa Vector
7479
end
80+
@testset "Constructors" begin
81+
# BlockSparseMatrix
82+
bs = ([2, 3], [3, 4])
83+
for T in (
84+
BlockSparseArray{elt},
85+
BlockSparseArray{elt,2},
86+
BlockSparseMatrix{elt},
87+
## BlockSparseArray{elt,2,Matrix{elt}},
88+
## BlockSparseMatrix{elt,Matrix{elt}},
89+
## BlockSparseArray{elt,2,Matrix{elt},SparseMatrixDOK{Matrix{elt}}},
90+
## BlockSparseMatrix{elt,Matrix{elt},SparseMatrixDOK{Matrix{elt}}},
91+
)
92+
for args in (
93+
bs,
94+
(bs,),
95+
blockedrange.(bs),
96+
(blockedrange.(bs),),
97+
(undef, bs),
98+
(undef, bs...),
99+
(undef, blockedrange.(bs)),
100+
(undef, blockedrange.(bs)...),
101+
)
102+
a = T(args...)
103+
@test eltype(a) == elt
104+
@test blocktype(a) == Matrix{elt}
105+
@test blockstype(a) <: SparseMatrixDOK{Matrix{elt}}
106+
@test blocklengths.(axes(a)) == ([2, 3], [3, 4])
107+
@test iszero(a)
108+
@test_broken iszero(block_stored_length(a))
109+
@test iszero(block_nstored(a))
110+
@test_broken iszero(stored_length(a))
111+
@test iszero(nstored(a))
112+
end
113+
end
114+
115+
# BlockSparseVector
116+
bs = ([2, 3],)
117+
for T in (
118+
BlockSparseArray{elt},
119+
BlockSparseArray{elt,1},
120+
BlockSparseVector{elt},
121+
## BlockSparseArray{elt,1,Vector{elt}},
122+
## BlockSparseVector{elt,Vector{elt}},
123+
## BlockSparseArray{elt,1,Vector{elt},SparseVectorDOK{Vector{elt}}},
124+
## BlockSparseVector{elt,Vector{elt},SparseVectorDOK{Vector{elt}}},
125+
)
126+
for args in (
127+
bs,
128+
(bs,),
129+
blockedrange.(bs),
130+
(blockedrange.(bs),),
131+
(undef, bs),
132+
(undef, bs...),
133+
(undef, blockedrange.(bs)),
134+
(undef, blockedrange.(bs)...),
135+
)
136+
a = T(args...)
137+
@test eltype(a) == elt
138+
@test blocktype(a) == Vector{elt}
139+
@test blockstype(a) <: SparseVectorDOK{Vector{elt}}
140+
@test blocklengths.(axes(a)) == ([2, 3],)
141+
@test iszero(a)
142+
@test_broken iszero(block_stored_length(a))
143+
@test iszero(block_nstored(a))
144+
@test_broken iszero(stored_length(a))
145+
@test iszero(nstored(a))
146+
end
147+
end
148+
end
75149
@testset "Basics" begin
76150
a = dev(BlockSparseArray{elt}([2, 3], [2, 3]))
77151
@allowscalar @test a == dev(

0 commit comments

Comments
 (0)