Skip to content

Simplify show implementation #46

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 3, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 23 additions & 54 deletions src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using Adapt: Adapt, WrappedArray
using Adapt: Adapt, WrappedArray, adapt
using ArrayLayouts: zero!
using BlockArrays:
BlockArrays,
Expand Down Expand Up @@ -337,60 +337,29 @@ function Base.Array(a::AnyAbstractBlockSparseArray)
return Array{eltype(a)}(a)
end

using SparseArraysBase: ReplacedUnstoredSparseArray

# Wraps a block sparse array but replaces the unstored values.
# This is used in printing in order to customize printing
# of zero/unstored values.
struct ReplacedUnstoredBlockSparseArray{T,N,F,Parent<:AbstractArray{T,N}} <:
AbstractBlockSparseArray{T,N}
parent::Parent
getunstoredblock::F
end
Base.parent(a::ReplacedUnstoredBlockSparseArray) = a.parent
Base.axes(a::ReplacedUnstoredBlockSparseArray) = axes(parent(a))
function BlockArrays.blocks(a::ReplacedUnstoredBlockSparseArray)
return ReplacedUnstoredSparseArray(blocks(parent(a)), a.getunstoredblock)
end

# This is copied from `SparseArraysBase.jl` since it is not part
# of the public interface.
# Like `Char` but prints without quotes.
struct UnquotedChar <: AbstractChar
char::Char
function SparseArraysBase.isstored(
A::AnyAbstractBlockSparseArray{<:Any,N}, I::Vararg{Int,N}
) where {N}
bI = BlockIndex(findblockindex.(axes(A), I))
bA = blocks(A)
return isstored(bA, bI.I...) && isstored(bA[bI.I...], bI.α...)
end
Base.show(io::IO, c::UnquotedChar) = print(io, c.char)
Base.show(io::IO, ::MIME"text/plain", c::UnquotedChar) = show(io, c)

using FillArrays: Fill
struct GetUnstoredBlockShow{Axes}
axes::Axes
end
@inline function (f::GetUnstoredBlockShow)(
a::AbstractArray{<:Any,N}, I::Vararg{Int,N}
) where {N}
# TODO: Make sure this works for sparse or block sparse blocks, immutable
# blocks, diagonal blocks, etc.!
b_size = ntuple(ndims(a)) do d
return length(f.axes[d][Block(I[d])])
function Base.replace_in_print_matrix(
A::AnyAbstractBlockSparseArray{<:Any,2}, i::Integer, j::Integer, s::AbstractString
)
return isstored(A, i, j) ? s : Base.replace_with_centered_mark(s)
end

# attempt to catch things that wrap GPU arrays
function Base.print_array(io::IO, X::AnyAbstractBlockSparseArray)
X_cpu = adapt(Array, X)
if typeof(X_cpu) === typeof(X) # prevent infinite recursion
# need to specify ndims to allow specialized code for vector/matrix
@allowscalar @invoke Base.print_array(
io, X_cpu::AbstractArray{eltype(X_cpu),ndims(X_cpu)}
)
else
Base.print_array(io, X_cpu)
end
return Fill(UnquotedChar('.'), b_size)
end
# TODO: Use `Base.to_indices`.
@inline function (f::GetUnstoredBlockShow)(
a::AbstractArray{<:Any,N}, I::CartesianIndex{N}
) where {N}
return f(a, Tuple(I)...)
end

# TODO: Make this an `@interface ::AbstractBlockSparseArrayInterface` function
# once we delete the hacky `Base.show` definitions in `BlockSparseArraysTensorAlgebraExt`.
function Base.show(io::IO, mime::MIME"text/plain", a::AnyAbstractBlockSparseArray)
summary(io, a)
isempty(a) && return nothing
print(io, ":")
println(io)
a′ = ReplacedUnstoredBlockSparseArray(a, GetUnstoredBlockShow(axes(a)))
@allowscalar Base.print_array(io, a′)
return nothing
end
2 changes: 1 addition & 1 deletion test/test_basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1125,7 +1125,7 @@ arrayts = (Array, JLArray)
a = BlockSparseMatrix{elt,arrayt{elt,2}}([2, 2], [2, 2])
@allowscalar a[1, 2] = 12
@test sprint(show, "text/plain", a) ==
"$(summary(a)):\n $(zero(eltype(a))) $(eltype(a)(12)) │ . .\n $(zero(eltype(a))) $(zero(eltype(a))) │ . .\n ───────────┼──────\n . .. .\n . .. ."
"$(summary(a)):\n $(zero(eltype(a))) $(eltype(a)(12)) │ ⋅ ⋅ \n $(zero(eltype(a))) $(zero(eltype(a))) │ ⋅ ⋅ \n ───────────┼──────────\n ⋅ ⋅ \n ⋅ ⋅ "
end
end
@testset "TypeParameterAccessors.position" begin
Expand Down
Loading