Skip to content

Commit 24fb3c8

Browse files
committed
fix combine_blockaxes
1 parent de6080b commit 24fb3c8

File tree

2 files changed

+14
-3
lines changed

2 files changed

+14
-3
lines changed

NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
3939
d2 = gradedrange([U1(0) => 2, U1(1) => 2])
4040
a = BlockSparseArray{elt}(d1, d2, d1, d2)
4141
blockdiagonal!(randn!, a)
42+
@test axes(a, 1) isa GradedOneTo
43+
@test axes(view(a, 1:4, 1:4), 1) isa GradedOneTo
4244

4345
for b in (a + a, 2 * a)
4446
@test size(b) == (4, 4, 4, 4)

NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ using BlockArrays:
1717
blocklengths,
1818
findblock,
1919
findblockindex,
20-
mortar
20+
mortar,
21+
sortedunion
2122
using Compat: allequal
2223
using FillArrays: Fill
2324
using ..LabelledNumbers:
@@ -304,17 +305,25 @@ end
304305
# that mixed dense and graded axes.
305306
# TODO: Maybe come up with a more general solution.
306307
function BlockArrays.combine_blockaxes(
307-
a1::GradedOneTo{<:LabelledInteger{T}}, a2::Base.OneTo{T}
308+
a1::AbstractGradedUnitRange{T}, a2::Base.OneTo{T}
308309
) where {T<:Integer}
309310
combined_blocklasts = sort!(union(unlabel.(blocklasts(a1)), blocklasts(a2)))
310311
return BlockedOneTo(combined_blocklasts)
311312
end
312313
function BlockArrays.combine_blockaxes(
313-
a1::Base.OneTo{T}, a2::GradedOneTo{<:LabelledInteger{T}}
314+
a1::Base.OneTo{T}, a2::AbstractGradedUnitRange{T}
314315
) where {T<:Integer}
315316
return BlockArrays.combine_blockaxes(a2, a1)
316317
end
317318

319+
# preserve labels inside combine_blockaxes
320+
# TODO dual
321+
function BlockArrays.combine_blockaxes(
322+
a::AbstractGradedUnitRange, b::AbstractGradedUnitRange
323+
)
324+
return gradedrange(sortedunion(blocklasts(a), blocklasts(b)))
325+
end
326+
318327
# Version of length that checks that all blocks have the same label
319328
# and returns a labelled length with that label.
320329
function labelled_length(a::AbstractBlockVector{<:Integer})

0 commit comments

Comments
 (0)