@@ -40,24 +40,36 @@ static void markAsGPUContainer(ModuleOp topModule) {
4040 UnitAttr::get (topModule->getContext ()));
4141}
4242
43- // / Constructs a new GPU module (for GPU kernels) inside the given top module.
44- static gpu::GPUModuleOp genGPUModule (OpBuilder &builder, ModuleOp topModule,
45- StringRef name) {
43+ // / Constructs a new GPU module (for GPU kernels) inside the given top module,
44+ // / or returns an existing GPU module if one was built previously.
45+ static gpu::GPUModuleOp genGPUModule (OpBuilder &builder, ModuleOp topModule) {
46+ for (auto op : topModule.getBodyRegion ().getOps <gpu::GPUModuleOp>())
47+ return op; // existing
4648 markAsGPUContainer (topModule);
4749 builder.setInsertionPointToStart (&topModule.getBodyRegion ().front ());
48- return builder.create <gpu::GPUModuleOp>(topModule->getLoc (), name);
50+ return builder.create <gpu::GPUModuleOp>(topModule->getLoc (),
51+ " sparse_kernels" );
4952}
5053
5154// / Constructs a new GPU kernel in the given GPU module.
5255static gpu::GPUFuncOp genGPUFunc (OpBuilder &builder, gpu::GPUModuleOp gpuModule,
53- StringRef name, SmallVectorImpl<Value> &args) {
56+ SmallVectorImpl<Value> &args) {
57+ // Get a unique kernel name. Not very creative,
58+ // but we simply try kernel0, kernel1, etc.
59+ unsigned kernelNumber = 0 ;
60+ SmallString<16 > kernelName;
61+ do {
62+ kernelName.clear ();
63+ (" kernel" + Twine (kernelNumber++)).toStringRef (kernelName);
64+ } while (gpuModule.lookupSymbol (kernelName));
65+ // Then we insert a new kernel with given arguments into the module.
5466 builder.setInsertionPointToStart (&gpuModule.getBodyRegion ().front ());
5567 SmallVector<Type> argsTp;
5668 for (unsigned i = 0 , e = args.size (); i < e; i++)
5769 argsTp.push_back (args[i].getType ());
5870 FunctionType type = FunctionType::get (gpuModule->getContext (), argsTp, {});
5971 auto gpuFunc =
60- builder.create <gpu::GPUFuncOp>(gpuModule->getLoc (), name , type);
72+ builder.create <gpu::GPUFuncOp>(gpuModule->getLoc (), kernelName , type);
6173 gpuFunc->setAttr (gpu::GPUDialect::getKernelFuncAttrName (),
6274 builder.getUnitAttr ());
6375 return gpuFunc;
@@ -208,12 +220,9 @@ struct ForallRewriter : public OpRewritePattern<scf::ParallelOp> {
208220 args.push_back (genHostRegisterMemref (rewriter, loc, b));
209221 auto saveIp = rewriter.saveInsertionPoint ();
210222 // Set up GPU module and construct GPU function.
211- //
212- // TODO: only generate once, avoid name conflict
213- //
214223 ModuleOp topModule = forallOp->getParentOfType <ModuleOp>();
215- auto gpuModule = genGPUModule (rewriter, topModule, " sparsekernels " );
216- auto gpuFunc = genGPUFunc (rewriter, gpuModule, " kernel " , args);
224+ auto gpuModule = genGPUModule (rewriter, topModule);
225+ auto gpuFunc = genGPUFunc (rewriter, gpuModule, args);
217226 genGPUCode (rewriter, gpuFunc, forallOp, constants, scalars, buffers);
218227 // Generate code that launches the kernel.
219228 rewriter.restoreInsertionPoint (saveIp);
0 commit comments