Skip to content

Commit 88eca40

Browse files
authored
[BlockSparseArrays] Simplifications of blocks for blocksparse Adjoint and Transpose (#1580)
1 parent ea2453c commit 88eca40

File tree

3 files changed

+66
-81
lines changed

3 files changed

+66
-81
lines changed

src/abstractblocksparsearray/views.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,3 +278,17 @@ function BlockArrays.viewblock(
278278
) where {T,N}
279279
return view(viewblock(a, Block.(block)...), map(b -> only(b.indices), block)...)
280280
end
281+
282+
# migrate wrapper layer for viewing `adjoint` and `transpose`.
283+
for (f, F) in ((:adjoint, :Adjoint), (:transpose, :Transpose))
284+
@eval begin
285+
function Base.view(A::$F{<:Any,<:AbstractBlockSparseVector}, b::Block{1})
286+
return $f(view(parent(A), b))
287+
end
288+
289+
Base.view(A::$F{<:Any,<:AbstractBlockSparseMatrix}, b::Block{2}) = view(A, Tuple(b)...)
290+
function Base.view(A::$F{<:Any,<:AbstractBlockSparseMatrix}, b1::Block{1}, b2::Block{1})
291+
return $f(view(parent(A), b2, b1))
292+
end
293+
end
294+
end

src/blocksparsearrayinterface/blocksparsearrayinterface.jl

Lines changed: 2 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -186,79 +186,8 @@ end
186186
reverse_index(index) = reverse(index)
187187
reverse_index(index::CartesianIndex) = CartesianIndex(reverse(Tuple(index)))
188188

189-
# Represents the array of arrays of a `Transpose`
190-
# wrapping a block spare array, i.e. `blocks(array)` where `a` is a `Transpose`.
191-
struct SparseTransposeBlocks{T,BlockType<:AbstractArray{T},Array<:Transpose{T}} <:
192-
AbstractSparseMatrix{BlockType}
193-
array::Array
194-
end
195-
function blocksparse_blocks(a::Transpose)
196-
return SparseTransposeBlocks{eltype(a),blocktype(parent(a)),typeof(a)}(a)
197-
end
198-
function Base.size(a::SparseTransposeBlocks)
199-
return reverse(size(blocks(parent(a.array))))
200-
end
201-
function Base.getindex(a::SparseTransposeBlocks, index::Vararg{Int,2})
202-
return transpose(blocks(parent(a.array))[reverse(index)...])
203-
end
204-
# TODO: This should be handled by generic `AbstractSparseArray` code.
205-
function Base.getindex(a::SparseTransposeBlocks, index::CartesianIndex{2})
206-
return a[Tuple(index)...]
207-
end
208-
# TODO: Create a generic `parent_index` function to map an index
209-
# a parent index.
210-
function Base.isassigned(a::SparseTransposeBlocks, index::Vararg{Int,2})
211-
return isassigned(blocks(parent(a.array)), reverse(index)...)
212-
end
213-
function SparseArrayInterface.stored_indices(a::SparseTransposeBlocks)
214-
return map(reverse_index, stored_indices(blocks(parent(a.array))))
215-
end
216-
# TODO: Either make this the generic interface or define
217-
# `SparseArrayInterface.sparse_storage`, which is used
218-
# to defined this.
219-
SparseArrayInterface.stored_length(a::SparseTransposeBlocks) = length(stored_indices(a))
220-
function SparseArrayInterface.sparse_storage(a::SparseTransposeBlocks)
221-
return error("Not implemented")
222-
end
223-
224-
# Represents the array of arrays of a `Adjoint`
225-
# wrapping a block spare array, i.e. `blocks(array)` where `a` is a `Adjoint`.
226-
struct SparseAdjointBlocks{T,BlockType<:AbstractArray{T},Array<:Adjoint{T}} <:
227-
AbstractSparseMatrix{BlockType}
228-
array::Array
229-
end
230-
function blocksparse_blocks(a::Adjoint)
231-
return SparseAdjointBlocks{eltype(a),blocktype(parent(a)),typeof(a)}(a)
232-
end
233-
function Base.size(a::SparseAdjointBlocks)
234-
return reverse(size(blocks(parent(a.array))))
235-
end
236-
# TODO: Create a generic `parent_index` function to map an index
237-
# a parent index.
238-
function Base.getindex(a::SparseAdjointBlocks, index::Vararg{Int,2})
239-
return blocks(parent(a.array))[reverse(index)...]'
240-
end
241-
# TODO: Create a generic `parent_index` function to map an index
242-
# a parent index.
243-
# TODO: This should be handled by generic `AbstractSparseArray` code.
244-
function Base.getindex(a::SparseAdjointBlocks, index::CartesianIndex{2})
245-
return a[Tuple(index)...]
246-
end
247-
# TODO: Create a generic `parent_index` function to map an index
248-
# a parent index.
249-
function Base.isassigned(a::SparseAdjointBlocks, index::Vararg{Int,2})
250-
return isassigned(blocks(parent(a.array)), reverse(index)...)
251-
end
252-
function SparseArrayInterface.stored_indices(a::SparseAdjointBlocks)
253-
return map(reverse_index, stored_indices(blocks(parent(a.array))))
254-
end
255-
# TODO: Either make this the generic interface or define
256-
# `SparseArrayInterface.sparse_storage`, which is used
257-
# to defined this.
258-
SparseArrayInterface.stored_length(a::SparseAdjointBlocks) = length(stored_indices(a))
259-
function SparseArrayInterface.sparse_storage(a::SparseAdjointBlocks)
260-
return error("Not implemented")
261-
end
189+
blocksparse_blocks(a::Transpose) = transpose(blocks(parent(a)))
190+
blocksparse_blocks(a::Adjoint) = adjoint(blocks(parent(a)))
262191

263192
# Represents the array of arrays of a `SubArray`
264193
# wrapping a block spare array, i.e. `blocks(array)` where `a` is a `SubArray`.

test/test_basics.jl

Lines changed: 50 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ using BlockArrays:
1616
mortar
1717
using Compat: @compat
1818
using GPUArraysCore: @allowscalar
19-
using LinearAlgebra: Adjoint, dot, mul!, norm
19+
using LinearAlgebra: Adjoint, Transpose, dot, mul!, norm
2020
using NDTensors.BlockSparseArrays:
2121
@view!,
2222
BlockSparseArray,
@@ -33,7 +33,7 @@ using NDTensors.GPUArraysCoreExtensions: cpu
3333
using NDTensors.SparseArrayInterface: stored_length
3434
using NDTensors.SparseArrayDOKs: SparseArrayDOK, SparseMatrixDOK, SparseVectorDOK
3535
using NDTensors.TensorAlgebra: contract
36-
using Test: @test, @test_broken, @test_throws, @testset
36+
using Test: @test, @test_broken, @test_throws, @testset, @inferred
3737
include("TestBlockSparseArraysUtils.jl")
3838

3939
using NDTensors: NDTensors
@@ -70,12 +70,6 @@ using .NDTensorsTestUtils: devices_list, is_supported_eltype
7070
@test adjoint(a) isa Adjoint{elt,<:BlockSparseArray}
7171
@test_broken adjoint(a)[Block(1), :] isa Adjoint{elt,<:BlockSparseArray}
7272
# could also be directly a BlockSparseArray
73-
74-
a = dev(BlockSparseArray{elt}([1], [1, 1]))
75-
@allowscalar a[1, 2] = 1
76-
@test [a[Block(Tuple(it))] for it in eachindex(block_stored_indices(a))] isa Vector
77-
ah = adjoint(a)
78-
@test_broken [ah[Block(Tuple(it))] for it in eachindex(block_stored_indices(ah))] isa Vector
7973
end
8074
@testset "Constructors" begin
8175
# BlockSparseMatrix
@@ -210,6 +204,54 @@ using .NDTensorsTestUtils: devices_list, is_supported_eltype
210204
## @test b[Block()[]] == 2
211205
end
212206
end
207+
208+
@testset "Transpose" begin
209+
a = dev(BlockSparseArray{elt}([2, 2], [3, 3, 1]))
210+
a[Block(1, 1)] = dev(randn(elt, 2, 3))
211+
a[Block(2, 3)] = dev(randn(elt, 2, 1))
212+
213+
at = @inferred transpose(a)
214+
@test at isa Transpose
215+
@test size(at) == reverse(size(a))
216+
@test blocksize(at) == reverse(blocksize(a))
217+
@test stored_length(at) == stored_length(a)
218+
@test block_stored_length(at) == block_stored_length(a)
219+
for bind in block_stored_indices(a)
220+
bindt = Block(reverse(Int.(Tuple(bind))))
221+
@test bindt in block_stored_indices(at)
222+
end
223+
224+
@test @views(at[Block(1, 1)]) == transpose(a[Block(1, 1)])
225+
@test @views(at[Block(1, 1)]) isa Transpose
226+
@test @views(at[Block(3, 2)]) == transpose(a[Block(2, 3)])
227+
# TODO: BlockView == AbstractArray calls scalar code
228+
@test @allowscalar @views(at[Block(1, 2)]) == transpose(a[Block(2, 1)])
229+
@test @views(at[Block(1, 2)]) isa Transpose
230+
end
231+
232+
@testset "Adjoint" begin
233+
a = dev(BlockSparseArray{elt}([2, 2], [3, 3, 1]))
234+
a[Block(1, 1)] = dev(randn(elt, 2, 3))
235+
a[Block(2, 3)] = dev(randn(elt, 2, 1))
236+
237+
at = @inferred adjoint(a)
238+
@test at isa Adjoint
239+
@test size(at) == reverse(size(a))
240+
@test blocksize(at) == reverse(blocksize(a))
241+
@test stored_length(at) == stored_length(a)
242+
@test block_stored_length(at) == block_stored_length(a)
243+
for bind in block_stored_indices(a)
244+
bindt = Block(reverse(Int.(Tuple(bind))))
245+
@test bindt in block_stored_indices(at)
246+
end
247+
248+
@test @views(at[Block(1, 1)]) == adjoint(a[Block(1, 1)])
249+
@test @views(at[Block(1, 1)]) isa Adjoint
250+
@test @views(at[Block(3, 2)]) == adjoint(a[Block(2, 3)])
251+
# TODO: BlockView == AbstractArray calls scalar code
252+
@test @allowscalar @views(at[Block(1, 2)]) == adjoint(a[Block(2, 1)])
253+
@test @views(at[Block(1, 2)]) isa Adjoint
254+
end
213255
end
214256
@testset "Tensor algebra" begin
215257
a = dev(BlockSparseArray{elt}(undef, ([2, 3], [3, 4])))

0 commit comments

Comments
 (0)