@@ -365,14 +365,14 @@ class GpuKernelOutliningPass
365365 auto funcWalkResult = func.walk ([&](gpu::LaunchOp op) {
366366 SetVector<Value> operands;
367367 std::string kernelFnName;
368- if (auto outlineModuleAttr = op-> getAttrOfType <SymbolRefAttr>( " outline_module " )) {
369- kernelFnName = outlineModuleAttr .getRootReference ().str ();
370- llvm::errs () << " outlined module name = " << kernelFnName << " \n " ;
368+ if (op. hasKernelFuncName ( )) {
369+ kernelFnName = op-> getAttrOfType <mlir::SymbolRefAttr>( " kernelFunc " ) .getRootReference ().str ();
370+ llvm::errs () << " use provided kernel func name = " << kernelFnName << " \n " ;
371371 } else {
372372 kernelFnName =
373373 Twine (op->getParentOfType <SymbolOpInterface>().getName (), " _kernel" )
374374 .str ();
375- llvm::errs () << " original module name = " << kernelFnName << " \n " ;
375+ llvm::errs () << " use default kernel func name = " << kernelFnName << " \n " ;
376376 }
377377
378378 gpu::GPUFuncOp outlinedFunc =
@@ -381,7 +381,7 @@ class GpuKernelOutliningPass
381381 // Create nested module and insert outlinedFunc. The module will
382382 // originally get the same name as the function, but may be renamed on
383383 // insertion into the parent module.
384- auto kernelModule = createKernelModule (outlinedFunc, symbolTable);
384+ auto kernelModule = createKernelModule (op, outlinedFunc, symbolTable);
385385 symbolTable.insert (kernelModule, insertPt);
386386
387387 // Potentially changes signature, pulling in constants.
@@ -402,16 +402,34 @@ class GpuKernelOutliningPass
402402
403403private:
404404 // / Returns a gpu.module containing kernelFunc and all callees (recursive).
405- gpu::GPUModuleOp createKernelModule (gpu::GPUFuncOp kernelFunc,
405+ gpu::GPUModuleOp createKernelModule (gpu::LaunchOp gpuLaunchOp, gpu:: GPUFuncOp kernelFunc,
406406 const SymbolTable &parentSymbolTable) {
407407 // TODO: This code cannot use an OpBuilder because it must be inserted into
408408 // a SymbolTable by the caller. SymbolTable needs to be refactored to
409409 // prevent manual building of Ops with symbols in code using SymbolTables
410410 // and then this needs to use the OpBuilder.
411411 auto *context = getOperation ().getContext ();
412412 OpBuilder builder (context);
413- auto kernelModule = builder.create <gpu::GPUModuleOp>(kernelFunc.getLoc (),
414- kernelFunc.getName ());
413+ std::string kernelModuleName;
414+ if (gpuLaunchOp.hasKernelModuleName ()) {
415+ kernelModuleName = gpuLaunchOp->getAttrOfType <mlir::SymbolRefAttr>(" kernelModule" ).getRootReference ().str ();
416+ llvm::errs () << " use provided kernel module name = " << kernelModuleName << " \n " ;
417+ } else {
418+ kernelModuleName = kernelFunc.getName ();
419+ llvm::errs () << " use default kernel module name = " << kernelModuleName << " \n " ;
420+ }
421+
422+ gpu::GPUModuleOp kernelModule;
423+ // Check if the module already exists in the symbol table
424+ if (auto existingModule = parentSymbolTable.lookup <gpu::GPUModuleOp>(kernelModuleName)) {
425+ llvm::errs () << " Reusing existing kernel module: " << kernelModuleName << " \n " ;
426+ kernelModule = existingModule;
427+ } else {
428+ // If not found, create a new GPU module
429+ llvm::errs () << " Creating new kernel module: " << kernelModuleName << " \n " ;
430+ kernelModule = builder.create <gpu::GPUModuleOp>(kernelFunc.getLoc (),
431+ kernelModuleName);
432+ }
415433
416434 // If a valid data layout spec was provided, attach it to the kernel module.
417435 // Otherwise, the default data layout will be used.
@@ -439,6 +457,8 @@ class GpuKernelOutliningPass
439457 }
440458 }
441459
460+ // llvm::errs() << "kernelModule:\n" << kernelModule << "\n";
461+
442462 return kernelModule;
443463 }
444464
0 commit comments