Skip to content

Commit 2578245

Browse files
authored
Better definitions of norm and tr (#127)
1 parent 01d034f commit 2578245

File tree

5 files changed

+57
-10
lines changed

5 files changed

+57
-10
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "BlockSparseArrays"
22
uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.6.6"
4+
version = "0.6.7"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/abstractblocksparsearray/linearalgebra.jl

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using LinearAlgebra: Adjoint, Transpose
1+
using LinearAlgebra: LinearAlgebra, Adjoint, Transpose, norm, tr
22

33
# Like: https://github.com/JuliaLang/julia/blob/v1.11.1/stdlib/LinearAlgebra/src/transpose.jl#L184
44
# but also takes the dual of the axes.
@@ -16,3 +16,19 @@ function Base.copy(a::Transpose{T,<:AbstractBlockSparseMatrix{T}}) where {T}
1616
a_dest .= a
1717
return a_dest
1818
end
19+
20+
function LinearAlgebra.norm(a::AnyAbstractBlockSparseArray, p::Real=2)
21+
nrmᵖ = float(norm(zero(eltype(a))))
22+
for I in eachblockstoredindex(a)
23+
nrmᵖ += norm(@view(a[I]), p)^p
24+
end
25+
return nrmᵖ^(1/p)
26+
end
27+
28+
function LinearAlgebra.tr(a::AnyAbstractBlockSparseMatrix)
29+
tr_a = zero(eltype(a))
30+
for I in eachstoredblockdiagindex(a)
31+
tr_a += tr(@view(a[I]))
32+
end
33+
return tr_a
34+
end

src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -338,14 +338,17 @@ end
338338
# TODO: Implement this in a more generic way using a smarter `copyto!`,
339339
# which is ultimately what `Array{T,N}(::AbstractArray{<:Any,N})` calls.
340340
# These are defined for now to avoid scalar indexing issues when there
341-
# are blocks on GPU.
341+
# are blocks on GPU, and also work with exotic block types like
342+
# KroneckerArrays.
342343
function Base.Array{T,N}(a::AnyAbstractBlockSparseArray{<:Any,N}) where {T,N}
343-
# First make it dense, then move to CPU.
344-
# Directly copying to CPU causes some issues with
345-
# scalar indexing on GPU which we have to investigate.
346-
a_dest = similartype(blocktype(a), T)(undef, size(a))
347-
a_dest .= a
348-
return Array{T,N}(a_dest)
344+
a_dest = zeros(T, size(a))
345+
for I in eachblockstoredindex(a)
346+
# TODO: Use: `I′ = CartesianIndices(axes(a))[I]`, unfortunately this
347+
# outputs `Matrix{CartesianIndex}` instead of `CartesianIndices`.
348+
I′ = CartesianIndices(ntuple(dim -> axes(a, dim)[Tuple(I)[dim]], ndims(a)))
349+
a_dest[I′] = Array{T,N}(@view(a[I]))
350+
end
351+
return a_dest
349352
end
350353
function Base.Array{T}(a::AnyAbstractBlockSparseArray) where {T}
351354
return Array{T,ndims(a)}(a)

src/blocksparsearrayinterface/blocksparsearrayinterface.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,16 @@ function eachblockstoredindex(a::AbstractArray)
3636
return Block.(Tuple.(eachstoredindex(blocks(a))))
3737
end
3838

39+
using DiagonalArrays: diagindices
40+
# Block version of `DiagonalArrays.diagindices`.
41+
function blockdiagindices(a::AbstractArray)
42+
return Block.(Tuple.(diagindices(blocks(a))))
43+
end
44+
45+
function eachstoredblockdiagindex(a::AbstractArray)
46+
return eachblockstoredindex(a) blockdiagindices(a)
47+
end
48+
3949
# Like `BlockArrays.eachblock` but only iterating
4050
# over stored blocks.
4151
function eachstoredblock(a::AbstractArray)

test/test_basics.jl

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,17 +22,19 @@ using BlockSparseArrays:
2222
BlockSparseMatrix,
2323
BlockSparseVector,
2424
BlockView,
25+
blockdiagindices,
2526
blockreshape,
2627
blockstoredlength,
2728
blockstype,
2829
blocktype,
2930
eachblockstoredindex,
3031
eachstoredblock,
32+
eachstoredblockdiagindex,
3133
sparsemortar,
3234
view!
3335
using GPUArraysCore: @allowscalar
3436
using JLArrays: JLArray, JLMatrix
35-
using LinearAlgebra: Adjoint, Transpose, dot, norm
37+
using LinearAlgebra: Adjoint, Transpose, dot, norm, tr
3638
using SparseArraysBase: SparseArrayDOK, SparseMatrixDOK, SparseVectorDOK, storedlength
3739
using Test: @test, @test_broken, @test_throws, @testset, @inferred
3840
using TestExtras: @constinferred
@@ -217,10 +219,26 @@ arrayts = (Array, JLArray)
217219
a[Block(1, 2)] = randn(elt, 2, 3)
218220
@test issetequal(eachstoredblock(a), [a[Block(2, 1)], a[Block(1, 2)]])
219221
@test issetequal(eachblockstoredindex(a), [Block(2, 1), Block(1, 2)])
222+
@test issetequal(blockdiagindices(a), [Block(1, 1), Block(2, 2)])
223+
@test isempty(eachstoredblockdiagindex(a))
224+
@test norm(a) norm(Array(a))
225+
for p in 1:3
226+
@test norm(a, p) norm(Array(a), p)
227+
end
228+
@test tr(a) tr(Array(a))
220229

221230
a[3, 3] = NaN
222231
@test isnan(norm(a))
223232

233+
a = dev(BlockSparseArray{elt}(undef, [2, 3], [2, 3]))
234+
a[Block(1, 1)] = dev(randn(elt, 2, 2))
235+
@test issetequal(eachstoredblockdiagindex(a), [Block(1, 1)])
236+
@test norm(a) norm(Array(a))
237+
for p in 1:3
238+
@test norm(a, p) norm(Array(a), p)
239+
end
240+
@test tr(a) tr(Array(a))
241+
224242
# Empty constructor
225243
for a in (dev(BlockSparseArray{elt}(undef)),)
226244
@test size(a) == ()

0 commit comments

Comments
 (0)