Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
name = "BlockSparseArrays"
uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.2.3"
version = "0.2.4"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f"
Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
GradedUnitRanges = "e2de450a-8a67-46c7-b59c-01d5a3d041c5"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Expand All @@ -32,6 +33,7 @@ ArrayLayouts = "1.10.4"
BlockArrays = "1.2.0"
DerivableInterfaces = "0.3.7"
Dictionaries = "0.4.3"
FillArrays = "1.13.0"
GPUArraysCore = "0.1.0, 0.2"
GradedUnitRanges = "0.1.0"
LabelledNumbers = "0.1.0"
Expand Down
58 changes: 58 additions & 0 deletions src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -335,3 +335,61 @@
function Base.Array(a::AnyAbstractBlockSparseArray)
return Array{eltype(a)}(a)
end

using SparseArraysBase: ReplacedUnstoredSparseArray

# Wraps a block sparse array but replaces the unstored values.
# This is used in printing in order to customize printing
# of zero/unstored values.
struct ReplacedUnstoredBlockSparseArray{T,N,F,Parent<:AbstractArray{T,N}} <:
AbstractBlockSparseArray{T,N}
parent::Parent
getunstoredblock::F
end
Base.parent(a::ReplacedUnstoredBlockSparseArray) = a.parent
Base.axes(a::ReplacedUnstoredBlockSparseArray) = axes(parent(a))
function BlockArrays.blocks(a::ReplacedUnstoredBlockSparseArray)
return ReplacedUnstoredSparseArray(blocks(parent(a)), a.getunstoredblock)

Check warning on line 352 in src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl#L349-L352

Added lines #L349 - L352 were not covered by tests
end

# This is copied from `SparseArraysBase.jl` since it is not part
# of the public interface.
# Like `Char` but prints without quotes.
struct UnquotedChar <: AbstractChar
char::Char
end
Base.show(io::IO, c::UnquotedChar) = print(io, c.char)
Base.show(io::IO, ::MIME"text/plain", c::UnquotedChar) = show(io, c)

Check warning on line 362 in src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl#L361-L362

Added lines #L361 - L362 were not covered by tests

using FillArrays: Fill
struct GetUnstoredBlockShow{Axes}
axes::Axes
end
@inline function (f::GetUnstoredBlockShow)(

Check warning on line 368 in src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl#L368

Added line #L368 was not covered by tests
a::AbstractArray{<:Any,N}, I::Vararg{Int,N}
) where {N}
# TODO: Make sure this works for sparse or block sparse blocks, immutable
# blocks, diagonal blocks, etc.!
b_size = ntuple(ndims(a)) do d
return length(f.axes[d][Block(I[d])])

Check warning on line 374 in src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl#L373-L374

Added lines #L373 - L374 were not covered by tests
end
return Fill(UnquotedChar('.'), b_size)

Check warning on line 376 in src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl#L376

Added line #L376 was not covered by tests
end
# TODO: Use `Base.to_indices`.
@inline function (f::GetUnstoredBlockShow)(

Check warning on line 379 in src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl#L379

Added line #L379 was not covered by tests
a::AbstractArray{<:Any,N}, I::CartesianIndex{N}
) where {N}
return f(a, Tuple(I)...)

Check warning on line 382 in src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl#L382

Added line #L382 was not covered by tests
end

# TODO: Make this an `@interface ::AbstractBlockSparseArrayInterface` function
# once we delete the hacky `Base.show` definitions in `BlockSparseArraysTensorAlgebraExt`.
function Base.show(io::IO, mime::MIME"text/plain", a::AnyAbstractBlockSparseArray)
summary(io, a)
isempty(a) && return nothing
print(io, ":")
println(io)
a′ = ReplacedUnstoredBlockSparseArray(a, GetUnstoredBlockShow(axes(a)))
Base.print_array(io, a′)
return nothing

Check warning on line 394 in src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl#L387-L394

Added lines #L387 - L394 were not covered by tests
end
12 changes: 12 additions & 0 deletions test/basics/test_basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1045,4 +1045,16 @@ arrayts = (Array, JLArray)
@test blockstoredlength(b) == 2
@test storedlength(b) == 17
end
@testset "show" begin
if elt === Float64
# Not testing other element types since they change the
# spacing so it isn't easy to make the test general.
a = BlockSparseArray{elt}([2, 2], [2, 2])
a[1, 2] = 12
# TODO: Reenable this once we delete the hacky `Base.show` definitions
# in `BlockSparseArraysTensorAlgebraExt`.
@test_broken sprint(show, "text/plain", a) ==
"$(summary(a)):$(zero(eltype(a))) $(eltype(a)(12)) │ ⋅ ⋅\n $(zero(eltype(a))) $(zero(eltype(a))) │ ⋅ ⋅\n ───────────┼──────\n ⋅ ⋅ │ ⋅ ⋅\n ⋅ ⋅ │ ⋅ ⋅"
end
end
end
Loading