Skip to content

Commit efe1c7c

Browse files
committed
custom type for AbstractGradedUnitRange
1 parent d7e3a8d commit efe1c7c

File tree

2 files changed

+52
-36
lines changed

2 files changed

+52
-36
lines changed

NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl

Lines changed: 50 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -29,19 +29,26 @@ using ..LabelledNumbers:
2929
labelled_isequal,
3030
unlabel
3131

32-
const AbstractGradedUnitRange{T<:LabelledInteger} = AbstractBlockedUnitRange{T}
32+
abstract type AbstractGradedUnitRange{T,CS} <: AbstractBlockedUnitRange{T,CS} end
3333

34-
const GradedUnitRange{T<:LabelledInteger,BlockLasts<:Vector{T}} = BlockedUnitRange{
35-
T,BlockLasts
36-
}
34+
struct GradedUnitRange{T,BlockLasts<:Vector{T}} <: AbstractGradedUnitRange{T,BlockLasts}
35+
first::T
36+
lasts::BlockLasts
37+
end
3738

38-
const GradedOneTo{T<:LabelledInteger,BlockLasts<:Vector{T}} = BlockedOneTo{T,BlockLasts}
39+
struct GradedOneTo{T,BlockLasts<:Vector{T}} <: AbstractGradedUnitRange{T,BlockLasts}
40+
lasts::BlockLasts
3941

40-
# This is only needed in certain Julia versions below 1.10
41-
# (for example Julia 1.6).
42-
# TODO: Delete this once we drop Julia 1.6 support.
43-
function Base.OrdinalRange{T,T}(a::GradedOneTo{<:LabelledInteger{T}}) where {T}
44-
return unlabel_blocks(a)
42+
# assume that lasts is sorted, no checks carried out here
43+
function GradedOneTo(lasts::CS) where {T<:Integer,CS<:AbstractVector{T}}
44+
Base.require_one_based_indexing(lasts)
45+
isempty(lasts) || first(lasts) >= 0 || throw(ArgumentError("blocklasts must be >= 0"))
46+
return new{T,CS}(lasts)
47+
end
48+
function GradedOneTo(lasts::CS) where {T<:Integer,CS<:Tuple{T,Vararg{T}}}
49+
first(lasts) >= 0 || throw(ArgumentError("blocklasts must be >= 0"))
50+
return new{T,CS}(lasts)
51+
end
4552
end
4653

4754
# == is just a range comparison that ignores labels. Need dedicated function to check equality.
@@ -90,7 +97,7 @@ Base.eltype(::Type{<:GradedUnitRange{T}}) where {T} = T
9097
function gradedrange(lblocklengths::AbstractVector{<:LabelledInteger})
9198
brange = blockedrange(unlabel.(lblocklengths))
9299
lblocklasts = labelled.(blocklasts(brange), label.(lblocklengths))
93-
return BlockedOneTo(lblocklasts)
100+
return GradedOneTo(lblocklasts)
94101
end
95102

96103
# To help with generic code.
@@ -118,14 +125,12 @@ end
118125
function labelled_blocks(a::BlockedOneTo, labels)
119126
# TODO: Use `blocklasts(a)`? That might
120127
# cause a recursive loop.
121-
return BlockedOneTo(labelled.(a.lasts, labels))
128+
return GradedOneTo(labelled.(a.lasts, labels))
122129
end
123130
function labelled_blocks(a::BlockedUnitRange, labels)
124131
# TODO: Use `first(a)` and `blocklasts(a)`? Those might
125132
# cause a recursive loop.
126-
return BlockArrays._BlockedUnitRange(
127-
labelled(a.first, labels[1]), labelled.(a.lasts, labels)
128-
)
133+
return GradedUnitRange(labelled(a.first, labels[1]), labelled.(a.lasts, labels))
129134
end
130135

131136
function BlockArrays.findblock(a::AbstractGradedUnitRange, index::Integer)
@@ -185,7 +190,15 @@ function unlabel_blocks(a::BlockedUnitRange)
185190
return BlockArrays._BlockedUnitRange(a.first, unlabel.(a.lasts))
186191
end
187192

188-
## BlockedUnitRage interface
193+
function unlabel_blocks(a::GradedOneTo)
194+
# TODO: Use `blocklasts(a)`.
195+
return BlockedOneTo(unlabel.(a.lasts))
196+
end
197+
function unlabel_blocks(a::GradedUnitRange)
198+
return BlockArrays._BlockedUnitRange(a.first, unlabel.(a.lasts))
199+
end
200+
201+
## BlockedUnitRange interface
189202

190203
function Base.axes(ga::AbstractGradedUnitRange)
191204
return map(axes(unlabel_blocks(ga))) do a
@@ -217,9 +230,6 @@ end
217230
function Base.first(a::AbstractGradedUnitRange)
218231
return gradedunitrange_first(a)
219232
end
220-
function Base.first(a::GradedOneTo)
221-
return gradedunitrange_first(a)
222-
end
223233

224234
Base.iterate(a::AbstractGradedUnitRange) = isempty(a) ? nothing : (first(a), first(a))
225235
function Base.iterate(a::AbstractGradedUnitRange, i)
@@ -232,7 +242,7 @@ function firstblockindices(a::AbstractGradedUnitRange)
232242
return labelled.(firstblockindices(unlabel_blocks(a)), blocklabels(a))
233243
end
234244

235-
function blockedunitrange_getindex(a::AbstractGradedUnitRange, index)
245+
function gradedunitrange_getindices(a::AbstractGradedUnitRange, index)
236246
# This uses `blocklasts` since that is what is stored
237247
# in `BlockedUnitRange`, maybe abstract that away.
238248
return labelled(unlabel_blocks(a)[index], get_label(a, index))
@@ -245,27 +255,34 @@ function blocklabels(a::AbstractUnitRange, indices)
245255
end
246256
end
247257

