diff --git a/Project.toml b/Project.toml index 24a9c47c..efe25135 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "BlockSparseArrays" uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4" authors = ["ITensor developers and contributors"] -version = "0.2.3" +version = "0.2.4" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" @@ -9,6 +9,7 @@ ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f" Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4" +FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" GradedUnitRanges = "e2de450a-8a67-46c7-b59c-01d5a3d041c5" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" @@ -32,6 +33,7 @@ ArrayLayouts = "1.10.4" BlockArrays = "1.2.0" DerivableInterfaces = "0.3.7" Dictionaries = "0.4.3" +FillArrays = "1.13.0" GPUArraysCore = "0.1.0, 0.2" GradedUnitRanges = "0.1.0" LabelledNumbers = "0.1.0" diff --git a/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl b/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl index 2648eddf..bc8c14b9 100644 --- a/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl +++ b/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl @@ -335,3 +335,61 @@ end function Base.Array(a::AnyAbstractBlockSparseArray) return Array{eltype(a)}(a) end + +using SparseArraysBase: ReplacedUnstoredSparseArray + +# Wraps a block sparse array but replaces the unstored values. +# This is used in printing in order to customize printing +# of zero/unstored values. +struct ReplacedUnstoredBlockSparseArray{T,N,F,Parent<:AbstractArray{T,N}} <: + AbstractBlockSparseArray{T,N} + parent::Parent + getunstoredblock::F +end +Base.parent(a::ReplacedUnstoredBlockSparseArray) = a.parent +Base.axes(a::ReplacedUnstoredBlockSparseArray) = axes(parent(a)) +function BlockArrays.blocks(a::ReplacedUnstoredBlockSparseArray) + return ReplacedUnstoredSparseArray(blocks(parent(a)), a.getunstoredblock) +end + +# This is copied from `SparseArraysBase.jl` since it is not part +# of the public interface. +# Like `Char` but prints without quotes. +struct UnquotedChar <: AbstractChar + char::Char +end +Base.show(io::IO, c::UnquotedChar) = print(io, c.char) +Base.show(io::IO, ::MIME"text/plain", c::UnquotedChar) = show(io, c) + +using FillArrays: Fill +struct GetUnstoredBlockShow{Axes} + axes::Axes +end +@inline function (f::GetUnstoredBlockShow)( + a::AbstractArray{<:Any,N}, I::Vararg{Int,N} +) where {N} + # TODO: Make sure this works for sparse or block sparse blocks, immutable + # blocks, diagonal blocks, etc.! + b_size = ntuple(ndims(a)) do d + return length(f.axes[d][Block(I[d])]) + end + return Fill(UnquotedChar('.'), b_size) +end +# TODO: Use `Base.to_indices`. +@inline function (f::GetUnstoredBlockShow)( + a::AbstractArray{<:Any,N}, I::CartesianIndex{N} +) where {N} + return f(a, Tuple(I)...) +end + +# TODO: Make this an `@interface ::AbstractBlockSparseArrayInterface` function +# once we delete the hacky `Base.show` definitions in `BlockSparseArraysTensorAlgebraExt`. +function Base.show(io::IO, mime::MIME"text/plain", a::AnyAbstractBlockSparseArray) + summary(io, a) + isempty(a) && return nothing + print(io, ":") + println(io) + a′ = ReplacedUnstoredBlockSparseArray(a, GetUnstoredBlockShow(axes(a))) + Base.print_array(io, a′) + return nothing +end diff --git a/test/basics/test_basics.jl b/test/basics/test_basics.jl index 6bb31674..7fcc348d 100644 --- a/test/basics/test_basics.jl +++ b/test/basics/test_basics.jl @@ -1045,4 +1045,16 @@ arrayts = (Array, JLArray) @test blockstoredlength(b) == 2 @test storedlength(b) == 17 end + @testset "show" begin + if elt === Float64 + # Not testing other element types since they change the + # spacing so it isn't easy to make the test general. + a = BlockSparseArray{elt}([2, 2], [2, 2]) + a[1, 2] = 12 + # TODO: Reenable this once we delete the hacky `Base.show` definitions + # in `BlockSparseArraysTensorAlgebraExt`. + @test_broken sprint(show, "text/plain", a) == + "$(summary(a)):$(zero(eltype(a))) $(eltype(a)(12)) │ ⋅ ⋅\n $(zero(eltype(a))) $(zero(eltype(a))) │ ⋅ ⋅\n ───────────┼──────\n ⋅ ⋅ │ ⋅ ⋅\n ⋅ ⋅ │ ⋅ ⋅" + end + end end