@@ -248,78 +248,140 @@ function add_reduction_update_parent!(
248
248
opout
249
249
end
250
250
251
-
251
+ function substitute! (ex:: Expr , d:: Dict{Symbol,Symbol} )
252
+ for (i,arg) ∈ enumerate (ex. args)
253
+ if arg isa Symbol
254
+ ex. args[i] = get (d, arg, arg)
255
+ elseif arg isa Expr
256
+ substitute! (arg, d)
257
+ end
258
+ end
259
+ end
260
+ function argsymbol (ls:: LoopSet , arg, mpref, elementbytes:: Int , position:: Int ):: Symbol
261
+ argsym = gensym! (ls, " anonarg" )
262
+ if mpref === nothing
263
+ add_operation! (ls, argsym, arg, elementbytes, position)
264
+ else
265
+ add_operation! (ls, argsym, arg, mpref, elementbytes, position)
266
+ end
267
+ return argsym
268
+ end
269
+ function add_anon_func! (ls:: LoopSet , LHS:: Symbol , f:: Expr , ex:: Expr , position:: Int , mpref:: Union{Nothing,ArrayReferenceMetaPosition} , elementbytes:: Int ):: Operation
270
+ d = Dict {Symbol,Symbol} ()
271
+ anonargs = f. args[1 ]
272
+ if anonargs isa Symbol
273
+ @assert length (ex. args) == 2
274
+ arg = ex. args[2 ]
275
+ if ! (arg isa Symbol)
276
+ argsym = argsymbol (ls, arg, mpref, elementbytes, position)
277
+ else
278
+ argsym = arg:: Symbol
279
+ end
280
+ d[anonargs] = argsym
281
+ else
282
+ @assert Meta. isexpr (anonargs, :tuple )
283
+ callargs = @view (ex. args[2 : end ])
284
+ for i ∈ eachindex (anonargs. args, callargs)
285
+ arg = callargs[i]
286
+ if ! (arg isa Symbol)
287
+ argsym = argsymbol (ls, arg, mpref, elementbytes, position)
288
+ else
289
+ argsym = arg:: Symbol
290
+ end
291
+ d[anonargs. args[i]] = argsym
292
+ end
293
+ end
294
+ anonbody = f. args[2 ]
295
+ substitute! (anonbody, d)
296
+ for i ∈ 1 : length (anonbody. args)- 1
297
+ exᵢ = anonbody. args[i]
298
+ exᵢ isa Expr && push! (ls, exᵢ, elementbytes, position, mpref)
299
+ end
300
+ lastline = last (anonbody. args)
301
+ retop = if lastline isa Symbol
302
+ add_compute! (ls, LHS, instruction (:identity ), Operation[getop (ls, lastline)], elementbytes)
303
+ elseif Meta. isexpr (lastline, :call )
304
+ add_compute! (ls, LHS, lastline, elementbytes, position, mpref)
305
+ elseif Meta. isexpr (lastline, :(= ))
306
+ add_compute! (ls, LHS, lastline. args[2 ], elementbytes, position, mpref)
307
+ else
308
+ throw (LoopError (" Last line of anon func not understood: $lastline " ))
309
+ end
310
+ return retop
311
+ end
252
312
function add_compute! (
253
313
ls:: LoopSet , var:: Symbol , ex:: Expr , elementbytes:: Int , position:: Int ,
254
314
mpref:: Union{Nothing,ArrayReferenceMetaPosition} = nothing
255
- )
256
- @assert ex. head === :call
315
+ ):: Operation
316
+ @assert ex. head === :call
317
+ fexpr = first (ex. args)
318
+ Meta. isexpr (fexpr, :(-> )) && return add_anon_func! (ls, var, fexpr, ex, position, mpref, elementbytes)
257
319
# instr = instruction(first(ex.args))::Symbol
258
- instr = instruction! (ls, first (ex. args)):: Instruction
259
- args = @view (ex. args[2 : end ])
260
- if (instr. instr === :pow_fast || instr. instr === :(^ )) && length (args) == 2
261
- arg2 = args[2 ]
262
- arg2 isa Number && return add_pow! (ls, var, args[1 ], arg2, elementbytes, position)
263
- end
264
- vparents = Operation[]
265
- deps = Symbol[]
266
- reduceddeps = Symbol[]
267
- reduction_ind = 0
268
- # @show ex first(operations(ls)) === getop(ls, :kern_1_1, elementbytes) first(operations(ls)) getop(ls, :kern_1_1, elementbytes)
269
- for (ind,arg) ∈ enumerate (args)
270
- if var === arg
320
+ instr = instruction! (ls, first (ex. args)):: Instruction
321
+ args = @view (ex. args[2 : end ])
322
+ if (instr. instr === :pow_fast || instr. instr === :(^ )) && length (args) == 2
323
+ arg2 = args[2 ]
324
+ arg2 isa Number && return add_pow! (ls, var, args[1 ], arg2, elementbytes, position)
325
+ end
326
+ vparents = Operation[]
327
+ deps = Symbol[]
328
+ reduceddeps = Symbol[]
329
+ reduction_ind = 0
330
+ # @show ex first(operations(ls)) === getop(ls, :kern_1_1, elementbytes) first(operations(ls)) getop(ls, :kern_1_1, elementbytes)
331
+ for (ind,arg) ∈ enumerate (args)
332
+ if var === arg
333
+ reduction_ind = ind
334
+ # add_reduction!(vparents, deps, reduceddeps, ls, arg, elementbytes)
335
+ getop (ls, var, elementbytes)
336
+ elseif arg isa Expr
337
+ isref, argref = tryrefconvert (ls, arg, elementbytes, varname (mpref))
338
+ if isref
339
+ if mpref == argref
340
+ if varname (mpref) === var
341
+ id = findfirst (== (mpref. mref), ls. refs_aliasing_syms)
342
+ mpref. varname = var = id === nothing ? var : ls. syms_aliasing_refs[id]
271
343
reduction_ind = ind
272
- # add_reduction!(vparents, deps, reduceddeps, ls, arg, elementbytes)
273
- getop (ls, var, elementbytes)
274
- elseif arg isa Expr
275
- isref, argref = tryrefconvert (ls, arg, elementbytes, varname (mpref))
276
- if isref
277
- if mpref == argref
278
- if varname (mpref) === var
279
- id = findfirst (== (mpref. mref), ls. refs_aliasing_syms)
280
- mpref. varname = var = id === nothing ? var : ls. syms_aliasing_refs[id]
281
- reduction_ind = ind
282
- mergesetv! (deps, loopdependencies (add_load! (ls, argref, elementbytes)))
283
- else
284
- pushparent! (vparents, deps, reduceddeps, add_load! (ls, argref, elementbytes))
285
- end
286
- else
287
- argref. varname = gensym! (ls, " tempload" )
288
- pushparent! (vparents, deps, reduceddeps, add_load! (ls, argref, elementbytes))
289
- end
290
- else
291
- add_parent! (vparents, deps, reduceddeps, ls, arg, elementbytes, position)
292
- end
293
- elseif arg ∈ ls. loopsymbols
294
- loopsymop = add_loopvalue! (ls, arg, elementbytes)
295
- pushparent! (vparents, deps, reduceddeps, loopsymop)
344
+ mergesetv! (deps, loopdependencies (add_load! (ls, argref, elementbytes)))
345
+ else
346
+ pushparent! (vparents, deps, reduceddeps, add_load! (ls, argref, elementbytes))
347
+ end
296
348
else
297
- add_parent! (vparents, deps, reduceddeps, ls, arg, elementbytes, position)
349
+ argref. varname = gensym! (ls, " tempload" )
350
+ pushparent! (vparents, deps, reduceddeps, add_load! (ls, argref, elementbytes))
298
351
end
299
- end
300
- reduction = reduction_ind > 0
301
- loopnestview = view (ls . loopsymbols, 1 : position)
302
- if iszero ( length (deps)) && reduction
303
- append! (deps, loopnestview )
304
- append! ( reduceddeps, loopnestview )
352
+ else
353
+ add_parent! (vparents, deps, reduceddeps, ls, arg, elementbytes, position)
354
+ end
355
+ elseif arg ∈ ls . loopsymbols
356
+ loopsymop = add_loopvalue! (ls, arg, elementbytes )
357
+ pushparent! (vparents, deps, reduceddeps, loopsymop )
305
358
else
306
- newloopdeps = Symbol[]; newreduceddeps = Symbol[];
307
- setdiffv! (newloopdeps, newreduceddeps, deps, loopnestview)
308
- mergesetv! (newreduceddeps, reduceddeps)
309
- deps = newloopdeps; reduceddeps = newreduceddeps
359
+ add_parent! (vparents, deps, reduceddeps, ls, arg, elementbytes, position)
310
360
end
311
- # @show reduction, search_tree(vparents, var) ex var vparents mpref get(ls.opdict, var, nothing) search_tree_for_ref(ls, vparents, mpref, var) # relies on cycles being forbidden
312
- if reduction || search_tree (vparents, var)
313
- return add_reduction! (ls, var, reduceddeps, deps, vparents, reduction_ind, elementbytes, instr)
314
- else
315
- if mpref ≢ nothing && ((length (loopdependencies (mpref)) < position) | (length (reduceddependencies (mpref)) > 0 ))
316
- var, found = search_tree_for_ref (ls, vparents, mpref, var)
317
- found && return add_reduction! (ls, var, reduceddeps, deps, vparents, reduction_ind, elementbytes, instr)
318
- end
319
- op = Operation (length (operations (ls)), var, elementbytes, instr, compute, deps, reduceddeps, vparents)
320
- return pushop! (ls, op, var)
361
+ end
362
+ reduction = reduction_ind > 0
363
+ loopnestview = view (ls. loopsymbols, 1 : position)
364
+ if iszero (length (deps)) && reduction
365
+ append! (deps, loopnestview)
366
+ append! (reduceddeps, loopnestview)
367
+ else
368
+ newloopdeps = Symbol[]; newreduceddeps = Symbol[];
369
+ setdiffv! (newloopdeps, newreduceddeps, deps, loopnestview)
370
+ mergesetv! (newreduceddeps, reduceddeps)
371
+ deps = newloopdeps; reduceddeps = newreduceddeps
372
+ end
373
+ # @show reduction, search_tree(vparents, var) ex var vparents mpref get(ls.opdict, var, nothing) search_tree_for_ref(ls, vparents, mpref, var) # relies on cycles being forbidden
374
+ if reduction || search_tree (vparents, var)
375
+ return add_reduction! (ls, var, reduceddeps, deps, vparents, reduction_ind, elementbytes, instr)
376
+ else
377
+ if mpref ≢ nothing && ((length (loopdependencies (mpref)) < position) | (length (reduceddependencies (mpref)) > 0 ))
378
+ var, found = search_tree_for_ref (ls, vparents, mpref, var)
379
+ found && return add_reduction! (ls, var, reduceddeps, deps, vparents, reduction_ind, elementbytes, instr)
321
380
end
322
- # maybe_const_compute!(ls, op, elementbytes, position)
381
+ op = Operation (length (operations (ls)), var, elementbytes, instr, compute, deps, reduceddeps, vparents)
382
+ return pushop! (ls, op, var)
383
+ end
384
+ # maybe_const_compute!(ls, op, elementbytes, position)
323
385
end
324
386
function add_reduction! (ls:: LoopSet , var:: Symbol , reduceddeps, deps, vparents, reduction_ind, elementbytes, instr)
325
387
parent = ls. opdict[var]
0 commit comments