Skip to content

Commit e128b18

Browse files
authored
Improve showing unstored values (#22)
1 parent 6019b92 commit e128b18

File tree

3 files changed

+73
-1
lines changed

3 files changed

+73
-1
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
name = "BlockSparseArrays"
22
uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.2.3"
4+
version = "0.2.4"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
88
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
99
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
1010
DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f"
1111
Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4"
12+
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
1213
GradedUnitRanges = "e2de450a-8a67-46c7-b59c-01d5a3d041c5"
1314
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1415
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
@@ -32,6 +33,7 @@ ArrayLayouts = "1.10.4"
3233
BlockArrays = "1.2.0"
3334
DerivableInterfaces = "0.3.7"
3435
Dictionaries = "0.4.3"
36+
FillArrays = "1.13.0"
3537
GPUArraysCore = "0.1.0, 0.2"
3638
GradedUnitRanges = "0.1.0"
3739
LabelledNumbers = "0.1.0"

src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,3 +335,61 @@ end
335335
function Base.Array(a::AnyAbstractBlockSparseArray)
336336
return Array{eltype(a)}(a)
337337
end
338+
339+
using SparseArraysBase: ReplacedUnstoredSparseArray
340+
341+
# Wraps a block sparse array but replaces the unstored values.
342+
# This is used in printing in order to customize printing
343+
# of zero/unstored values.
344+
struct ReplacedUnstoredBlockSparseArray{T,N,F,Parent<:AbstractArray{T,N}} <:
345+
AbstractBlockSparseArray{T,N}
346+
parent::Parent
347+
getunstoredblock::F
348+
end
349+
Base.parent(a::ReplacedUnstoredBlockSparseArray) = a.parent
350+
Base.axes(a::ReplacedUnstoredBlockSparseArray) = axes(parent(a))
351+
function BlockArrays.blocks(a::ReplacedUnstoredBlockSparseArray)
352+
return ReplacedUnstoredSparseArray(blocks(parent(a)), a.getunstoredblock)
353+
end
354+
355+
# This is copied from `SparseArraysBase.jl` since it is not part
356+
# of the public interface.
357+
# Like `Char` but prints without quotes.
358+
struct UnquotedChar <: AbstractChar
359+
char::Char
360+
end
361+
Base.show(io::IO, c::UnquotedChar) = print(io, c.char)
362+
Base.show(io::IO, ::MIME"text/plain", c::UnquotedChar) = show(io, c)
363+
364+
using FillArrays: Fill
365+
struct GetUnstoredBlockShow{Axes}
366+
axes::Axes
367+
end
368+
@inline function (f::GetUnstoredBlockShow)(
369+
a::AbstractArray{<:Any,N}, I::Vararg{Int,N}
370+
) where {N}
371+
# TODO: Make sure this works for sparse or block sparse blocks, immutable
372+
# blocks, diagonal blocks, etc.!
373+
b_size = ntuple(ndims(a)) do d
374+
return length(f.axes[d][Block(I[d])])
375+
end
376+
return Fill(UnquotedChar('.'), b_size)
377+
end
378+
# TODO: Use `Base.to_indices`.
379+
@inline function (f::GetUnstoredBlockShow)(
380+
a::AbstractArray{<:Any,N}, I::CartesianIndex{N}
381+
) where {N}
382+
return f(a, Tuple(I)...)
383+
end
384+
385+
# TODO: Make this an `@interface ::AbstractBlockSparseArrayInterface` function
386+
# once we delete the hacky `Base.show` definitions in `BlockSparseArraysTensorAlgebraExt`.
387+
function Base.show(io::IO, mime::MIME"text/plain", a::AnyAbstractBlockSparseArray)
388+
summary(io, a)
389+
isempty(a) && return nothing
390+
print(io, ":")
391+
println(io)
392+
a′ = ReplacedUnstoredBlockSparseArray(a, GetUnstoredBlockShow(axes(a)))
393+
Base.print_array(io, a′)
394+
return nothing
395+
end

test/basics/test_basics.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1045,4 +1045,16 @@ arrayts = (Array, JLArray)
10451045
@test blockstoredlength(b) == 2
10461046
@test storedlength(b) == 17
10471047
end
1048+
@testset "show" begin
1049+
if elt === Float64
1050+
# Not testing other element types since they change the
1051+
# spacing so it isn't easy to make the test general.
1052+
a = BlockSparseArray{elt}([2, 2], [2, 2])
1053+
a[1, 2] = 12
1054+
# TODO: Reenable this once we delete the hacky `Base.show` definitions
1055+
# in `BlockSparseArraysTensorAlgebraExt`.
1056+
@test_broken sprint(show, "text/plain", a) ==
1057+
"$(summary(a)):$(zero(eltype(a))) $(eltype(a)(12)) │ ⋅ ⋅\n $(zero(eltype(a))) $(zero(eltype(a))) │ ⋅ ⋅\n ───────────┼──────\n ⋅ ⋅ │ ⋅ ⋅\n ⋅ ⋅ │ ⋅ ⋅"
1058+
end
1059+
end
10481060
end

0 commit comments

Comments
 (0)