@@ -94,29 +94,7 @@ function transform_gpu!(def, constargs, force_inbounds)
94
94
if force_inbounds
95
95
push! (new_stmts, Expr (:inbounds , true ))
96
96
end
97
-
98
- # fix convergence
99
- active_stmts = Any[]
100
- for stmt in stmts
101
- has_sync = find_sync (stmt)
102
- if has_sync
103
- push! (new_stmts, Expr (:if , :__active_lane__ , Expr (:block , active_stmts... )))
104
- empty! (active_stmts)
105
- push! (new_stmts, stmt)
106
- continue
107
- end
108
- if @capture (stmt, @uniform x_)
109
- push! (new_stmts, stmt)
110
- continue
111
- elseif @capture (stmt, lhs_ = rhs_ | (vs__, lhs_ = rhs_))
112
- if @capture (rhs, @localmem (args__) | @uniform (args__))
113
- push! (new_stmts, stmt)
114
- continue
115
- end
116
- end
117
- push! (active_stmts, stmt)
118
- end
119
- push! (new_stmts, Expr (:if , :__active_lane__ , Expr (:block , active_stmts... )))
97
+ append! (new_stmts, split (emit_gpu, body. args))
120
98
if force_inbounds
121
99
push! (new_stmts, Expr (:inbounds , :pop ))
122
100
end
@@ -151,7 +129,7 @@ function transform_cpu!(def, constargs, force_inbounds)
151
129
if force_inbounds
152
130
push! (new_stmts, Expr (:inbounds , true ))
153
131
end
154
- append! (new_stmts, split (body. args))
132
+ append! (new_stmts, split (emit_cpu, body. args))
155
133
if force_inbounds
156
134
push! (new_stmts, Expr (:inbounds , :pop ))
157
135
end
191
169
192
170
# TODO proper handling of LineInfo
193
171
function split (
172
+ emit,
194
173
stmts,
195
174
indicies = Any[], private = Set {Symbol} (),
196
175
)
@@ -221,7 +200,7 @@ function split(
221
200
function recurse (expr:: Expr )
222
201
expr = unblock (expr)
223
202
if is_scope_construct (expr) && any (find_sync, expr. args)
224
- new_args = unblock (split (expr. args, deepcopy (indicies), deepcopy (private)))
203
+ new_args = unblock (split (emit, expr. args, deepcopy (indicies), deepcopy (private)))
225
204
return Expr (expr. head, new_args... )
226
205
else
227
206
return Expr (expr. head, map (recurse, expr. args)... )
@@ -270,7 +249,7 @@ function split(
270
249
return new_stmts
271
250
end
272
251
273
- function emit (loop)
252
+ function emit_cpu (loop)
274
253
idx = gensym (:I )
275
254
for stmt in loop. indicies
276
255
# splice index into the i = @index(Cartesian, $idx)
@@ -324,3 +303,37 @@ function emit(loop)
324
303
325
304
return unblock (Expr (:block , stmts... ))
326
305
end
306
+
307
+ function emit_gpu (loop)
308
+ stmts = Any[]
309
+ append! (stmts, loop. allocations)
310
+ for stmt in loop. private_allocations
311
+ if @capture (stmt, lhs_ = rhs_)
312
+ push! (stmts, :($ lhs = $ rhs))
313
+ else
314
+ error (" @private $stmt not an assignment" )
315
+ end
316
+ end
317
+
318
+ # don't emit empty loops
319
+ if ! (isempty (loop. stmts) || all (s -> s isa LineNumberNode, loop. stmts))
320
+ body = Expr (:block , loop. stmts... )
321
+ body = postwalk (body) do expr
322
+ if @capture (expr, lhs_ = rhs_)
323
+ if lhs in loop. private
324
+ error (" Can't assign to variables marked private" )
325
+ end
326
+ end
327
+ return expr
328
+ end
329
+ loopexpr = quote
330
+ $ (loop. indicies... )
331
+ if __active_lane__
332
+ $ (unblock (body))
333
+ end
334
+ end
335
+ push! (stmts, loopexpr)
336
+ end
337
+
338
+ return unblock (Expr (:block , stmts... ))
339
+ end
0 commit comments