7
7
Base. @propagate_inbounds Base. getindex (A:: LowDimArray , i:: Vararg{Union{Integer,CartesianIndex},K} ) where {K} = getindex (A. data, i... )
8
8
@inline Base. size (A:: LowDimArray ) = Base. size (A. data)
9
9
@inline Base. size (A:: LowDimArray , i) = Base. size (A. data, i)
10
- @inline Base . strides (A :: LowDimArray ) = strides (A . data)
10
+
11
11
@inline ArrayInterface. parent_type (:: Type{LowDimArray{D,T,N,A}} ) where {T,D,N,A} = A
12
- @inline ArrayInterface . strides (A:: LowDimArray ) = ArrayInterface . strides (A. data )
12
+ @inline Base . strides (A:: LowDimArray ) = map (Int, strides (A) )
13
13
@inline ArrayInterface. device (A:: LowDimArray ) = ArrayInterface. CPUPointer ()
14
14
@generated function ArrayInterface. size (A:: LowDimArray{D,T,N} ) where {D,T,N}
15
- t = Expr (:tuple )
16
- gf = GlobalRef (Core,:getfield )
17
- for n ∈ 1 : N
18
- if n > length (D) || D[n]
19
- push! (t. args, Expr (:call , gf, :s , n, false ))
20
- else
21
- push! (t. args, Expr (:call , Expr (:curly , lv (:StaticInt ), 1 )))
22
- end
15
+ t = Expr (:tuple )
16
+ gf = GlobalRef (Core,:getfield )
17
+ for n ∈ 1 : N
18
+ if n > length (D) || D[n]
19
+ push! (t. args, Expr (:call , gf, :s , n, false ))
20
+ else
21
+ push! (t. args, Expr (:call , Expr (:curly , lv (:StaticInt ), 1 )))
23
22
end
24
- Expr (:block , Expr (:meta ,:inline ), :(s = ArrayInterface. size (parent (A))), t)
23
+ end
24
+ Expr (:block , Expr (:meta ,:inline ), :(s = ArrayInterface. size (parent (A))), t)
25
25
end
26
26
Base. parent (A:: LowDimArray ) = getfield (A, :data )
27
27
Base. unsafe_convert (:: Type{Ptr{T}} , A:: LowDimArray{D,T} ) where {D,T} = pointer (parent (A))
@@ -31,103 +31,174 @@ ArrayInterface.stride_rank(A::LowDimArray) = ArrayInterface.stride_rank(parent(A
31
31
ArrayInterface. offsets (A:: LowDimArray ) = ArrayInterface. offsets (parent (A))
32
32
33
33
@generated function _lowdimfilter (:: Val{D} , tup:: Tuple{Vararg{Any,N}} ) where {D,N}
34
- t = Expr (:tuple )
35
- gf = GlobalRef (Core,:getfield )
36
- for n ∈ 1 : N
37
- if n > length (D) || D[n]
38
- push! (t. args, Expr (:call , gf, :tup , n, false ))
39
- end
34
+ t = Expr (:tuple )
35
+ gf = GlobalRef (Core,:getfield )
36
+ for n ∈ 1 : N
37
+ if n > length (D) || D[n]
38
+ push! (t. args, Expr (:call , gf, :tup , n, false ))
40
39
end
41
- Expr (:block , Expr (:meta ,:inline ), t)
40
+ end
41
+ Expr (:block , Expr (:meta ,:inline ), t)
42
42
end
43
43
44
44
struct ForBroadcast{T,N,A<: AbstractArray{T,N} } <: AbstractArray{T,N}
45
- data:: A
45
+ data:: A
46
46
end
47
47
@inline Base. parent (fb:: ForBroadcast ) = getfield (fb, :data )
48
+ @inline ArrayInterface. parent_type (:: Type{ForBroadcast{T,N,A}} ) where {T,N,A} = A
48
49
Base. @propagate_inbounds Base. getindex (A:: ForBroadcast , i:: Vararg{Any,K} ) where {K} = parent (A)[i... ]
49
50
const LowDimArrayForBroadcast{D,T,N,A} = ForBroadcast{T,N,LowDimArray{D,T,N,A}}
50
51
@inline function VectorizationBase. contiguous_axis (fb:: LowDimArrayForBroadcast{D,T,N,A} ) where {D,T,N,A}
51
- _contiguous_axis (Val {D} (), VectorizationBase. contiguous_axis (parent (parent (fb))))
52
+ _contiguous_axis (Val {D} (), VectorizationBase. contiguous_axis (parent (parent (fb))))
52
53
end
53
54
@inline forbroadcast (A:: AbstractArray ) = ForBroadcast (A)
54
55
@inline forbroadcast (A:: AbstractRange ) = A
55
56
@inline forbroadcast (A) = A
56
57
@inline forbroadcast (A:: Adjoint ) = forbroadcast (parent (A))
57
58
@inline forbroadcast (A:: Transpose ) = forbroadcast (parent (A))
58
- Base. size (fb:: ForBroadcast ) = size (parent (fb))
59
+ @inline function ArrayInterface. strides (A:: Union{LowDimArray,ForBroadcast} )
60
+ B = parent (A)
61
+ _strides (size (A), strides (B), VectorizationBase. val_stride_rank (B), VectorizationBase. val_dense_dims (B))
62
+ end
63
+
59
64
60
65
# @inline function VectorizationBase.contiguous_batch_size(fb::LowDimArrayForBroadcast{D,T,N,A}) where {D,T,N,A}
61
66
# _contiguous_axis(Val{D}(), VectorizationBase.contiguous_batch_size(parent(parent(fb))))
62
67
# end
63
68
@generated function _contiguous_axis (:: Val{D} , :: StaticInt{C} ) where {D,C}
64
- Dlen = length (D)
65
- (C > 0 ) || return Expr (:block ,Expr (:meta ,:inline ), staticexpr (C))
66
- if C ≤ Dlen
67
- D[C] || return Expr (:block ,Expr (:meta ,:inline ), staticexpr (- 1 ))
68
- end
69
- Cnew = 0
70
- for n ∈ 1 : C
71
- Cnew += ((n > Dlen)) || D[n]
72
- end
73
- Expr (:block ,Expr (:meta ,:inline ), staticexpr (Cnew))
69
+ Dlen = length (D)
70
+ (C > 0 ) || return Expr (:block ,Expr (:meta ,:inline ), staticexpr (C))
71
+ if C ≤ Dlen
72
+ D[C] || return Expr (:block ,Expr (:meta ,:inline ), staticexpr (- 1 ))
73
+ end
74
+ Cnew = 0
75
+ for n ∈ 1 : C
76
+ Cnew += ((n > Dlen)) || D[n]
77
+ end
78
+ Expr (:block ,Expr (:meta ,:inline ), staticexpr (Cnew))
74
79
end
75
- @inline function VectorizationBase . val_stride_rank (fb :: LowDimArrayForBroadcast{D} ) where {D}
76
- VectorizationBase . asvalint ( _lowdimfilter (Val (D), ArrayInterface. stride_rank (parent ( parent (fb))) ))
80
+ @inline function ArrayInterface . stride_rank ( :: Type{ LowDimArrayForBroadcast{D,T,N,A}} ) where {D,T,N,A }
81
+ _lowdimfilter (Val (D), ArrayInterface. stride_rank (A ))
77
82
end
78
- @inline function VectorizationBase . val_dense_dims (fb :: LowDimArrayForBroadcast{D} ) where {D}
79
- VectorizationBase . asvalbool ( _lowdimfilter (Val (D), ArrayInterface. dense_dims (parent ( parent (fb))) ))
83
+ @inline function ArrayInterface . dense_dims ( :: Type{ LowDimArrayForBroadcast{D,T,N,A}} ) where {D,T,N,A }
84
+ _lowdimfilter (Val (D), ArrayInterface. dense_dims (A ))
80
85
end
81
- @inline function VectorizationBase. bytestrides (fb:: LowDimArrayForBroadcast{D} ) where {D}
82
- p = parent (parent (fb))
83
- s = _bytestrides (ArrayInterface. size (p), ArrayInterface. strides (p), p)
84
- _lowdimfilter (Val (D), s)
86
+ @inline function ArrayInterface. strides (fb:: LowDimArrayForBroadcast{D} ) where {D}
87
+ _lowdimfilter (Val (D), strides (parent (fb)))
85
88
end
86
89
@inline function ArrayInterface. offsets (fb:: LowDimArrayForBroadcast{D} ) where {D}
87
- _lowdimfilter (Val (D), ArrayInterface. offsets (parent (parent (fb))))
90
+ _lowdimfilter (Val (D), ArrayInterface. offsets (parent (parent (fb))))
91
+ end
92
+ @inline function ArrayInterface. StrideIndex (a:: A ) where {A<: LowDimArrayForBroadcast }
93
+ _stride_index (ArrayInterface. stride_rank (A), ArrayInterface. contiguous_axis (A), a)
94
+ end
95
+ @inline function _stride_index (r:: Tuple{Vararg{StaticInt,N}} , :: StaticInt{C} , A) where {N,C}
96
+ StrideIndex {N,ArrayInterface.known(r),C} (A)
88
97
end
89
98
99
+ for f ∈ [ # groupedstridedpointer support
100
+ :(ArrayInterface. contiguous_axis),
101
+ :(ArrayInterface. contiguous_batch_size),
102
+ :(ArrayInterface. device),
103
+ :(ArrayInterface. stride_rank),
104
+ ]
105
+ @eval @inline $ f (:: Type{ForBroadcast{T,N,A}} ) where {T,N,A} = $ f (A)
106
+ end
90
107
for f ∈ [ # groupedstridedpointer support
91
108
:(VectorizationBase. memory_reference),
92
109
:(ArrayInterface. contiguous_axis),
93
110
:(ArrayInterface. contiguous_batch_size),
94
111
:(ArrayInterface. device),
95
- :(VectorizationBase . val_stride_rank ),
112
+ :(ArrayInterface . stride_rank ),
96
113
:(VectorizationBase. val_dense_dims),
97
- :(ArrayInterface. offsets)
114
+ :(ArrayInterface. offsets),
115
+ :(Base. size)# , :(ArrayInterface.strides)
98
116
]
99
117
@eval @inline $ f (fb:: ForBroadcast ) = $ f (getfield (fb, :data ))
100
118
end
101
- @inline _bytestrides (s,paren) = VectorizationBase. bytestrides (paren)
102
- @generated function _bytestrides (s:: Tuple{Vararg{Integer,N}} , x:: Tuple{Vararg{Integer,N}} , paren:: AbstractArray{T,N} ) where {T,N}
103
- q = Expr (:block , Expr (:meta ,:inline ))
104
- strd_tup = Expr (:tuple )
105
- gf = GlobalRef (Core, :getfield )
106
- ifel = GlobalRef (Core, :ifelse )
107
- st = staticexpr (sizeof (T))
108
- for n ∈ 1 : N
109
- s_type = s. parameters[n]
110
- if s_type <: Static
111
- if s_type === One
112
- push! (strd_tup. args, Expr (:call , lv (:Zero )))
113
- else
114
- push! (strd_tup. args, :($ gf (x, $ n, false ) * $ st))
115
- end
119
+
120
+ function is_column_major (x)
121
+ for (i, j) ∈ enumerate (x)
122
+ i == j || return false
123
+ end
124
+ true
125
+ end
126
+ is_row_major (x) = is_column_major (reverse (x))
127
+ # @inline _bytestrides(s,paren) = VectorizationBase.bytestrides(paren)
128
+ function _strides_expr (@nospecialize (s), @nospecialize (x), R:: Vector{Int} , D:: Vector{Bool} )
129
+ N = length (R)
130
+ q = Expr (:block , Expr (:meta ,:inline ))
131
+ strd_tup = Expr (:tuple )
132
+ gf = GlobalRef (Core, :getfield )
133
+ ifel = GlobalRef (Core, :ifelse )
134
+ Nrange = 1 : 1 : N # type stability w/ respect to reverse
135
+ use_stride_acc = true
136
+ stride_acc:: Int = 1
137
+ if is_column_major (R)
138
+ elseif is_row_major (R)
139
+ Nrange = reverse (Nrange)
140
+ else # not worth my time optimizing this case at the moment...
141
+ # will write something generic stride-rank agnostic eventually
142
+ use_stride_acc = false
143
+ stride_acc = 0
144
+ end
145
+ sₙ_value:: Int = 0
146
+ for n ∈ Nrange
147
+ xₙ_type = x[n]
148
+ # xₙ_type = typeof(x).parameters[n]
149
+ xₙ_static = xₙ_type <: Static
150
+ xₙ_value:: Int = xₙ_static ? (xₙ_type. parameters[1 ]):: Int : 0
151
+ s_type = s[n]
152
+ # s_type = typeof(s).parameters[n]
153
+ sₙ_static = s_type <: Static
154
+ if sₙ_static
155
+ sₙ_value = s_type. parameters[1 ]
156
+ if s_type === One
157
+ push! (strd_tup. args, Expr (:call , lv (:Zero )))
158
+ elseif stride_acc ≠ 0
159
+ push! (strd_tup. args, staticexpr (stride_acc))
160
+ else
161
+ push! (strd_tup. args, :($ gf (x, $ n, false )))
162
+ end
163
+ else
164
+ if xₙ_static
165
+ push! (strd_tup. args, staticexpr (xₙ_value))
166
+ elseif stride_acc ≠ 0
167
+ push! (strd_tup. args, staticexpr (stride_acc))
168
+ else
169
+ push! (strd_tup. args, :($ ifel (isone ($ gf (s, $ n, false )), zero ($ xₙ_type), $ gf (x, $ n, false ))))
170
+ end
171
+ end
172
+ if (n ≠ last (Nrange)) && use_stride_acc
173
+ nnext = n + step (Nrange)
174
+ if D[nnext]
175
+ if xₙ_static & sₙ_static
176
+ stride_acc = xₙ_value * sₙ_value
177
+ elseif sₙ_static
178
+ if stride_acc ≠ 0
179
+ stride_acc *= sₙ_value
180
+ end
116
181
else
117
- Xₙ_type = x. parameters[n]
118
- if Xₙ_type <: Static # FIXME ; what to do here? Dynamic dispatch?
119
- push! (strd_tup. args, :($ gf (x, $ n, false )* $ st))
120
- else
121
- push! (strd_tup. args, :($ ifel (isone ($ gf (s, $ n, false )), zero ($ Xₙ_type), $ gf (x, $ n, false )* $ st)))
122
- end
182
+ stride_acc = 0
123
183
end
184
+ else
185
+ stride_acc = 0
186
+ end
124
187
end
125
- push! (q. args, strd_tup)
126
- q
127
- end
128
- @inline function VectorizationBase. bytestrides (fb:: ForBroadcast )
129
- p = getfield (fb,:data )
130
- _bytestrides (ArrayInterface. size (p), ArrayInterface. strides (p), p)
188
+ end
189
+ push! (q. args, strd_tup)
190
+ q
191
+ end
192
+ @generated function _strides (
193
+ s:: Tuple{Vararg{Integer,N}} , x:: Tuple{Vararg{Integer,N}} , :: Val{R} , :: Val{D}
194
+ ) where {N,R,D}
195
+ Rv = Vector {Int} (undef, N)
196
+ Dv = Vector {Bool} (undef, N)
197
+ for n in 1 : N
198
+ Rv[n] = R[n]
199
+ Dv[n] = D[n]
200
+ end
201
+ _strides_expr (s. parameters, x. parameters, Rv, Dv)
131
202
end
132
203
133
204
struct Product{A,B}
0 commit comments