@@ -73,7 +73,7 @@ function broadcast_t(f::Any, ::Type{Any}, ::Any, ::Any, A::GPUArrays.GPUArray, a
73
73
end
74
74
75
75
deref (x) = x
76
- deref (x:: RefValue ) = (x[],) # RefValue doesn't work with CUDAnative
76
+ deref (x:: RefValue ) = (x[],) # RefValue doesn't work with CUDAnative so we use Tuple
77
77
78
78
function _broadcast! (
79
79
func, out:: GPUArray ,
@@ -84,9 +84,9 @@ function _broadcast!(
84
84
shape = UInt32 .(size (out))
85
85
args = (A, Bs... )
86
86
descriptor_tuple = ntuple (length (args)) do i
87
- BroadcastDescriptor (args[i], keeps[i], Idefaults[i])
87
+ BInfo (args[i], keeps[i], Idefaults[i])
88
88
end
89
- gpu_call (broadcast_kernel!, out, (func, out, shape, UInt32 ( length (out)), descriptor_tuple, A, deref .(Bs)... ))
89
+ gpu_call (broadcast_kernel!, out, (func, out, shape, descriptor_tuple, ( A, deref .(Bs)... ) ))
90
90
out
91
91
end
92
92
@@ -97,9 +97,9 @@ function Base.foreach(func, over::GPUArray, Bs...)
97
97
keeps, Idefaults = map_newindexer (shape, over, Bs)
98
98
args = (over, Bs... )
99
99
descriptor_tuple = ntuple (length (args)) do i
100
- BroadcastDescriptor (args[i], keeps[i], Idefaults[i])
100
+ BInfo (args[i], keeps[i], Idefaults[i])
101
101
end
102
- gpu_call (foreach_kernel, over, (func, shape, UInt32 .( length (over)), descriptor_tuple, over, deref .(Bs)... ))
102
+ gpu_call (foreach_kernel, over, (func, shape, descriptor_tuple, ( over, deref .(Bs)... ) ))
103
103
return
104
104
end
105
105
@@ -116,51 +116,60 @@ arg_length(x::Tuple) = (UInt32(length(x)),)
116
116
arg_length (x:: GPUArray ) = UInt32 .(size (x))
117
117
arg_length (x) = () # Scalar
118
118
119
- abstract type BroadcastDescriptor{Typ} end
120
-
121
- struct BroadcastDescriptorN{Typ, N} <: BroadcastDescriptor{Typ}
119
+ struct BInfo{Typ, N}
122
120
size:: NTuple{N, UInt32}
123
121
keep:: NTuple{N, UInt32}
124
122
idefault:: NTuple{N, UInt32}
125
123
end
126
- function BroadcastDescriptor (val:: RefValue , keep, idefault)
127
- BroadcastDescriptorN {Tuple, 1} ((UInt32 (1 ),), (UInt32 (0 ),), (UInt32 (1 ),))
128
- end
129
124
130
- function BroadcastDescriptor (val, keep, idefault)
125
+ function BInfo (val, keep, idefault)
131
126
N = length (keep)
132
127
typ = Broadcast. containertype (val)
133
- BroadcastDescriptorN {typ, N} (arg_length (val), UInt32 .(keep), UInt32 .(idefault))
128
+ BInfo {typ, N} (arg_length (val), UInt32 .(keep), UInt32 .(idefault))
134
129
end
135
130
136
131
@propagate_inbounds @inline function _broadcast_getindex (
137
- :: BroadcastDescriptor {Array} , A, I
132
+ :: BInfo {Array} , A, I
138
133
)
139
134
A[I]
140
135
end
141
136
@propagate_inbounds @inline function _broadcast_getindex (
142
- :: BroadcastDescriptor {Tuple} , A, I
137
+ :: BInfo {Tuple} , A, I
143
138
)
144
139
A[I]
145
140
end
146
141
@propagate_inbounds @inline function _broadcast_getindex (
147
- :: BroadcastDescriptor {Array} , A:: Ref , I
142
+ :: BInfo {Array} , A:: Ref , I
148
143
)
149
144
A[]
150
145
end
151
146
152
147
@inline _broadcast_getindex (any, A, I) = A
153
148
154
- for N = 0 : 10
149
+ @inline function broadcast_kernel! (state, func, out, shape, descriptor, args)
150
+ ilin = @linearidx (out, state)
151
+ @inbounds out[ilin] = apply_broadcast (ilin, func, shape, descriptor, args)
152
+ return
153
+ end
154
+ function foreach_kernel (state, func, shape, descriptor, args)
155
+ ilin = @linearidx (args[1 ], state)
156
+ apply_broadcast (ilin, func, shape, descriptor, args)
157
+ return
158
+ end
159
+
160
+ function mapidx_kernel (state, f, A, args)
161
+ ilin = @linearidx (A, state)
162
+ f (ilin, A, args... )
163
+ return
164
+ end
165
+
166
+ for N = 0 : 15
155
167
nargs = N + 1
156
168
inner_expr = []
157
- args = []
158
169
valargs = []
159
170
for i = 1 : N
160
- Ai = Symbol (" A_" , i);
161
171
val_i = Symbol (" val_" , i); I_i = Symbol (" I_" , i);
162
172
desi = Symbol (" deref_" , i)
163
- push! (args, Ai)
164
173
inner = quote
165
174
# destructure the keeps and As tuples
166
175
$ desi = descriptor[$ i]
@@ -172,51 +181,25 @@ for N = 0:10
172
181
$ desi. size
173
182
)
174
183
# extract array values
175
- @inbounds $ val_i = _broadcast_getindex ($ desi, $ Ai , $ I_i)
184
+ @inbounds $ val_i = _broadcast_getindex ($ desi, args[ $ i] , $ I_i)
176
185
end
177
186
push! (inner_expr, inner)
178
187
push! (valargs, val_i)
179
188
end
180
189
@eval begin
181
-
182
- @inline function apply_broadcast (ilin, state, func, shape, len, descriptor, $ (args... ))
190
+ @inline function apply_broadcast (ilin, func, shape, descriptor, args:: NTuple{$N, Any} )
183
191
# this will hopefully get dead code removed,
184
192
# if only arrays with linear index are involved, because I should be unused in that case
185
193
I = gpu_ind2sub (shape, ilin)
186
194
$ (inner_expr... )
187
195
# call the function and store the result
188
196
func ($ (valargs... ))
189
197
end
190
-
191
- @inline function broadcast_kernel! (state, func, B, shape, len, descriptor, $ (args... ))
192
- ilin = linear_index (state)
193
- if ilin <= len
194
- @inbounds B[ilin] = apply_broadcast (ilin, state, func, shape, len, descriptor, $ (args... ))
195
- end
196
- return
197
- end
198
-
199
- function foreach_kernel (state, func, shape, len, descriptor, A, $ (args... ))
200
- ilin = linear_index (state)
201
- if ilin <= len
202
- apply_broadcast (ilin, state, func, shape, len, descriptor, A, $ (args... ))
203
- end
204
- return
205
- end
206
-
207
- function mapidx_kernel (state, f, A, len, $ (args... ))
208
- i = linear_index (state)
209
- @inbounds if i <= len
210
- f (i, A, $ (args... ))
211
- end
212
- return
213
- end
214
198
end
215
-
216
199
end
217
200
218
201
function mapidx (f, A:: GPUArray , args:: NTuple{N, Any} ) where N
219
- gpu_call (mapidx_kernel, A, (f, A, UInt32 ( length (A)), args... ))
202
+ gpu_call (mapidx_kernel, A, (f, A, args))
220
203
end
221
204
222
205
# don't do anything for empty tuples
0 commit comments