@@ -116,7 +116,7 @@ function create_result(
116116end
117117
118118# Optimization passes via transform dialect
119- function optimization_passes (; no_nan:: Bool = false )
119+ function optimization_passes (; no_nan:: Bool = false , sroa :: Bool = false )
120120 transform_passes_list = [
121121 " patterns=compare_op_canon<16>" ,
122122 " transpose_transpose<16>" ,
@@ -295,12 +295,16 @@ function optimization_passes(; no_nan::Bool=false)
295295 " ," ,
296296 )
297297 func_passes = join ([" canonicalize" , " cse" , " canonicalize" , transform_passes], " ," )
298- return join (
299- [
300- " inline{default-pipeline=canonicalize max-iterations=4}" ,
301- " libdevice-funcs-raise" ,
302- func_passes,
303- ],
298+ passes = [
299+ " inline{default-pipeline=canonicalize max-iterations=4}"
300+ ]
301+ if sroa
302+ push! (passes, " sroa-wrappers" )
303+ push! (passes, " libdevice-funcs-raise" )
304+ push! (passes, " canonicalize" )
305+ end
306+ push! (passes, func_passes)
307+ return join (passes,
304308 ' ,' ,
305309 )
306310end
310314const enzyme_pass:: String = " enzyme{postpasses=\" arith-raise{stablehlo=true},canonicalize,cse,canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math,canonicalize,cse,canonicalize\" }"
311315
312316function run_pass_pipeline! (mod, pass_pipeline; enable_verifier= true )
317+ @show pass_pipeline
318+ flush (stdout )
313319 pm = MLIR. IR. PassManager ()
314320 MLIR. IR. enable_verifier! (pm, enable_verifier)
315321 opm = MLIR. IR. OpPassManager (pm)
@@ -374,9 +380,10 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan::
374380 kern = " lower-kernel{run_init=true toolkitPath=$toolkit cuLaunchKernelPtr=$(cuLaunch[]) cuModuleLoadDataPtr=$(cuModule[]) cuModuleGetFunctionPtr=$(cuFunc[]) },symbol-dce"
375381
376382 opt_passes = optimization_passes (; no_nan)
383+ opt_passes2 = optimization_passes (; no_nan, sroa= false )
377384
378385 if optimize === :all
379- run_pass_pipeline! (mod, join ([opt_passes, " enzyme-batch" , opt_passes ], " ," ))
386+ run_pass_pipeline! (mod, join ([opt_passes, " enzyme-batch" , opt_passes2 ], " ," ))
380387 run_pass_pipeline! (
381388 mod, " $enzyme_pass ,arith-raise{stablehlo=true}" ; enable_verifier= false
382389 )
@@ -387,14 +394,14 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan::
387394 " canonicalize" ,
388395 " remove-unnecessary-enzyme-ops" ,
389396 " enzyme-simplify-math" ,
390- opt_passes ,
397+ opt_passes2 ,
391398 kern,
392399 ],
393400 ' ,' ,
394401 ),
395402 )
396403 elseif optimize === :before_kernel
397- run_pass_pipeline! (mod, join ([opt_passes, " enzyme-batch" , opt_passes ], " ," ))
404+ run_pass_pipeline! (mod, join ([opt_passes, " enzyme-batch" , opt_passes2 ], " ," ))
398405 run_pass_pipeline! (
399406 mod, " $enzyme_pass ,arith-raise{stablehlo=true}" ; enable_verifier= false
400407 )
@@ -405,13 +412,13 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan::
405412 " canonicalize" ,
406413 " remove-unnecessary-enzyme-ops" ,
407414 " enzyme-simplify-math" ,
408- opt_passes ,
415+ opt_passes2 ,
409416 ],
410417 ' ,' ,
411418 ),
412419 )
413420 elseif optimize === :no_enzyme
414- run_pass_pipeline! (mod, join ([opt_passes, " enzyme-batch" , opt_passes ], " ," ))
421+ run_pass_pipeline! (mod, join ([opt_passes, " enzyme-batch" , opt_passes2 ], " ," ))
415422 run_pass_pipeline! (mod, " arith-raise{stablehlo=true}" ; enable_verifier= false )
416423 run_pass_pipeline! (
417424 mod,
@@ -420,7 +427,7 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan::
420427 " canonicalize" ,
421428 " remove-unnecessary-enzyme-ops" ,
422429 " enzyme-simplify-math" ,
423- opt_passes ,
430+ opt_passes2 ,
424431 ],
425432 ' ,' ,
426433 ),
@@ -449,14 +456,14 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan::
449456 " canonicalize" ,
450457 " remove-unnecessary-enzyme-ops" ,
451458 " enzyme-simplify-math" ,
452- opt_passes ,
459+ opt_passes2 ,
453460 kern,
454461 ],
455462 ' ,' ,
456463 ),
457464 )
458465 elseif optimize === :before_enzyme
459- run_pass_pipeline! (mod, join ([opt_passes, " enzyme-batch" , opt_passes ], " ," ))
466+ run_pass_pipeline! (mod, join ([opt_passes, " enzyme-batch" , opt_passes2 ], " ," ))
460467 run_pass_pipeline! (
461468 mod, " $enzyme_pass ,arith-raise{stablehlo=true}" ; enable_verifier= false
462469 )
0 commit comments