Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -278,3 +278,17 @@ function BlockArrays.viewblock(
) where {T,N}
return view(viewblock(a, Block.(block)...), map(b -> only(b.indices), block)...)
end

# migrate wrapper layer for viewing `adjoint` and `transpose`.
for (f, F) in ((:adjoint, :Adjoint), (:transpose, :Transpose))
@eval begin
function Base.view(A::$F{<:Any,<:AbstractBlockSparseVector}, b::Block{1})
return $f(view(parent(A), b))
end

Base.view(A::$F{<:Any,<:AbstractBlockSparseMatrix}, b::Block{2}) = view(A, Tuple(b)...)
function Base.view(A::$F{<:Any,<:AbstractBlockSparseMatrix}, b1::Block{1}, b2::Block{1})
return $f(view(parent(A), b2, b1))
end
end
end
Original file line number Diff line number Diff line change
Expand Up @@ -186,79 +186,8 @@ end
reverse_index(index) = reverse(index)
reverse_index(index::CartesianIndex) = CartesianIndex(reverse(Tuple(index)))

# Represents the array of arrays of a `Transpose`
# wrapping a block spare array, i.e. `blocks(array)` where `a` is a `Transpose`.
struct SparseTransposeBlocks{T,BlockType<:AbstractArray{T},Array<:Transpose{T}} <:
AbstractSparseMatrix{BlockType}
array::Array
end
function blocksparse_blocks(a::Transpose)
return SparseTransposeBlocks{eltype(a),blocktype(parent(a)),typeof(a)}(a)
end
function Base.size(a::SparseTransposeBlocks)
return reverse(size(blocks(parent(a.array))))
end
function Base.getindex(a::SparseTransposeBlocks, index::Vararg{Int,2})
return transpose(blocks(parent(a.array))[reverse(index)...])
end
# TODO: This should be handled by generic `AbstractSparseArray` code.
function Base.getindex(a::SparseTransposeBlocks, index::CartesianIndex{2})
return a[Tuple(index)...]
end
# TODO: Create a generic `parent_index` function to map an index
# a parent index.
function Base.isassigned(a::SparseTransposeBlocks, index::Vararg{Int,2})
return isassigned(blocks(parent(a.array)), reverse(index)...)
end
function SparseArrayInterface.stored_indices(a::SparseTransposeBlocks)
return map(reverse_index, stored_indices(blocks(parent(a.array))))
end
# TODO: Either make this the generic interface or define
# `SparseArrayInterface.sparse_storage`, which is used
# to defined this.
SparseArrayInterface.stored_length(a::SparseTransposeBlocks) = length(stored_indices(a))
function SparseArrayInterface.sparse_storage(a::SparseTransposeBlocks)
return error("Not implemented")
end

# Represents the array of arrays of a `Adjoint`
# wrapping a block spare array, i.e. `blocks(array)` where `a` is a `Adjoint`.
struct SparseAdjointBlocks{T,BlockType<:AbstractArray{T},Array<:Adjoint{T}} <:
AbstractSparseMatrix{BlockType}
array::Array
end
function blocksparse_blocks(a::Adjoint)
return SparseAdjointBlocks{eltype(a),blocktype(parent(a)),typeof(a)}(a)
end
function Base.size(a::SparseAdjointBlocks)
return reverse(size(blocks(parent(a.array))))
end
# TODO: Create a generic `parent_index` function to map an index
# a parent index.
function Base.getindex(a::SparseAdjointBlocks, index::Vararg{Int,2})
return blocks(parent(a.array))[reverse(index)...]'
end
# TODO: Create a generic `parent_index` function to map an index
# a parent index.
# TODO: This should be handled by generic `AbstractSparseArray` code.
function Base.getindex(a::SparseAdjointBlocks, index::CartesianIndex{2})
return a[Tuple(index)...]
end
# TODO: Create a generic `parent_index` function to map an index
# a parent index.
function Base.isassigned(a::SparseAdjointBlocks, index::Vararg{Int,2})
return isassigned(blocks(parent(a.array)), reverse(index)...)
end
function SparseArrayInterface.stored_indices(a::SparseAdjointBlocks)
return map(reverse_index, stored_indices(blocks(parent(a.array))))
end
# TODO: Either make this the generic interface or define
# `SparseArrayInterface.sparse_storage`, which is used
# to defined this.
SparseArrayInterface.stored_length(a::SparseAdjointBlocks) = length(stored_indices(a))
function SparseArrayInterface.sparse_storage(a::SparseAdjointBlocks)
return error("Not implemented")
end
blocksparse_blocks(a::Transpose) = transpose(blocks(parent(a)))
blocksparse_blocks(a::Adjoint) = adjoint(blocks(parent(a)))

# Represents the array of arrays of a `SubArray`
# wrapping a block spare array, i.e. `blocks(array)` where `a` is a `SubArray`.
Expand Down
15 changes: 9 additions & 6 deletions NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,6 @@ using .NDTensorsTestUtils: devices_list, is_supported_eltype
@test adjoint(a) isa Adjoint{elt,<:BlockSparseArray}
@test_broken adjoint(a)[Block(1), :] isa Adjoint{elt,<:BlockSparseArray}
# could also be directly a BlockSparseArray

a = dev(BlockSparseArray{elt}([1], [1, 1]))
@allowscalar a[1, 2] = 1
@test [a[Block(Tuple(it))] for it in eachindex(block_stored_indices(a))] isa Vector
ah = adjoint(a)
@test_broken [ah[Block(Tuple(it))] for it in eachindex(block_stored_indices(ah))] isa Vector
end
@testset "Constructors" begin
# BlockSparseMatrix
Expand Down Expand Up @@ -210,6 +204,15 @@ using .NDTensorsTestUtils: devices_list, is_supported_eltype
## @test b[Block()[]] == 2
end
end

# adjoint and transpose
a = dev(BlockSparseArray{elt}([1], [1, 1]))
@allowscalar a[1, 2] = 1
@test [@views(a[it]) for it in block_stored_indices(a)] isa Vector
ah = adjoint(a)
@test [@views(ah[it]) for it in block_stored_indices(ah)] isa Vector
at = transpose(a)
@test [@views(at[it]) for it in block_stored_indices(at)] isa Vector
end
@testset "Tensor algebra" begin
a = dev(BlockSparseArray{elt}(undef, ([2, 3], [3, 4])))
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using Adapt: WrappedArray
using LinearAlgebra: Adjoint, Transpose

const WrappedAbstractSparseArray{T,N,A} = WrappedArray{
T,N,<:AbstractSparseArray,<:AbstractSparseArray{T,N}
Expand All @@ -7,3 +8,13 @@ const WrappedAbstractSparseArray{T,N,A} = WrappedArray{
const AnyAbstractSparseArray{T,N} = Union{
<:AbstractSparseArray{T,N},<:WrappedAbstractSparseArray{T,N}
}

function stored_indices(a::Adjoint)
return Iterators.map(I -> CartesianIndex(reverse(Tuple(I))), stored_indices(parent(a)))
end
stored_length(a::Adjoint) = stored_length(parent(a))

function stored_indices(a::Transpose)
return Iterators.map(I -> CartesianIndex(reverse(Tuple(I))), stored_indices(parent(a)))
end
stored_length(a::Transpose) = stored_length(parent(a))
Loading