248-
function blockedunitrange_getindices(
258+
function gradedunitrange_getindices(
249259
ga::AbstractGradedUnitRange, indices::AbstractUnitRange{<:Integer}
250260
)
251261
a_indices = blockedunitrange_getindices(unlabel_blocks(ga), indices)
252262
return labelled_blocks(a_indices, blocklabels(ga, indices))
253263
end
254264

265+
function gradedunitrange_getindices(
266+
a::AbstractGradedUnitRange,
267+
indices::Union{AbstractVector{<:Block{1}},AbstractVector{<:BlockIndexRange{1}}},
268+
)
269+
return blockedunitrange_getindices(a, indices)
270+
end
271+
255272
# Fixes ambiguity error with:
256273
# ```julia
257-
# blockedunitrange_getindices(::GradedUnitRange, ::AbstractUnitRange{<:Integer})
274+
# gradedunitrange_getindices(::GradedUnitRange, ::AbstractUnitRange{<:Integer})
258275
# ```
259276
# TODO: Try removing once GradedAxes is rewritten for BlockArrays v1.
260-
function blockedunitrange_getindices(a::AbstractGradedUnitRange, indices::BlockSlice)
277+
function gradedunitrange_getindices(a::AbstractGradedUnitRange, indices::BlockSlice)
261278
return a[indices.block]
262279
end
263280

264-
function blockedunitrange_getindices(ga::AbstractGradedUnitRange, indices::BlockRange)
281+
function gradedunitrange_getindices(ga::AbstractGradedUnitRange, indices::BlockRange)
265282
return labelled_blocks(unlabel_blocks(ga)[indices], blocklabels(ga, indices))
266283
end
267284

268-
function blockedunitrange_getindices(a::AbstractGradedUnitRange, indices::BlockIndex{1})
285+
function gradedunitrange_getindices(a::AbstractGradedUnitRange, indices::BlockIndex{1})
269286
return a[block(indices)][blockindex(indices)]
270287
end
271288

@@ -276,7 +293,7 @@ function Base.getindex(a::AbstractGradedUnitRange, index::Integer)
276293
end
277294

278295
function Base.getindex(a::AbstractGradedUnitRange, index::Block{1})
279-
return blockedunitrange_getindex(a, index)
296+
return gradedunitrange_getindices(a, index)
280297
end
281298

282299
function Base.getindex(a::AbstractGradedUnitRange, indices::BlockIndexRange)
@@ -286,18 +303,18 @@ end
286303
function Base.getindex(
287304
a::AbstractGradedUnitRange, indices::BlockRange{1,<:Tuple{AbstractUnitRange{Int}}}
288305
)
289-
return blockedunitrange_getindices(a, indices)
306+
return gradedunitrange_getindices(a, indices)
290307
end
291308

292309
# Fixes ambiguity error with `BlockArrays`.
293310
function Base.getindex(
294311
a::AbstractGradedUnitRange, indices::BlockRange{1,Tuple{Base.OneTo{Int}}}
295312
)
296-
return blockedunitrange_getindices(a, indices)
313+
return gradedunitrange_getindices(a, indices)
297314
end
298315

299316
function Base.getindex(a::AbstractGradedUnitRange, indices::BlockIndex{1})
300-
return blockedunitrange_getindices(a, indices)
317+
return gradedunitrange_getindices(a, indices)
301318
end
302319

303320
# Fixes ambiguity issues with:
@@ -310,15 +327,15 @@ end
310327
# TODO: Maybe not needed once GradedAxes is rewritten
311328
# for BlockArrays v1.
312329
function Base.getindex(a::AbstractGradedUnitRange, indices::BlockSlice)
313-
return blockedunitrange_getindices(a, indices)
330+
return gradedunitrange_getindices(a, indices)
314331
end
315332

316333
function Base.getindex(a::AbstractGradedUnitRange, indices)
317-
return blockedunitrange_getindices(a, indices)
334+
return gradedunitrange_getindices(a, indices)
318335
end
319336

320337
function Base.getindex(a::AbstractGradedUnitRange, indices::AbstractUnitRange{<:Integer})
321-
return blockedunitrange_getindices(a, indices)
338+
return gradedunitrange_getindices(a, indices)
322339
end
323340

324341
# This fixes an issue that `combine_blockaxes` was promoting
@@ -352,7 +369,7 @@ end
352369
# blocklengths = map(bs -> sum(b -> length(a[b]), bs), blocks(indices))
353370
# return blockedrange(blocklengths)
354371
# ```
355-
function blockedunitrange_getindices(
372+
function gradedunitrange_getindices(
356373
a::AbstractGradedUnitRange, indices::AbstractBlockVector{<:Block{1}}
357374
)
358375
blks = map(bs -> mortar(map(b -> a[b], bs)), blocks(indices))

NDTensors/src/lib/GradedAxes/src/gradedunitrangedual.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
1-
struct GradedUnitRangeDual{
2-
T<:LabelledInteger,NondualUnitRange<:AbstractGradedUnitRange{T}
3-
} <: AbstractGradedUnitRange{T,Vector{T}}
1+
struct GradedUnitRangeDual{T,CS,NondualUnitRange<:AbstractGradedUnitRange{T,CS}} <:
2+
AbstractGradedUnitRange{T,CS}
43
nondual_unitrange::NondualUnitRange
54
end
65

0 commit comments

Comments
 (0)