@@ -642,11 +642,10 @@ static IntegerAttr wrapNumericMemorySpace(MLIRContext *ctx, unsigned space) {
642642
643643// / Generates a symbol with 0-sized array type for dynamic shared memory usage,
644644// / or uses existing symbol.
645- LLVM::GlobalOp
646- getDynamicSharedMemorySymbol (ConversionPatternRewriter &rewriter,
647- Operation *moduleOp, gpu::DynamicSharedMemoryOp op,
648- const LLVMTypeConverter *typeConverter,
649- MemRefType memrefType, unsigned alignmentBit) {
645+ LLVM::GlobalOp getDynamicSharedMemorySymbol (
646+ ConversionPatternRewriter &rewriter, gpu::GPUModuleOp moduleOp,
647+ gpu::DynamicSharedMemoryOp op, const LLVMTypeConverter *typeConverter,
648+ MemRefType memrefType, unsigned alignmentBit) {
650649 uint64_t alignmentByte = alignmentBit / memrefType.getElementTypeBitWidth ();
651650
652651 FailureOr<unsigned > addressSpace =
@@ -661,8 +660,7 @@ getDynamicSharedMemorySymbol(ConversionPatternRewriter &rewriter,
661660 // Step 1. Collect symbol names of LLVM::GlobalOp Ops. Also if any of
662661 // LLVM::GlobalOp is suitable for shared memory, return it.
663662 llvm::StringSet<> existingGlobalNames;
664- for (auto globalOp :
665- moduleOp->getRegion (0 ).front ().getOps <LLVM::GlobalOp>()) {
663+ for (auto globalOp : moduleOp.getBody ()->getOps <LLVM::GlobalOp>()) {
666664 existingGlobalNames.insert (globalOp.getSymName ());
667665 if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(globalOp.getType ())) {
668666 if (globalOp.getAddrSpace () == addressSpace.value () &&
@@ -684,7 +682,7 @@ getDynamicSharedMemorySymbol(ConversionPatternRewriter &rewriter,
684682
685683 // Step 3. Generate a global op
686684 OpBuilder::InsertionGuard guard (rewriter);
687- rewriter.setInsertionPoint (& moduleOp-> getRegion ( 0 ). front (). front ());
685+ rewriter.setInsertionPointToStart ( moduleOp. getBody ());
688686
689687 auto zeroSizedArrayType = LLVM::LLVMArrayType::get (
690688 typeConverter->convertType (memrefType.getElementType ()), 0 );
@@ -709,10 +707,8 @@ LogicalResult GPUDynamicSharedMemoryOpLowering::matchAndRewrite(
709707
710708 // Step 2: Generate a global symbol or existing for the dynamic shared
711709 // memory with memref<0xi8> type
712- LLVM::LLVMFuncOp funcOp = op->getParentOfType <LLVM::LLVMFuncOp>();
713- LLVM::GlobalOp shmemOp = {};
714- Operation *moduleOp = funcOp->getParentWithTrait <OpTrait::SymbolTable>();
715- shmemOp = getDynamicSharedMemorySymbol (
710+ auto moduleOp = op->getParentOfType <gpu::GPUModuleOp>();
711+ LLVM::GlobalOp shmemOp = getDynamicSharedMemorySymbol (
716712 rewriter, moduleOp, op, getTypeConverter (), memrefType0sz, alignmentBit);
717713
718714 // Step 3. Get address of the global symbol
0 commit comments