Skip to content

Commit 01d498a

Browse files
committed
reuse macro infrastructure
1 parent ce95488 commit 01d498a

File tree

1 file changed

+39
-26
lines changed

1 file changed

+39
-26
lines changed

src/macros.jl

Lines changed: 39 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -94,29 +94,7 @@ function transform_gpu!(def, constargs, force_inbounds)
9494
if force_inbounds
9595
push!(new_stmts, Expr(:inbounds, true))
9696
end
97-
98-
# fix convergence
99-
active_stmts = Any[]
100-
for stmt in stmts
101-
has_sync = find_sync(stmt)
102-
if has_sync
103-
push!(new_stmts, Expr(:if, :__active_lane__, Expr(:block, active_stmts...)))
104-
empty!(active_stmts)
105-
push!(new_stmts, stmt)
106-
continue
107-
end
108-
if @capture(stmt, @uniform x_)
109-
push!(new_stmts, stmt)
110-
continue
111-
elseif @capture(stmt, lhs_ = rhs_ | (vs__, lhs_ = rhs_))
112-
if @capture(rhs, @localmem(args__) | @uniform(args__))
113-
push!(new_stmts, stmt)
114-
continue
115-
end
116-
end
117-
push!(active_stmts, stmt)
118-
end
119-
push!(new_stmts, Expr(:if, :__active_lane__, Expr(:block, active_stmts...)))
97+
append!(new_stmts, split(emit_gpu, body.args))
12098
if force_inbounds
12199
push!(new_stmts, Expr(:inbounds, :pop))
122100
end
@@ -151,7 +129,7 @@ function transform_cpu!(def, constargs, force_inbounds)
151129
if force_inbounds
152130
push!(new_stmts, Expr(:inbounds, true))
153131
end
154-
append!(new_stmts, split(body.args))
132+
append!(new_stmts, split(emit_cpu, body.args))
155133
if force_inbounds
156134
push!(new_stmts, Expr(:inbounds, :pop))
157135
end
@@ -191,6 +169,7 @@ end
191169

192170
# TODO proper handling of LineInfo
193171
function split(
172+
emit,
194173
stmts,
195174
indicies = Any[], private = Set{Symbol}(),
196175
)
@@ -221,7 +200,7 @@ function split(
221200
function recurse(expr::Expr)
222201
expr = unblock(expr)
223202
if is_scope_construct(expr) && any(find_sync, expr.args)
224-
new_args = unblock(split(expr.args, deepcopy(indicies), deepcopy(private)))
203+
new_args = unblock(split(emit, expr.args, deepcopy(indicies), deepcopy(private)))
225204
return Expr(expr.head, new_args...)
226205
else
227206
return Expr(expr.head, map(recurse, expr.args)...)
@@ -270,7 +249,7 @@ function split(
270249
return new_stmts
271250
end
272251

273-
function emit(loop)
252+
function emit_cpu(loop)
274253
idx = gensym(:I)
275254
for stmt in loop.indicies
276255
# splice index into the i = @index(Cartesian, $idx)
@@ -324,3 +303,37 @@ function emit(loop)
324303

325304
return unblock(Expr(:block, stmts...))
326305
end
306+
307+
function emit_gpu(loop)
308+
stmts = Any[]
309+
append!(stmts, loop.allocations)
310+
for stmt in loop.private_allocations
311+
if @capture(stmt, lhs_ = rhs_)
312+
push!(stmts, :($lhs = $rhs))
313+
else
314+
error("@private $stmt not an assignment")
315+
end
316+
end
317+
318+
# don't emit empty loops
319+
if !(isempty(loop.stmts) || all(s -> s isa LineNumberNode, loop.stmts))
320+
body = Expr(:block, loop.stmts...)
321+
body = postwalk(body) do expr
322+
if @capture(expr, lhs_ = rhs_)
323+
if lhs in loop.private
324+
error("Can't assign to variables marked private")
325+
end
326+
end
327+
return expr
328+
end
329+
loopexpr = quote
330+
$(loop.indicies...)
331+
if __active_lane__
332+
$(unblock(body))
333+
end
334+
end
335+
push!(stmts, loopexpr)
336+
end
337+
338+
return unblock(Expr(:block, stmts...))
339+
end

0 commit comments

Comments
 (0)