|
335 | 335 | function Base.Array(a::AnyAbstractBlockSparseArray)
|
336 | 336 | return Array{eltype(a)}(a)
|
337 | 337 | end
|
| 338 | + |
| 339 | +using SparseArraysBase: ReplacedUnstoredSparseArray |
| 340 | + |
| 341 | +# Wraps a block sparse array but replaces the unstored values. |
| 342 | +# This is used in printing in order to customize printing |
| 343 | +# of zero/unstored values. |
| 344 | +struct ReplacedUnstoredBlockSparseArray{T,N,F,Parent<:AbstractArray{T,N}} <: |
| 345 | + AbstractBlockSparseArray{T,N} |
| 346 | + parent::Parent |
| 347 | + getunstoredblock::F |
| 348 | +end |
| 349 | +Base.parent(a::ReplacedUnstoredBlockSparseArray) = a.parent |
| 350 | +Base.axes(a::ReplacedUnstoredBlockSparseArray) = axes(parent(a)) |
| 351 | +function BlockArrays.blocks(a::ReplacedUnstoredBlockSparseArray) |
| 352 | + return ReplacedUnstoredSparseArray(blocks(parent(a)), a.getunstoredblock) |
| 353 | +end |
| 354 | + |
| 355 | +# This is copied from `SparseArraysBase.jl` since it is not part |
| 356 | +# of the public interface. |
| 357 | +# Like `Char` but prints without quotes. |
| 358 | +struct UnquotedChar <: AbstractChar |
| 359 | + char::Char |
| 360 | +end |
| 361 | +Base.show(io::IO, c::UnquotedChar) = print(io, c.char) |
| 362 | +Base.show(io::IO, ::MIME"text/plain", c::UnquotedChar) = show(io, c) |
| 363 | + |
| 364 | +using FillArrays: Fill |
| 365 | +struct GetUnstoredBlockShow{Axes} |
| 366 | + axes::Axes |
| 367 | +end |
| 368 | +@inline function (f::GetUnstoredBlockShow)( |
| 369 | + a::AbstractArray{<:Any,N}, I::Vararg{Int,N} |
| 370 | +) where {N} |
| 371 | + # TODO: Make sure this works for sparse or block sparse blocks, immutable |
| 372 | + # blocks, diagonal blocks, etc.! |
| 373 | + b_size = ntuple(ndims(a)) do d |
| 374 | + return length(f.axes[d][Block(I[d])]) |
| 375 | + end |
| 376 | + return Fill(UnquotedChar('.'), b_size) |
| 377 | +end |
| 378 | +# TODO: Use `Base.to_indices`. |
| 379 | +@inline function (f::GetUnstoredBlockShow)( |
| 380 | + a::AbstractArray{<:Any,N}, I::CartesianIndex{N} |
| 381 | +) where {N} |
| 382 | + return f(a, Tuple(I)...) |
| 383 | +end |
| 384 | + |
| 385 | +# TODO: Make this an `@interface ::AbstractBlockSparseArrayInterface` function |
| 386 | +# once we delete the hacky `Base.show` definitions in `BlockSparseArraysTensorAlgebraExt`. |
| 387 | +function Base.show(io::IO, mime::MIME"text/plain", a::AnyAbstractBlockSparseArray) |
| 388 | + summary(io, a) |
| 389 | + isempty(a) && return nothing |
| 390 | + print(io, ":") |
| 391 | + println(io) |
| 392 | + a′ = ReplacedUnstoredBlockSparseArray(a, GetUnstoredBlockShow(axes(a))) |
| 393 | + Base.print_array(io, a′) |
| 394 | + return nothing |
| 395 | +end |
0 commit comments