1
1
2
2
@inline stridedpointer_for_broadcast (A) = stridedpointer_for_broadcast (ArrayInterface. size (A), stridedpointer (A))
3
3
@inline stridedpointer_for_broadcast (s, ptr) = ptr
4
- function stridedpointer_for_broadcast (s, ptr:: VectorizationBase.AbstractStridedPointer )
5
- # FIXME : this is unsafe for AbstractStridedPointers
6
- throw (" Broadcasting not currently supported for arrays where typeof(stridedpointer(A)) === $(typeof (ptr)) " )
7
- end
8
- @generated function stridedpointer_for_broadcast (s :: Tuple{Vararg{Any,N}} , ptr :: StridedPointer{T,N,C,B,R,X,O} ) where {T,N,C,B,R,X,O}
4
+ # function stridedpointer_for_broadcast(s, ptr::VectorizationBase.AbstractStridedPointer)
5
+ # # FIXME : this is unsafe for AbstractStridedPointers
6
+ # throw("Broadcasting not currently supported for arrays where typeof(stridedpointer(A)) === $(typeof(ptr))")
7
+ # end
8
+ function stridedpointer_for_broadcast_quote (typ, N, S, X)
9
9
q = Expr (:block , Expr (:meta ,:inline ), :(strd = ptr. strd))
10
10
strd_tup = Expr (:tuple )
11
11
for n ∈ 1 : N
12
- s_type = s . parameters [n]
12
+ s_type = S [n]
13
13
if s_type <: Static
14
14
if s_type === Static{1 }
15
15
push! (strd_tup. args, Expr (:call , lv (:Zero )))
16
16
else
17
17
push! (strd_tup. args, :(strd[$ n]))
18
18
end
19
19
else
20
- Xₙ_type = X. parameters [n]
20
+ Xₙ_type = X[n]
21
21
if Xₙ_type <: Static # FIXME ; what to do here? Dynamic dispatch?
22
22
push! (strd_tup. args, :(strd[$ n]))
23
23
else
24
- push! (strd_tup. args, :(Base. ifelse (isone (s[$ n]), one ($ Xₙ_type), strd[$ n])))
24
+ push! (strd_tup. args, :(Base. ifelse (isone (s[$ n]), zero ($ Xₙ_type), strd[$ n])))
25
25
end
26
26
end
27
27
end
28
- push! (q. args, :(@inbounds StridedPointer {$T,$N,$C,$B,$R} (ptr. p, $ strd_tup, ptr. offsets)))
28
+ push! (q. args, :(@inbounds $ typ (ptr. p, $ strd_tup, ptr. offsets)))
29
29
q
30
30
end
31
+ @generated function stridedpointer_for_broadcast (s:: Tuple{Vararg{Any,N}} , ptr:: StridedPointer{T,N,C,B,R,X,O} ) where {T,N,C,B,R,X,O}
32
+ typ = Expr (:curly , :StridedPointer , T, N, C, B, R)
33
+ stridedpointer_for_broadcast_quote (typ, N, s. parameters, X. parameters)
34
+ end
35
+ @generated function stridedpointer_for_broadcast (s:: Tuple{Vararg{Any,N}} , ptr:: VectorizationBase.StridedBitPointer{N,C,B,R,X,O} ) where {N,C,B,R,X,O}
36
+ typ = Expr (:curly , :StridedBitPointer , N, C, B, R)
37
+ stridedpointer_for_broadcast_quote (typ, N, s. parameters, X. parameters)
38
+ end
31
39
32
40
struct Product{A,B}
33
41
a:: A
@@ -132,8 +140,9 @@ function add_broadcast!(
132
140
push! (ls. preamble_zeros, (identifier (setC), IntOrFloat))
133
141
setC. reduced_children = kvec
134
142
# compute Cₘₙ += Aₘₖ * Bₖₙ
143
+ instrsym = Base. libllvm_version < v " 11.0.0" ? :vfmadd231 : :vfmadd
135
144
reductop = Operation (
136
- ls, mC, elementbytes, :vfmadd231 , compute, reductdeps, kvec, Operation[loadA, loadB, setC]
145
+ ls, mC, elementbytes, instrsym , compute, reductdeps, kvec, Operation[loadA, loadB, setC]
137
146
)
138
147
reductop = pushop! (ls, reductop, mC)
139
148
reductfinal = Operation (
@@ -149,17 +158,18 @@ Base.@propagate_inbounds Base.getindex(A::LowDimArray, i...) = getindex(A.data,
149
158
@inline Base. size (A:: LowDimArray ) = Base. size (A. data)
150
159
@inline Base. size (A:: LowDimArray , i) = Base. size (A. data, i)
151
160
@inline Base. strides (A:: LowDimArray ) = strides (A. data)
161
+ @inline ArrayInterface. parent_type (:: Type{LowDimArray{D,T,N,A}} ) where {T,D,N,A} = A
152
162
@inline ArrayInterface. strides (A:: LowDimArray ) = ArrayInterface. strides (A. data)
153
163
@generated function ArrayInterface. size (A:: LowDimArray{D,T,N} ) where {D,T,N}
154
164
t = Expr (:tuple )
155
165
for n ∈ 1 : N
156
- if D[n]
166
+ if n > length (D) || D[n]
157
167
push! (t. args, Expr (:ref , :s , n))
158
168
else
159
169
push! (t. args, Expr (:call , Expr (:curly , lv (:Static ), 1 )))
160
170
end
161
171
end
162
- Expr (:block , Expr (:meta ,:inline ), :(s = size (A )), t)
172
+ Expr (:block , Expr (:meta ,:inline ), :(s = ArrayInterface . size (parent (A) )), t)
163
173
end
164
174
Base. parent (A:: LowDimArray ) = A. data
165
175
Base. unsafe_convert (:: Type{Ptr{T}} , A:: LowDimArray{D,T} ) where {D,T} = pointer (A. data)
@@ -168,6 +178,35 @@ ArrayInterface.contiguous_batch_size(A::LowDimArray) = ArrayInterface.contiguous
168
178
ArrayInterface. stride_rank (A:: LowDimArray ) = ArrayInterface. stride_rank (A. data)
169
179
ArrayInterface. offsets (A:: LowDimArray ) = ArrayInterface. offsets (A. data)
170
180
181
+ @inline function stridedpointer_for_broadcast (A:: LowDimArray{D} ) where {D}
182
+ _stridedpointer (stridedpointer_for_broadcast (parent (A)), Val {D} ())
183
+ end
184
+
185
+ @generated function _stridedpointer (p:: StridedPointer{T,N,C,B,R} , :: Val{D} ) where {T,N,C,B,R,D}
186
+ lenD = length (D)
187
+ strd = Expr (:tuple )
188
+ offsets = Expr (:tuple )
189
+ Rtup = Expr (:tuple )
190
+ Cnew = - 1
191
+ Bnew = - 1
192
+ Nnew = 0
193
+ for n ∈ 1 : N
194
+ ((n ≤ lenD) && (! D[n])) && continue
195
+ if n == C
196
+ Cnew = n
197
+ end
198
+ if n == B
199
+ Bnew = n
200
+ end
201
+ push! (Rtup. args, R[n])
202
+ push! (offsets. args, Expr (:ref , :offs , n))
203
+ push! (strd. args, Expr (:ref , :strd , n))
204
+ Nnew += 1
205
+ end
206
+ typ = Expr (:curly , :StridedPointer , T, Nnew, Cnew, Bnew, Rtup)
207
+ ptr = Expr (:call , typ, :(pointer (p)), strd, offsets)
208
+ Expr (:block , Expr (:meta ,:inline ), :(strd = p. strd), :(offs = p. offsets), ptr)
209
+ end
171
210
# @generated function VectorizationBase.stridedpointer(A::LowDimArray{D,T,N}) where {D,T,N}
172
211
# smul = Expr(:(.), Expr(:(.), :LoopVectorization, QuoteNode(:VectorizationBase)), QuoteNode(:staticmul))
173
212
# multup = Expr(:tuple)
@@ -188,22 +227,22 @@ function LowDimArray{D}(data::A) where {D,T,N,A <: AbstractArray{T,N}}
188
227
end
189
228
function extract_all_1_array! (ls:: LoopSet , bcname:: Symbol , N:: Int , elementbytes:: Int )
190
229
refextract = gensym (bcname)
191
- ref = Expr (:ref , bcname); append ! (ref. args, [ 1 for n ∈ 1 : N] )
230
+ ref = Expr (:ref , bcname); foreach (_ -> push ! (ref. args, :begin ), 1 : N)
192
231
pushprepreamble! (ls, Expr (:(= ), refextract, ref))
193
232
return add_constant! (ls, refextract, elementbytes) # or replace elementbytes with sizeof(T) ? u
194
233
end
195
234
function add_broadcast! (
196
235
ls:: LoopSet , destname:: Symbol , bcname:: Symbol , loopsyms:: Vector{Symbol} ,
197
- @nospecialize (LDA:: Type{<: LowDimArray} ), elementbytes:: Int
198
- )
199
- D,T,N:: Int ,_ = LDA. parameters
236
+ @nospecialize (LDA:: Type{LowDimArray{D,T,N,A} } ), elementbytes:: Int
237
+ ) where {D,T,N,A}
238
+ # D,T,N::Int,_ = LDA.parameters
200
239
Dlen = length (D)
201
- if Dlen == N && ! any (D)
240
+ if Dlen == N && ! any (D) # array is a scalar, as it is broadcasted on all dimensions
202
241
return extract_all_1_array! (ls, bcname, N, elementbytes)
203
242
end
204
243
fulldims = Symbol[loopsyms[n] for n ∈ 1 : N if ((Dlen < n) || D[n]:: Bool )]
205
244
ref = ArrayReference (bcname, fulldims)
206
- add_simple_load! (ls, destname, ref, elementbytes, true , false ):: Operation
245
+ add_simple_load! (ls, destname, ref, elementbytes, true , true ):: Operation
207
246
end
208
247
function add_broadcast_adjoint_array! (
209
248
ls:: LoopSet , destname:: Symbol , bcname:: Symbol , loopsyms:: Vector{Symbol} , :: Type{A} , elementbytes:: Int
@@ -218,8 +257,11 @@ function add_broadcast_adjoint_array!(
218
257
ls:: LoopSet , destname:: Symbol , bcname:: Symbol , loopsyms:: Vector{Symbol} , :: Type{<:AbstractVector} , elementbytes:: Int
219
258
)
220
259
# isone(length(loopsyms)) && return extract_all_1_array!(ls, bcname, N, elementbytes)
221
- ref = ArrayReference (bcname, Symbol[loopsyms[2 ]])
222
- add_simple_load! ( ls, destname, ref, elementbytes, true , true )
260
+ parent = gensym (:parent )
261
+ pushprepreamble! (ls, Expr (:(= ), parent, Expr (:call , :parent , bcname)))
262
+
263
+ ref = ArrayReference (parent, Symbol[loopsyms[2 ]])
264
+ add_simple_load! ( ls, destname, ref, elementbytes, true , true ):: Operation
223
265
end
224
266
function add_broadcast! (
225
267
ls:: LoopSet , destname:: Symbol , bcname:: Symbol , loopsyms:: Vector{Symbol} ,
0 commit comments