@@ -37,11 +37,9 @@ function axes_types(::Type{T}) where {T}
37
37
end
38
38
axes_types (:: Type{LinearIndices{N,R}} ) where {N,R} = R
39
39
axes_types (:: Type{CartesianIndices{N,R}} ) where {N,R} = R
40
- function axes_types (:: Type{T} ) where {T<: VecAdjTrans }
41
- return Tuple{OptionallyStaticUnitRange{One,One},axes_types (parent_type (T), One ())}
42
- end
43
- function axes_types (:: Type{T} ) where {T<: MatAdjTrans }
44
- return eachop_tuple (_get_tuple, to_parent_dims (T), axes_types (parent_type (T)))
40
+ function axes_types (:: Type{T} ) where {T<: Union{Adjoint,Transpose} }
41
+ P = parent_type (T)
42
+ return Tuple{axes_types (P, static (2 )), axes_types (P, static (1 ))}
45
43
end
46
44
function axes_types (:: Type{T} ) where {T<: PermutedDimsArray }
47
45
return eachop_tuple (_get_tuple, to_parent_dims (T), axes_types (parent_type (T)))
@@ -133,6 +131,21 @@ function axes(a::A, dim::Integer) where {A}
133
131
return axes (parent (a), to_parent_dims (A, dim))
134
132
end
135
133
end
134
+ function axes (A:: CartesianIndices{N} , dim:: Integer ) where {N}
135
+ if dim > N
136
+ return static (1 ): static (1 )
137
+ else
138
+ return getfield (axes (A), Int (dim))
139
+ end
140
+ end
141
+ function axes (A:: LinearIndices{N} , dim:: Integer ) where {N}
142
+ if dim > N
143
+ return static (1 ): static (1 )
144
+ else
145
+ return getfield (axes (A), Int (dim))
146
+ end
147
+ end
148
+
136
149
axes (A:: SubArray , dim:: Integer ) = Base. axes (A, Int (dim)) # TODO implement ArrayInterface version
137
150
axes (A:: ReinterpretArray , dim:: Integer ) = Base. axes (A, Int (dim)) # TODO implement ArrayInterface version
138
151
axes (A:: Base.ReshapedArray , dim:: Integer ) = Base. axes (A, Int (dim)) # TODO implement ArrayInterface version
@@ -160,3 +173,137 @@ axes(A::Base.ReshapedArray) = Base.axes(A) # TODO implement ArrayInterface vers
160
173
axes (A:: CartesianIndices ) = A. indices
161
174
axes (A:: LinearIndices ) = A. indices
162
175
176
+ """
177
+ LazyAxis{N}(parent::AbstractArray)
178
+
179
+ A lazy representation of `axes(parent, N)`.
180
+ """
181
+ struct LazyAxis{N,P} <: AbstractUnitRange{Int}
182
+ parent:: P
183
+
184
+ LazyAxis {N} (parent:: P ) where {N,P} = new {N::Int,P} (parent)
185
+ @inline function LazyAxis {:} (parent:: P ) where {P}
186
+ if ndims (P) === 1
187
+ return new {1,P} (parent)
188
+ else
189
+ return new {:,P} (parent)
190
+ end
191
+ end
192
+ end
193
+
194
+ @inline Base. parent (x:: LazyAxis{N,P} ) where {N,P} = axes (getfield (x, :parent ), static (N))
195
+ @inline function Base. parent (x:: LazyAxis{:,P} ) where {P}
196
+ return eachindex (IndexLinear (), getfield (x, :parent ))
197
+ end
198
+
199
+ @inline parent_type (:: Type{LazyAxis{N,P}} ) where {N,P} = axes_types (P, static (N))
200
+ # TODO this approach to parent_type(::Type{LazyAxis{:}}) is a bit hacky. Something like
201
+ # LabelledArrays has a linear set of symbolic keys, which could be propagated through
202
+ # `to_indices` for key based indexing. However, there currently isn't a good way of handling
203
+ # that when the linear indices aren't linearly accessible through a child array (e.g, adjoint)
204
+ # For now we just make sure the linear elements are accurate.
205
+ parent_type (:: Type{LazyAxis{:,P}} ) where {P<: Array } = OneTo{Int}
206
+ @inline function parent_type (:: Type{LazyAxis{:,P}} ) where {P}
207
+ if known_length (P) === nothing
208
+ return OptionallyStaticUnitRange{StaticInt{1 },Int}
209
+ else
210
+ return OptionallyStaticUnitRange{StaticInt{1 },StaticInt{known_length (P)}}
211
+ end
212
+ end
213
+
214
+ Base. keys (x:: LazyAxis ) = keys (parent (x))
215
+
216
+ Base. IndexStyle (:: Type{T} ) where {T<: LazyAxis } = IndexStyle (parent_type (T))
217
+
218
+ can_change_size (:: Type{LazyAxis{N,P}} ) where {N,P} = can_change_size (P)
219
+
220
+ known_first (:: Type{T} ) where {T<: LazyAxis } = known_first (parent_type (T))
221
+
222
+ known_length (:: Type{LazyAxis{N,P}} ) where {N,P} = known_size (P, N)
223
+ known_length (:: Type{LazyAxis{:,P}} ) where {P} = known_length (P)
224
+
225
+ @inline function known_last (:: Type{T} ) where {T<: LazyAxis }
226
+ return _lazy_axis_known_last (known_first (T), known_length (T))
227
+ end
228
+ _lazy_axis_known_last (start:: Int , length:: Int ) = (length + start) - 1
229
+ _lazy_axis_known_last (:: Any , :: Any ) = nothing
230
+
231
+ @inline function Base. first (x:: LazyAxis{N} ):: Int where {N}
232
+ if known_first (x) === nothing
233
+ return offsets (getfield (x, :parent ), static (N))
234
+ else
235
+ return known_first (x)
236
+ end
237
+ end
238
+ @inline function Base. first (x:: LazyAxis{:} ):: Int
239
+ if known_first (x) === nothing
240
+ return firstindex (getfield (x, :parent ))
241
+ else
242
+ return known_first (x)
243
+ end
244
+ end
245
+
246
+ @inline function Base. length (x:: LazyAxis{N} ):: Int where {N}
247
+ if known_length (x) === nothing
248
+ return size (getfield (x, :parent ), static (N))
249
+ else
250
+ return known_length (x)
251
+ end
252
+ end
253
+ @inline function Base. length (x:: LazyAxis{:} ):: Int
254
+ if known_length (x) === nothing
255
+ return lastindex (getfield (x, :parent ))
256
+ else
257
+ return known_length (x)
258
+ end
259
+ end
260
+
261
+ @inline function Base. last (x:: LazyAxis ):: Int
262
+ if known_last (x) === nothing
263
+ if known_first (x) === 1
264
+ return length (x)
265
+ else
266
+ return (static_length (x) + static_first (x)) - 1
267
+ end
268
+ else
269
+ return known_last (x)
270
+ end
271
+ end
272
+
273
+ Base. to_shape (x:: LazyAxis ) = length (x)
274
+
275
+ @inline function Base. checkindex (:: Type{Bool} , x:: LazyAxis , i:: Integer )
276
+ if known_first (x) === nothing || known_last (x) === nothing
277
+ return checkindex (Bool, parent (x), i)
278
+ else # everything is static so we don't have to retrieve the axis
279
+ return (! (known_first (x) > i) || ! (known_last (x) < i))
280
+ end
281
+ end
282
+
283
+ @propagate_inbounds function Base. getindex (x:: LazyAxis , i:: Integer )
284
+ @boundscheck checkindex (Bool, x, i) || throw (BoundsError (x, i))
285
+ return Int (i)
286
+ end
287
+ @propagate_inbounds Base. getindex (x:: LazyAxis , i:: StepRange{T} ) where {T<: Integer } = parent (x)[i]
288
+ @propagate_inbounds Base. getindex (x:: LazyAxis , i:: AbstractUnitRange{<:Integer} ) = parent (x)[i]
289
+
290
+ Base. show (io:: IO , x:: LazyAxis{N} ) where {N} = print (io, " LazyAxis{$N }($(parent (x)) ))" )
291
+
292
+ """
293
+ lazy_axes(x)
294
+
295
+ Produces a tuple of axes where each axis is constructed lazily. If an axis of `x` is already
296
+ constructed or it is simply retrieved.
297
+ """
298
+ @generated function lazy_axes (x:: X ) where {X}
299
+ Expr (:block ,
300
+ Expr (:meta , :inline ),
301
+ Expr (:tuple , [:(LazyAxis {$dim} (x)) for dim in 1 : ndims (X)]. .. )
302
+ )
303
+ end
304
+ lazy_axes (x:: LinearIndices ) = axes (x)
305
+ lazy_axes (x:: CartesianIndices ) = axes (x)
306
+ @inline lazy_axes (x:: MatAdjTrans ) = reverse (lazy_axes (parent (x)))
307
+ @inline lazy_axes (x:: VecAdjTrans ) = (LazyAxis {1} (x), first (lazy_axes (parent (x))))
308
+ @inline lazy_axes (x:: PermutedDimsArray ) = permute (lazy_axes (parent (x)), to_parent_dims (A))
309
+
0 commit comments