87
87
88
88
89
89
"""
90
- contiguous_axis_indicator(::Type{T}) -> Tuple{Vararg{<: Val}}
90
+ contiguous_axis_indicator(::Type{T}) -> Tuple{Vararg{Val}}
91
91
92
92
Returns a tuple boolean `Val`s indicating whether that axis is contiguous.
93
93
"""
@@ -98,53 +98,6 @@ contiguous_axis_indicator(::Nothing, ::Val) = nothing
98
98
Base. @pure contiguous_axis_indicator (:: Contiguous{N} , :: Val{D} ) where {N,D} =
99
99
ntuple (d -> Val {d == N} (), Val {D} ())
100
100
101
- """
102
- If the contiguous dimension is not the dimension with `StrideRank{1}`:
103
- """
104
- struct ContiguousBatch{N} end
105
- Base. @pure ContiguousBatch (N:: Int ) = ContiguousBatch {N} ()
106
- _get (:: ContiguousBatch{N} ) where {N} = N
107
-
108
- """
109
- contiguous_batch_size(::Type{T}) -> ContiguousBatch{N}
110
-
111
- Returns the Base.size of contiguous batches if `!isone(stride_rank(T, contiguous_axis(T)))`.
112
- If `isone(stride_rank(T, contiguous_axis(T)))`, then it will return `ContiguousBatch{0}()`.
113
- If `contiguous_axis(T) == -1`, it will return `ContiguousBatch{-1}()`.
114
- If unknown, it will return `nothing`.
115
- """
116
- contiguous_batch_size (x) = contiguous_batch_size (typeof (x))
117
- contiguous_batch_size (:: Type ) = nothing
118
- contiguous_batch_size (:: Type{Array{T,N}} ) where {T,N} = ContiguousBatch {0} ()
119
- contiguous_batch_size (:: Type{<:Tuple} ) = ContiguousBatch {0} ()
120
- contiguous_batch_size (
121
- :: Type{<:Union{Transpose{T,A},Adjoint{T,A}}} ,
122
- ) where {T,A<: AbstractVecOrMat{T} } = contiguous_batch_size (A)
123
- contiguous_batch_size (
124
- :: Type{<:PermutedDimsArray{T,N,I1,I2,A}} ,
125
- ) where {T,N,I1,I2,A<: AbstractArray{T,N} } = contiguous_batch_size (A)
126
- function contiguous_batch_size (
127
- :: Type{S} ,
128
- ) where {N,NP,T,A<: AbstractArray{T,NP} ,I,S<: SubArray{T,N,A,I} }
129
- _contiguous_batch_size (S, contiguous_batch_size (A), contiguous_axis (A))
130
- end
131
- _contiguous_batch_size (:: Any , :: Any , :: Any ) = nothing
132
- @generated function _contiguous_batch_size (
133
- :: Type{S} ,
134
- :: ContiguousBatch{B} ,
135
- :: Contiguous{C} ,
136
- ) where {B,C,N,NP,T,A<: AbstractArray{T,NP} ,I,S<: SubArray{T,N,A,I} }
137
- if I. parameters[C] <: AbstractUnitRange
138
- Expr (:call , Expr (:curly , :ContiguousBatch , B))
139
- else
140
- Expr (:call , Expr (:curly , :ContiguousBatch , - 1 ))
141
- end
142
- end
143
-
144
- contiguous_batch_size (
145
- :: Type{R} ,
146
- ) where {T,N,S,A<: Array{S} ,R<: Base.ReinterpretArray{T,N,S,A} } = ContiguousBatch {0} ()
147
-
148
101
struct StrideRank{R} end
149
102
Base. @pure StrideRank (R:: NTuple{<:Any,Int} ) = StrideRank {R} ()
150
103
_get (:: StrideRank{R} ) where {R} = R
@@ -230,6 +183,67 @@ stride_rank(x, i) = stride_rank(x)[i]
230
183
stride_rank (:: Type{R} ) where {T,N,S,A<: Array{S} ,R<: Base.ReinterpretArray{T,N,S,A} } =
231
184
StrideRank {ntuple(identity, Val{N}())} ()
232
185
186
+ function stride_rank (:: Type {Base. ReshapedArray{T, N, P, Tuple{Vararg{Base. SignedMultiplicativeInverse{Int},M}}}}) where {T,N,P,M}
187
+
188
+ _reshaped_striderank (is_column_major (P), Val {N} (), Val {M} ())
189
+ end
190
+ _reshaped_striderank (:: Val{true} , :: Val{N} , :: Val{0} ) where {N} = StrideRank {ntuple(identity, Val{N}())} ()
191
+ _reshaped_striderank (_, __, ___) = nothing
192
+
193
+
194
+ """
195
+ If the contiguous dimension is not the dimension with `StrideRank{1}`:
196
+ """
197
+ struct ContiguousBatch{N} end
198
+ Base. @pure ContiguousBatch (N:: Int ) = ContiguousBatch {N} ()
199
+ _get (:: ContiguousBatch{N} ) where {N} = N
200
+
201
+ """
202
+ contiguous_batch_size(::Type{T}) -> ContiguousBatch{N}
203
+
204
+ Returns the Base.size of contiguous batches if `!isone(stride_rank(T, contiguous_axis(T)))`.
205
+ If `isone(stride_rank(T, contiguous_axis(T)))`, then it will return `ContiguousBatch{0}()`.
206
+ If `contiguous_axis(T) == -1`, it will return `ContiguousBatch{-1}()`.
207
+ If unknown, it will return `nothing`.
208
+ """
209
+ contiguous_batch_size (x) = contiguous_batch_size (typeof (x))
210
+ contiguous_batch_size (:: Type{T} ) where {T} = _contiguous_batch_size (contiguous_axis (T), stride_rank (T))
211
+ _contiguous_batch_size (_, __) = nothing
212
+ @generated function _contiguous_batch_size (:: Contiguous{D} , :: StrideRank{R} ) where {D,R}
213
+ isone (R[D]) ? :(ContiguousBatch {0} ()) : :nothing
214
+ end
215
+
216
+ contiguous_batch_size (:: Type{Array{T,N}} ) where {T,N} = ContiguousBatch {0} ()
217
+ contiguous_batch_size (:: Type{<:Tuple} ) = ContiguousBatch {0} ()
218
+ contiguous_batch_size (
219
+ :: Type{<:Union{Transpose{T,A},Adjoint{T,A}}} ,
220
+ ) where {T,A<: AbstractVecOrMat{T} } = contiguous_batch_size (A)
221
+ contiguous_batch_size (
222
+ :: Type{<:PermutedDimsArray{T,N,I1,I2,A}} ,
223
+ ) where {T,N,I1,I2,A<: AbstractArray{T,N} } = contiguous_batch_size (A)
224
+ function contiguous_batch_size (
225
+ :: Type{S} ,
226
+ ) where {N,NP,T,A<: AbstractArray{T,NP} ,I,S<: SubArray{T,N,A,I} }
227
+ _contiguous_batch_size (S, contiguous_batch_size (A), contiguous_axis (A))
228
+ end
229
+ _contiguous_batch_size (:: Any , :: Any , :: Any ) = nothing
230
+ @generated function _contiguous_batch_size (
231
+ :: Type{S} ,
232
+ :: ContiguousBatch{B} ,
233
+ :: Contiguous{C} ,
234
+ ) where {B,C,N,NP,T,A<: AbstractArray{T,NP} ,I,S<: SubArray{T,N,A,I} }
235
+ if I. parameters[C] <: AbstractUnitRange
236
+ Expr (:call , Expr (:curly , :ContiguousBatch , B))
237
+ else
238
+ Expr (:call , Expr (:curly , :ContiguousBatch , - 1 ))
239
+ end
240
+ end
241
+
242
+ contiguous_batch_size (
243
+ :: Type{R} ,
244
+ ) where {T,N,S,A<: Array{S} ,R<: Base.ReinterpretArray{T,N,S,A} } = ContiguousBatch {0} ()
245
+
246
+
233
247
"""
234
248
is_column_major(A) -> Val{true/false}()
235
249
@@ -260,7 +274,8 @@ An axis `i` of array `A` is dense if `stride(A, i) * Base.size(A, i) == stride(A
260
274
"""
261
275
dense_dims (x) = dense_dims (typeof (x))
262
276
dense_dims (:: Type ) = nothing
263
- dense_dims (:: Type{Array{T,N}} ) where {T,N} = DenseDims {ntuple(_ -> true, Val{N}())} ()
277
+ _all_dense (:: Val{N} ) where {N} = DenseDims {ntuple(_ -> true, Val{N}())} ()
278
+ dense_dims (:: Type{Array{T,N}} ) where {T,N} = _all_dense (Val {N} ())
264
279
dense_dims (:: Type{<:Tuple} ) = DenseDims {(true,)} ()
265
280
function dense_dims (
266
281
:: Type{<:Union{Transpose{T,A},Adjoint{T,A}}} ,
@@ -306,6 +321,15 @@ _dense_dims(::Any, ::Any) = nothing
306
321
length (dense_tup. args) == N ? Expr (:call , Expr (:curly , :DenseDims , dense_tup)) : nothing
307
322
end
308
323
324
+ function dense_dims (:: Type {Base. ReshapedArray{T, N, P, Tuple{Vararg{Base. SignedMultiplicativeInverse{Int},M}}}}) where {T,N,P,M}
325
+
326
+ _reshaped_dense_dims (dense_dims (P), is_column_major (P), Val {N} (), Val {M} ())
327
+ end
328
+ _reshaped_dense_dims (_, __, ___, ____) = nothing
329
+ @generated function _reshaped_dense_dims (:: DenseDims{D} , :: Val{true} , :: Val{N} , :: Val{0} ) where {D,N}
330
+ all (D) ? :(_all_dense (Val {$N} ())) : :nothing
331
+ end
332
+
309
333
permute (t:: NTuple{N} , I:: NTuple{N,Int} ) where {N} = ntuple (n -> t[I[n]], Val {N} ())
310
334
@generated function permute (t:: Tuple{Vararg{Any,N}} , :: Val{I} ) where {N,I}
311
335
t = Expr (:tuple )
0 commit comments