@@ -73,7 +73,7 @@ function broadcast_t(f::Any, ::Type{Any}, ::Any, ::Any, A::GPUArrays.GPUArray, a
73
73
end
74
74
75
75
deref (x) = x
76
- deref (x:: RefValue ) = (x[],) # RefValue doesn't work with CUDAnative
76
+ deref (x:: RefValue ) = (x[],) # RefValue doesn't work with CUDAnative so we use Tuple
77
77
78
78
function _broadcast! (
79
79
func, out:: GPUArray ,
@@ -84,9 +84,9 @@ function _broadcast!(
84
84
shape = UInt32 .(size (out))
85
85
args = (A, Bs... )
86
86
descriptor_tuple = ntuple (length (args)) do i
87
- BroadcastDescriptor (args[i], keeps[i], Idefaults[i])
87
+ BInfo (args[i], keeps[i], Idefaults[i])
88
88
end
89
- gpu_call (broadcast_kernel!, out, (func, out, shape, UInt32 ( length (out)), descriptor_tuple, A, deref .(Bs)... ))
89
+ gpu_call (broadcast_kernel!, out, (func, out, shape, descriptor_tuple, ( A, deref .(Bs)... ) ))
90
90
out
91
91
end
92
92
@@ -97,9 +97,9 @@ function Base.foreach(func, over::GPUArray, Bs...)
97
97
keeps, Idefaults = map_newindexer (shape, over, Bs)
98
98
args = (over, Bs... )
99
99
descriptor_tuple = ntuple (length (args)) do i
100
- BroadcastDescriptor (args[i], keeps[i], Idefaults[i])
100
+ BInfo (args[i], keeps[i], Idefaults[i])
101
101
end
102
- gpu_call (foreach_kernel, over, (func, shape, UInt32 .( length (over)), descriptor_tuple, over, deref .(Bs)... ))
102
+ gpu_call (foreach_kernel, over, (func, shape, descriptor_tuple, ( over, deref .(Bs)... ) ))
103
103
return
104
104
end
105
105
@@ -109,51 +109,60 @@ arg_length(x::Tuple) = (UInt32(length(x)),)
109
109
arg_length (x:: GPUArray ) = UInt32 .(size (x))
110
110
arg_length (x) = ()
111
111
112
- abstract type BroadcastDescriptor{Typ} end
113
-
114
- struct BroadcastDescriptorN{Typ, N} <: BroadcastDescriptor{Typ}
112
+ struct BInfo{Typ, N}
115
113
size:: NTuple{N, UInt32}
116
114
keep:: NTuple{N, UInt32}
117
115
idefault:: NTuple{N, UInt32}
118
116
end
119
- function BroadcastDescriptor (val:: RefValue , keep, idefault)
120
- BroadcastDescriptorN {Tuple, 1} ((UInt32 (1 ),), (UInt32 (0 ),), (UInt32 (1 ),))
121
- end
122
117
123
- function BroadcastDescriptor (val, keep, idefault)
118
+ function BInfo (val, keep, idefault)
124
119
N = length (keep)
125
120
typ = Broadcast. containertype (val)
126
- BroadcastDescriptorN {typ, N} (arg_length (val), UInt32 .(keep), UInt32 .(idefault))
121
+ BInfo {typ, N} (arg_length (val), UInt32 .(keep), UInt32 .(idefault))
127
122
end
128
123
129
124
@propagate_inbounds @inline function _broadcast_getindex (
130
- :: BroadcastDescriptor {Array} , A, I
125
+ :: BInfo {Array} , A, I
131
126
)
132
127
A[I]
133
128
end
134
129
@propagate_inbounds @inline function _broadcast_getindex (
135
- :: BroadcastDescriptor {Tuple} , A, I
130
+ :: BInfo {Tuple} , A, I
136
131
)
137
132
A[I]
138
133
end
139
134
@propagate_inbounds @inline function _broadcast_getindex (
140
- :: BroadcastDescriptor {Array} , A:: Ref , I
135
+ :: BInfo {Array} , A:: Ref , I
141
136
)
142
137
A[]
143
138
end
144
139
145
140
@inline _broadcast_getindex (any, A, I) = A
146
141
147
- for N = 0 : 10
142
+ @inline function broadcast_kernel! (state, func, out, shape, descriptor, args)
143
+ ilin = @linearidx (out, state)
144
+ @inbounds out[ilin] = apply_broadcast (ilin, func, shape, descriptor, args)
145
+ return
146
+ end
147
+ function foreach_kernel (state, func, shape, descriptor, args)
148
+ ilin = @linearidx (args[1 ], state)
149
+ apply_broadcast (ilin, func, shape, descriptor, args)
150
+ return
151
+ end
152
+
153
+ function mapidx_kernel (state, f, A, args)
154
+ ilin = @linearidx (A, state)
155
+ f (ilin, A, args... )
156
+ return
157
+ end
158
+
159
+ for N = 0 : 15
148
160
nargs = N + 1
149
161
inner_expr = []
150
- args = []
151
162
valargs = []
152
163
for i = 1 : N
153
- Ai = Symbol (" A_" , i);
154
164
val_i = Symbol (" val_" , i); I_i = Symbol (" I_" , i);
155
165
desi = Symbol (" deref_" , i)
156
- push! (args, Ai)
157
166
inner = quote
158
167
# destructure the keeps and As tuples
159
168
$ desi = descriptor[$ i]
@@ -165,51 +174,25 @@ for N = 0:10
165
174
$ desi. size
166
175
)
167
176
# extract array values
168
- @inbounds $ val_i = _broadcast_getindex ($ desi, $ Ai , $ I_i)
177
+ @inbounds $ val_i = _broadcast_getindex ($ desi, args[ $ i] , $ I_i)
169
178
end
170
179
push! (inner_expr, inner)
171
180
push! (valargs, val_i)
172
181
end
173
182
@eval begin
174
-
175
- @inline function apply_broadcast (ilin, state, func, shape, len, descriptor, $ (args... ))
183
+ @inline function apply_broadcast (ilin, func, shape, descriptor, args:: NTuple{$N, Any} )
176
184
# this will hopefully get dead code removed,
177
185
# if only arrays with linear index are involved, because I should be unused in that case
178
186
I = gpu_ind2sub (shape, ilin)
179
187
$ (inner_expr... )
180
188
# call the function and store the result
181
189
func ($ (valargs... ))
182
190
end
183
-
184
- @inline function broadcast_kernel! (state, func, B, shape, len, descriptor, $ (args... ))
185
- ilin = linear_index (state)
186
- if ilin <= len
187
- @inbounds B[ilin] = apply_broadcast (ilin, state, func, shape, len, descriptor, $ (args... ))
188
- end
189
- return
190
- end
191
-
192
- function foreach_kernel (state, func, shape, len, descriptor, A, $ (args... ))
193
- ilin = linear_index (state)
194
- if ilin <= len
195
- apply_broadcast (ilin, state, func, shape, len, descriptor, A, $ (args... ))
196
- end
197
- return
198
- end
199
-
200
- function mapidx_kernel (state, f, A, len, $ (args... ))
201
- i = linear_index (state)
202
- @inbounds if i <= len
203
- f (i, A, $ (args... ))
204
- end
205
- return
206
- end
207
191
end
208
-
209
192
end
210
193
211
194
function mapidx (f, A:: GPUArray , args:: NTuple{N, Any} ) where N
212
- gpu_call (mapidx_kernel, A, (f, A, UInt32 ( length (A)), args... ))
195
+ gpu_call (mapidx_kernel, A, (f, A, args))
213
196
end
214
197
215
198
# don't do anything for empty tuples
0 commit comments