Skip to content

Commit fdb56cc

Browse files
authored
[GradedAxes] [BlockSparseArrays] Fix ambiguity error when slicing GradedUnitRange with BlockSlice (#1491)
* [GradedAxes] [BlockSparseArrays] Fix ambiguity issue when slicing GradedUnitRange with BlockSlice * [NDTensors] Bump to v0.3.25
1 parent ccad1a4 commit fdb56cc

File tree

3 files changed

+66
-30
lines changed

3 files changed

+66
-30
lines changed

ext/BlockSparseArraysGradedAxesExt/test/runtests.jl

Lines changed: 40 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -95,37 +95,48 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
9595
end
9696
@testset "dual axes" begin
9797
r = gradedrange([U1(0) => 2, U1(1) => 2])
98-
a = BlockSparseArray{elt}(dual(r), r)
99-
@views for b in [Block(1, 1), Block(2, 2)]
100-
a[b] = randn(elt, size(a[b]))
101-
end
102-
# TODO: Define and use `isdual` here.
103-
@test axes(a, 1) isa UnitRangeDual
104-
@test axes(a, 2) isa GradedUnitRange
105-
@test !(axes(a, 2) isa UnitRangeDual)
106-
a_dense = Array(a)
107-
@test eachindex(a) == CartesianIndices(size(a))
108-
for I in eachindex(a)
109-
@test a[I] == a_dense[I]
110-
end
111-
@test axes(a') == dual.(reverse(axes(a)))
112-
# TODO: Define and use `isdual` here.
113-
@test axes(a', 1) isa UnitRangeDual
114-
@test axes(a', 2) isa GradedUnitRange
115-
@test !(axes(a', 2) isa UnitRangeDual)
116-
@test isnothing(show(devnull, MIME("text/plain"), a))
117-
118-
# Check preserving dual in tensor algebra.
119-
for b in (a + a, 2 * a, 3 * a - a)
120-
@test Array(b) 2 * Array(a)
98+
for ax in ((r, r), (dual(r), r), (r, dual(r)), (dual(r), dual(r)))
99+
a = BlockSparseArray{elt}(ax...)
100+
@views for b in [Block(1, 1), Block(2, 2)]
101+
a[b] = randn(elt, size(a[b]))
102+
end
121103
# TODO: Define and use `isdual` here.
122-
@test axes(b, 1) isa UnitRangeDual
123-
@test axes(b, 2) isa GradedUnitRange
124-
@test !(axes(b, 2) isa UnitRangeDual)
125-
end
104+
for dim in 1:ndims(a)
105+
@test typeof(ax[dim]) === typeof(axes(a, dim))
106+
end
107+
@test @view(a[Block(1, 1)])[1, 1] == a[1, 1]
108+
@test @view(a[Block(1, 1)])[2, 1] == a[2, 1]
109+
@test @view(a[Block(1, 1)])[1, 2] == a[1, 2]
110+
@test @view(a[Block(1, 1)])[2, 2] == a[2, 2]
111+
@test @view(a[Block(2, 2)])[1, 1] == a[3, 3]
112+
@test @view(a[Block(2, 2)])[2, 1] == a[4, 3]
113+
@test @view(a[Block(2, 2)])[1, 2] == a[3, 4]
114+
@test @view(a[Block(2, 2)])[2, 2] == a[4, 4]
115+
@test @view(a[Block(1, 1)])[1:2, 1:2] == a[1:2, 1:2]
116+
@test @view(a[Block(2, 2)])[1:2, 1:2] == a[3:4, 3:4]
117+
a_dense = Array(a)
118+
@test eachindex(a) == CartesianIndices(size(a))
119+
for I in eachindex(a)
120+
@test a[I] == a_dense[I]
121+
end
122+
@test axes(a') == dual.(reverse(axes(a)))
123+
# TODO: Define and use `isdual` here.
124+
@test typeof(axes(a', 1)) === typeof(dual(axes(a, 2)))
125+
@test typeof(axes(a', 2)) === typeof(dual(axes(a, 1)))
126+
@test isnothing(show(devnull, MIME("text/plain"), a))
126127

127-
@test isnothing(show(devnull, MIME("text/plain"), @view(a[Block(1, 1)])))
128-
@test @view(a[Block(1, 1)]) == a[Block(1, 1)]
128+
# Check preserving dual in tensor algebra.
129+
for b in (a + a, 2 * a, 3 * a - a)
130+
@test Array(b) 2 * Array(a)
131+
# TODO: Define and use `isdual` here.
132+
for dim in 1:ndims(a)
133+
@test typeof(axes(b, dim)) === typeof(axes(b, dim))
134+
end
135+
end
136+
137+
@test isnothing(show(devnull, MIME("text/plain"), @view(a[Block(1, 1)])))
138+
@test @view(a[Block(1, 1)]) == a[Block(1, 1)]
139+
end
129140

130141
# Test case when all axes are dual.
131142
for r in (gradedrange([U1(0) => 2, U1(1) => 2]), blockedrange([2, 2]))

src/BlockArraysExtensions/BlockArraysExtensions.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,21 @@ using Dictionaries: Dictionary, Indices
1919
using ..GradedAxes: blockedunitrange_getindices
2020
using ..SparseArrayInterface: stored_indices
2121

22+
# GenericBlockSlice works around an issue that the indices of BlockSlice
23+
# are restricted to Int element type.
24+
# TODO: Raise an issue/make a pull request in BlockArrays.jl.
25+
struct GenericBlockSlice{B,T<:Integer,I<:AbstractUnitRange{T}} <: AbstractUnitRange{T}
26+
block::B
27+
indices::I
28+
end
29+
BlockArrays.Block(bs::GenericBlockSlice{<:Block}) = bs.block
30+
for f in (:axes, :unsafe_indices, :axes1, :first, :last, :size, :length, :unsafe_length)
31+
@eval Base.$f(S::GenericBlockSlice) = Base.$f(S.indices)
32+
end
33+
Base.getindex(S::GenericBlockSlice, i::Integer) = getindex(S.indices, i)
34+
35+
# BlockIndices works around an issue that the indices of BlockSlice
36+
# are restricted to AbstractUnitRange{Int}.
2237
struct BlockIndices{B,T<:Integer,I<:AbstractVector{T}} <: AbstractVector{T}
2338
blocks::B
2439
indices::I
@@ -175,6 +190,13 @@ function blockrange(axis::AbstractUnitRange, r::BlockSlice)
175190
return blockrange(axis, r.block)
176191
end
177192

193+
# GenericBlockSlice works around an issue that the indices of BlockSlice
194+
# are restricted to Int element type.
195+
# TODO: Raise an issue/make a pull request in BlockArrays.jl.
196+
function blockrange(axis::AbstractUnitRange, r::GenericBlockSlice)
197+
return blockrange(axis, r.block)
198+
end
199+
178200
function blockrange(a::AbstractUnitRange, r::BlockIndices)
179201
return blockrange(a, r.blocks)
180202
end

src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,10 @@ end
7878
# TODO: Move to blocksparsearrayinterface.
7979
function blocksparse_unblock(a, inds, I::Tuple{AbstractUnitRange{<:Integer},Vararg{Any}})
8080
bs = blockrange(inds[1], I[1])
81-
return BlockSlice(bs, blockedunitrange_getindices(inds[1], I[1]))
81+
# GenericBlockSlice works around an issue that the indices of BlockSlice
82+
# are restricted to Int element type.
83+
# TODO: Raise an issue/make a pull request in BlockArrays.jl.
84+
return GenericBlockSlice(bs, blockedunitrange_getindices(inds[1], I[1]))
8285
end
8386

8487
function BlockArrays.unblock(a, inds, I::Tuple{AbstractVector{<:Block{1}},Vararg{Any}})

0 commit comments

Comments
 (0)