@@ -116,7 +116,7 @@ function create_result(
116116end
117117
118118# Optimization passes via transform dialect
119- function optimization_passes (; no_nan:: Bool = false )
119+ function optimization_passes (; no_nan:: Bool = false , sroa :: Bool = false )
120120 transform_passes_list = [
121121 " patterns=compare_op_canon<16>" ,
122122 " transpose_transpose<16>" ,
@@ -295,12 +295,16 @@ function optimization_passes(; no_nan::Bool=false)
295295 " ," ,
296296 )
297297 func_passes = join ([" canonicalize" , " cse" , " canonicalize" , transform_passes], " ," )
298- return join (
299- [
300- " inline{default-pipeline=canonicalize max-iterations=4}" ,
301- " libdevice-funcs-raise" ,
302- func_passes,
303- ],
298+ passes = [
299+ " inline{default-pipeline=canonicalize max-iterations=4}"
300+ ]
301+ if sroa
302+ push! (passes, " sroa-wrappers" )
303+ push! (passes, " libdevice-funcs-raise" )
304+ push! (passes, " canonicalize" )
305+ end
306+ push! (passes, func_passes)
307+ return join (passes,
304308 ' ,' ,
305309 )
306310end
351355const cuLaunch = Ref {UInt} (0 )
352356const cuFunc = Ref {UInt} (0 )
353357const cuModule = Ref {UInt} (0 )
358+ const cuSync = Ref {UInt} (0 )
359+ const DEBUG_KERNEL = Ref {Bool} (false )
354360
355361function compile_mlir! (mod, f, args; optimize:: Union{Bool,Symbol} = true , no_nan:: Bool = false )
356362 # Explicitly don't use block! to avoid creating a closure, which creates
@@ -379,12 +385,20 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan::
379385 if isdefined (Reactant_jll, :ptxas_path )
380386 toolkit = Reactant_jll. ptxas_path[1 : (end - length (" /bin/ptxas" ))]
381387 end
382- kern = " lower-kernel{run_init=true toolkitPath=$toolkit cuLaunchKernelPtr=$(cuLaunch[]) cuModuleLoadDataPtr=$(cuModule[]) cuModuleGetFunctionPtr=$(cuFunc[]) },symbol-dce"
388+ if DEBUG_KERNEL[]
389+ curesulthandler = XLA. Libdl. dlsym (Reactant_jll. libReactantExtra_handle, " ReactantHandleCuResult" )
390+ @assert curesulthandler != = nothing
391+ curesulthandler = Base. reinterpret (UInt, curesulthandler)
392+ kern = " lower-kernel{debug=true cuResultHandlerPtr=$curesulthandler run_init=true toolkitPath=$toolkit cuLaunchKernelPtr=$(cuLaunch[]) cuModuleLoadDataPtr=$(cuModule[]) cuModuleGetFunctionPtr=$(cuFunc[]) cuStreamSynchronizePtr=$(cuSync[]) },symbol-dce"
393+ else
394+ kern = " lower-kernel{run_init=true toolkitPath=$toolkit cuLaunchKernelPtr=$(cuLaunch[]) cuModuleLoadDataPtr=$(cuModule[]) cuModuleGetFunctionPtr=$(cuFunc[]) },symbol-dce"
395+ end
383396
384- opt_passes = optimization_passes (; no_nan)
397+ opt_passes = optimization_passes (; no_nan, sroa= true )
398+ opt_passes2 = optimization_passes (; no_nan, sroa= false )
385399
386400 if optimize === :all
387- run_pass_pipeline! (mod, join ([opt_passes, " enzyme-batch" , opt_passes ], " ," ))
401+ run_pass_pipeline! (mod, join ([opt_passes, " enzyme-batch" , opt_passes2 ], " ," ))
388402 run_pass_pipeline! (
389403 mod, " $enzyme_pass ,arith-raise{stablehlo=true}" ; enable_verifier= false
390404 )
@@ -395,14 +409,14 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan::
395409 " canonicalize" ,
396410 " remove-unnecessary-enzyme-ops" ,
397411 " enzyme-simplify-math" ,
398- opt_passes ,
412+ opt_passes2 ,
399413 kern,
400414 ],
401415 ' ,' ,
402416 ),
403417 )
404418 elseif optimize === :before_kernel
405- run_pass_pipeline! (mod, join ([opt_passes, " enzyme-batch" , opt_passes ], " ," ))
419+ run_pass_pipeline! (mod, join ([opt_passes, " enzyme-batch" , opt_passes2 ], " ," ))
406420 run_pass_pipeline! (
407421 mod, " $enzyme_pass ,arith-raise{stablehlo=true}" ; enable_verifier= false
408422 )
@@ -413,13 +427,13 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan::
413427 " canonicalize" ,
414428 " remove-unnecessary-enzyme-ops" ,
415429 " enzyme-simplify-math" ,
416- opt_passes ,
430+ opt_passes2 ,
417431 ],
418432 ' ,' ,
419433 ),
420434 )
421435 elseif optimize === :no_enzyme
422- run_pass_pipeline! (mod, join ([opt_passes, " enzyme-batch" , opt_passes ], " ," ))
436+ run_pass_pipeline! (mod, join ([opt_passes, " enzyme-batch" , opt_passes2 ], " ," ))
423437 run_pass_pipeline! (mod, " arith-raise{stablehlo=true}" ; enable_verifier= false )
424438 run_pass_pipeline! (
425439 mod,
@@ -428,7 +442,7 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan::
428442 " canonicalize" ,
429443 " remove-unnecessary-enzyme-ops" ,
430444 " enzyme-simplify-math" ,
431- opt_passes ,
445+ opt_passes2 ,
432446 ],
433447 ' ,' ,
434448 ),
@@ -457,14 +471,14 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan::
457471 " canonicalize" ,
458472 " remove-unnecessary-enzyme-ops" ,
459473 " enzyme-simplify-math" ,
460- opt_passes ,
474+ opt_passes2 ,
461475 kern,
462476 ],
463477 ' ,' ,
464478 ),
465479 )
466480 elseif optimize === :before_enzyme
467- run_pass_pipeline! (mod, join ([opt_passes, " enzyme-batch" , opt_passes ], " ," ))
481+ run_pass_pipeline! (mod, join ([opt_passes, " enzyme-batch" , opt_passes2 ], " ," ))
468482 run_pass_pipeline! (
469483 mod, " $enzyme_pass ,arith-raise{stablehlo=true}" ; enable_verifier= false
470484 )
0 commit comments