Skip to content

Commit 38a21f6

Browse files
committed
fix slicing
1 parent 705da7d commit 38a21f6

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed

NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ using NDTensors.GradedAxes:
1111
GradedUnitRangeDual,
1212
blocklabels,
1313
dual,
14-
gradedrange
14+
gradedrange,
15+
isdual
1516
using NDTensors.LabelledNumbers: label
1617
using NDTensors.SparseArrayInterface: nstored
1718
using NDTensors.TensorAlgebra: fusedims, splitdims
@@ -40,7 +41,7 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
4041
a = BlockSparseArray{elt}(d1, d2, d1, d2)
4142
blockdiagonal!(randn!, a)
4243
@test axes(a, 1) isa GradedOneTo
43-
@test axes(view(a, 1:4, 1:4), 1) isa GradedOneTo
44+
@test axes(view(a, 1:4, 1:4, 1:4, 1:4), 1) isa GradedOneTo
4445

4546
for b in (a + a, 2 * a)
4647
@test size(b) == (4, 4, 4, 4)
@@ -121,6 +122,7 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
121122
# TODO: Define and use `isdual` here.
122123
for dim in 1:ndims(a)
123124
@test typeof(ax[dim]) === typeof(axes(a, dim))
125+
@test isdual(ax[dim]) == isdual(axes(a, dim))
124126
end
125127
@test @view(a[Block(1, 1)])[1, 1] == a[1, 1]
126128
@test @view(a[Block(1, 1)])[2, 1] == a[2, 1]

NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -323,13 +323,13 @@ end
323323
# that mixed dense and graded axes.
324324
# TODO: Maybe come up with a more general solution.
325325
function BlockArrays.combine_blockaxes(
326-
a1::AbstractGradedUnitRange{T}, a2::Base.OneTo{T}
326+
a1::AbstractGradedUnitRange{<:LabelledInteger{T}}, a2::AbstractUnitRange{T}
327327
) where {T<:Integer}
328328
combined_blocklasts = sort!(union(unlabel.(blocklasts(a1)), blocklasts(a2)))
329329
return BlockedOneTo(combined_blocklasts)
330330
end
331331
function BlockArrays.combine_blockaxes(
332-
a1::Base.OneTo{T}, a2::AbstractGradedUnitRange{T}
332+
a1::AbstractUnitRange{T}, a2::AbstractGradedUnitRange{<:LabelledInteger{T}}
333333
) where {T<:Integer}
334334
return BlockArrays.combine_blockaxes(a2, a1)
335335
end

0 commit comments

Comments
 (0)