Skip to content

Commit c89b1c9

Browse files
committed
process instead of hoist anon funcs
1 parent 8a3c3eb commit c89b1c9

File tree

3 files changed

+136
-73
lines changed

3 files changed

+136
-73
lines changed

src/modeling/graphs.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1016,10 +1016,10 @@ function instruction!(ls::LoopSet, x::Expr)
10161016
@assert _x isa Expr
10171017
x = _x
10181018
end
1019-
if x.head :(->)
1020-
instr = last(x.args).value
1021-
instr keys(COST) && return Instruction(:LoopVectorization, instr)
1022-
end
1019+
# if x.head ≢ :(->)
1020+
instr = last(x.args).value
1021+
instr keys(COST) && return Instruction(:LoopVectorization, instr)
1022+
# end
10231023
instr = gensym!(ls, "f")
10241024
pushpreamble!(ls, Expr(:(=), instr, x))
10251025
Instruction(Symbol(""), instr)
@@ -1191,7 +1191,7 @@ function add_assignment!(ls::LoopSet, LHS, RHS, elementbytes::Int, position::Int
11911191
end
11921192
end
11931193

1194-
function Base.push!(ls::LoopSet, ex::Expr, elementbytes::Int, position::Int)
1194+
function Base.push!(ls::LoopSet, ex::Expr, elementbytes::Int, position::Int, mpref::Union{Nothing,ArrayReferenceMetaPosition} = nothing)
11951195
if ex.head === :call
11961196
finex = first(ex.args)::Symbol
11971197
if finex === :setindex!
@@ -1216,7 +1216,7 @@ function Base.push!(ls::LoopSet, ex::Expr, elementbytes::Int, position::Int)
12161216
@assert localbody.head === :(=)
12171217
@assert length(localbody.args) == 2
12181218
LHS = (localbody.args[1])::Symbol
1219-
RHS = push!(ls, (localbody.args[2]), elementbytes, position)
1219+
RHS = push!(ls, (localbody.args[2]), elementbytes, position, mpref)
12201220
if isstore(RHS)
12211221
RHS
12221222
else

src/parse/add_compute.jl

Lines changed: 124 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -248,78 +248,140 @@ function add_reduction_update_parent!(
248248
opout
249249
end
250250

251-
251+
function substitute!(ex::Expr, d::Dict{Symbol,Symbol})
252+
for (i,arg) enumerate(ex.args)
253+
if arg isa Symbol
254+
ex.args[i] = get(d, arg, arg)
255+
elseif arg isa Expr
256+
substitute!(arg, d)
257+
end
258+
end
259+
end
260+
function argsymbol(ls::LoopSet, arg, mpref, elementbytes::Int, position::Int)::Symbol
261+
argsym = gensym!(ls, "anonarg")
262+
if mpref === nothing
263+
add_operation!(ls, argsym, arg, elementbytes, position)
264+
else
265+
add_operation!(ls, argsym, arg, mpref, elementbytes, position)
266+
end
267+
return argsym
268+
end
269+
function add_anon_func!(ls::LoopSet, LHS::Symbol, f::Expr, ex::Expr, position::Int, mpref::Union{Nothing,ArrayReferenceMetaPosition}, elementbytes::Int)::Operation
270+
d = Dict{Symbol,Symbol}()
271+
anonargs = f.args[1]
272+
if anonargs isa Symbol
273+
@assert length(ex.args) == 2
274+
arg = ex.args[2]
275+
if !(arg isa Symbol)
276+
argsym = argsymbol(ls, arg, mpref, elementbytes, position)
277+
else
278+
argsym = arg::Symbol
279+
end
280+
d[anonargs] = argsym
281+
else
282+
@assert Meta.isexpr(anonargs, :tuple)
283+
callargs = @view(ex.args[2:end])
284+
for i eachindex(anonargs.args, callargs)
285+
arg = callargs[i]
286+
if !(arg isa Symbol)
287+
argsym = argsymbol(ls, arg, mpref, elementbytes, position)
288+
else
289+
argsym = arg::Symbol
290+
end
291+
d[anonargs.args[i]] = argsym
292+
end
293+
end
294+
anonbody = f.args[2]
295+
substitute!(anonbody, d)
296+
for i 1:length(anonbody.args)-1
297+
exᵢ = anonbody.args[i]
298+
exᵢ isa Expr && push!(ls, exᵢ, elementbytes, position, mpref)
299+
end
300+
lastline = last(anonbody.args)
301+
retop = if lastline isa Symbol
302+
add_compute!(ls, LHS, instruction(:identity), Operation[getop(ls, lastline)], elementbytes)
303+
elseif Meta.isexpr(lastline, :call)
304+
add_compute!(ls, LHS, lastline, elementbytes, position, mpref)
305+
elseif Meta.isexpr(lastline, :(=))
306+
add_compute!(ls, LHS, lastline.args[2], elementbytes, position, mpref)
307+
else
308+
throw(LoopError("Last line of anon func not understood: $lastline"))
309+
end
310+
return retop
311+
end
252312
function add_compute!(
253313
ls::LoopSet, var::Symbol, ex::Expr, elementbytes::Int, position::Int,
254314
mpref::Union{Nothing,ArrayReferenceMetaPosition} = nothing
255-
)
256-
@assert ex.head === :call
315+
)::Operation
316+
@assert ex.head === :call
317+
fexpr = first(ex.args)
318+
Meta.isexpr(fexpr, :(->)) && return add_anon_func!(ls, var, fexpr, ex, position, mpref, elementbytes)
257319
# instr = instruction(first(ex.args))::Symbol
258-
instr = instruction!(ls, first(ex.args))::Instruction
259-
args = @view(ex.args[2:end])
260-
if (instr.instr === :pow_fast || instr.instr === :(^)) && length(args) == 2
261-
arg2 = args[2]
262-
arg2 isa Number && return add_pow!(ls, var, args[1], arg2, elementbytes, position)
263-
end
264-
vparents = Operation[]
265-
deps = Symbol[]
266-
reduceddeps = Symbol[]
267-
reduction_ind = 0
268-
# @show ex first(operations(ls)) === getop(ls, :kern_1_1, elementbytes) first(operations(ls)) getop(ls, :kern_1_1, elementbytes)
269-
for (ind,arg) enumerate(args)
270-
if var === arg
320+
instr = instruction!(ls, first(ex.args))::Instruction
321+
args = @view(ex.args[2:end])
322+
if (instr.instr === :pow_fast || instr.instr === :(^)) && length(args) == 2
323+
arg2 = args[2]
324+
arg2 isa Number && return add_pow!(ls, var, args[1], arg2, elementbytes, position)
325+
end
326+
vparents = Operation[]
327+
deps = Symbol[]
328+
reduceddeps = Symbol[]
329+
reduction_ind = 0
330+
# @show ex first(operations(ls)) === getop(ls, :kern_1_1, elementbytes) first(operations(ls)) getop(ls, :kern_1_1, elementbytes)
331+
for (ind,arg) enumerate(args)
332+
if var === arg
333+
reduction_ind = ind
334+
# add_reduction!(vparents, deps, reduceddeps, ls, arg, elementbytes)
335+
getop(ls, var, elementbytes)
336+
elseif arg isa Expr
337+
isref, argref = tryrefconvert(ls, arg, elementbytes, varname(mpref))
338+
if isref
339+
if mpref == argref
340+
if varname(mpref) === var
341+
id = findfirst(==(mpref.mref), ls.refs_aliasing_syms)
342+
mpref.varname = var = id === nothing ? var : ls.syms_aliasing_refs[id]
271343
reduction_ind = ind
272-
# add_reduction!(vparents, deps, reduceddeps, ls, arg, elementbytes)
273-
getop(ls, var, elementbytes)
274-
elseif arg isa Expr
275-
isref, argref = tryrefconvert(ls, arg, elementbytes, varname(mpref))
276-
if isref
277-
if mpref == argref
278-
if varname(mpref) === var
279-
id = findfirst(==(mpref.mref), ls.refs_aliasing_syms)
280-
mpref.varname = var = id === nothing ? var : ls.syms_aliasing_refs[id]
281-
reduction_ind = ind
282-
mergesetv!(deps, loopdependencies(add_load!(ls, argref, elementbytes)))
283-
else
284-
pushparent!(vparents, deps, reduceddeps, add_load!(ls, argref, elementbytes))
285-
end
286-
else
287-
argref.varname = gensym!(ls, "tempload")
288-
pushparent!(vparents, deps, reduceddeps, add_load!(ls, argref, elementbytes))
289-
end
290-
else
291-
add_parent!(vparents, deps, reduceddeps, ls, arg, elementbytes, position)
292-
end
293-
elseif arg ls.loopsymbols
294-
loopsymop = add_loopvalue!(ls, arg, elementbytes)
295-
pushparent!(vparents, deps, reduceddeps, loopsymop)
344+
mergesetv!(deps, loopdependencies(add_load!(ls, argref, elementbytes)))
345+
else
346+
pushparent!(vparents, deps, reduceddeps, add_load!(ls, argref, elementbytes))
347+
end
296348
else
297-
add_parent!(vparents, deps, reduceddeps, ls, arg, elementbytes, position)
349+
argref.varname = gensym!(ls, "tempload")
350+
pushparent!(vparents, deps, reduceddeps, add_load!(ls, argref, elementbytes))
298351
end
299-
end
300-
reduction = reduction_ind > 0
301-
loopnestview = view(ls.loopsymbols, 1:position)
302-
if iszero(length(deps)) && reduction
303-
append!(deps, loopnestview)
304-
append!(reduceddeps, loopnestview)
352+
else
353+
add_parent!(vparents, deps, reduceddeps, ls, arg, elementbytes, position)
354+
end
355+
elseif arg ls.loopsymbols
356+
loopsymop = add_loopvalue!(ls, arg, elementbytes)
357+
pushparent!(vparents, deps, reduceddeps, loopsymop)
305358
else
306-
newloopdeps = Symbol[]; newreduceddeps = Symbol[];
307-
setdiffv!(newloopdeps, newreduceddeps, deps, loopnestview)
308-
mergesetv!(newreduceddeps, reduceddeps)
309-
deps = newloopdeps; reduceddeps = newreduceddeps
359+
add_parent!(vparents, deps, reduceddeps, ls, arg, elementbytes, position)
310360
end
311-
# @show reduction, search_tree(vparents, var) ex var vparents mpref get(ls.opdict, var, nothing) search_tree_for_ref(ls, vparents, mpref, var) # relies on cycles being forbidden
312-
if reduction || search_tree(vparents, var)
313-
return add_reduction!(ls, var, reduceddeps, deps, vparents, reduction_ind, elementbytes, instr)
314-
else
315-
if mpref nothing && ((length(loopdependencies(mpref)) < position) | (length(reduceddependencies(mpref)) > 0))
316-
var, found = search_tree_for_ref(ls, vparents, mpref, var)
317-
found && return add_reduction!(ls, var, reduceddeps, deps, vparents, reduction_ind, elementbytes, instr)
318-
end
319-
op = Operation(length(operations(ls)), var, elementbytes, instr, compute, deps, reduceddeps, vparents)
320-
return pushop!(ls, op, var)
361+
end
362+
reduction = reduction_ind > 0
363+
loopnestview = view(ls.loopsymbols, 1:position)
364+
if iszero(length(deps)) && reduction
365+
append!(deps, loopnestview)
366+
append!(reduceddeps, loopnestview)
367+
else
368+
newloopdeps = Symbol[]; newreduceddeps = Symbol[];
369+
setdiffv!(newloopdeps, newreduceddeps, deps, loopnestview)
370+
mergesetv!(newreduceddeps, reduceddeps)
371+
deps = newloopdeps; reduceddeps = newreduceddeps
372+
end
373+
# @show reduction, search_tree(vparents, var) ex var vparents mpref get(ls.opdict, var, nothing) search_tree_for_ref(ls, vparents, mpref, var) # relies on cycles being forbidden
374+
if reduction || search_tree(vparents, var)
375+
return add_reduction!(ls, var, reduceddeps, deps, vparents, reduction_ind, elementbytes, instr)
376+
else
377+
if mpref nothing && ((length(loopdependencies(mpref)) < position) | (length(reduceddependencies(mpref)) > 0))
378+
var, found = search_tree_for_ref(ls, vparents, mpref, var)
379+
found && return add_reduction!(ls, var, reduceddeps, deps, vparents, reduction_ind, elementbytes, instr)
321380
end
322-
# maybe_const_compute!(ls, op, elementbytes, position)
381+
op = Operation(length(operations(ls)), var, elementbytes, instr, compute, deps, reduceddeps, vparents)
382+
return pushop!(ls, op, var)
383+
end
384+
# maybe_const_compute!(ls, op, elementbytes, position)
323385
end
324386
function add_reduction!(ls::LoopSet, var::Symbol, reduceddeps, deps, vparents, reduction_ind, elementbytes, instr)
325387
parent = ls.opdict[var]

test/miscellaneous.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,7 @@ using Test
310310
function softmax3_coreavx1!(lse, qq, xx, tmpmax, maxk, nk)
311311
for k in Base.OneTo(maxk)
312312
@turbo for i in eachindex(lse)
313-
tmp = exp(xx[i,k] - tmpmax[i])
313+
tmp = (tmpm -> exp(xx[i,k] - tmpm))(tmpmax[i])
314314
lse[i] += tmp
315315
qq[i,k] = tmp
316316
end
@@ -342,7 +342,7 @@ using Test
342342
function softmax3_coreavx2!(lse, qq, xx, tmpmax, maxk, nk)
343343
@turbo for k in Base.OneTo(maxk)
344344
for i in eachindex(lse)
345-
tmp = exp(xx[i,k] - tmpmax[i])
345+
tmp = (yy -> exp(yy[i,k] - tmpmax[i]))(xx)
346346
lse[i] += tmp
347347
qq[i,k] = tmp
348348
end
@@ -524,9 +524,10 @@ using Test
524524
end
525525
end
526526
function test_for_with_different_indexavx!(c, a, b, start_sample, num_samples)
527-
@turbo for i = start_sample:num_samples + start_sample - 1
528-
c[i] = b[i] * a[i]
529-
end
527+
@turbo for i = start_sample:num_samples + start_sample - 1
528+
aᵢ = a[i]
529+
c[i] = ((x,y) -> x*y)(b[i], aᵢ)
530+
end
530531
end
531532
function test_for_with_different_index_avx!(c, a, b, start_sample, num_samples)
532533
@_avx for i = start_sample:num_samples + start_sample - 1

0 commit comments

Comments
 (0)