2020
2121using namespace mlir ;
2222
23- LLVM::LLVMFuncOp mlir::getOrDefineFunction (gpu::GPUModuleOp moduleOp,
24- Location loc, OpBuilder &b,
25- StringRef name,
23+ LLVM::LLVMFuncOp mlir::getOrDefineFunction (Operation *moduleOp, Location loc,
24+ OpBuilder &b, StringRef name,
2625 LLVM::LLVMFunctionType type) {
27- LLVM::LLVMFuncOp ret;
28- if (!(ret = moduleOp.template lookupSymbol <LLVM::LLVMFuncOp>(name))) {
29- OpBuilder::InsertionGuard guard (b);
30- b.setInsertionPointToStart (moduleOp.getBody ());
31- ret = LLVM::LLVMFuncOp::create (b, loc, name, type, LLVM::Linkage::External);
32- }
33- return ret;
26+ auto existing = dyn_cast_or_null<LLVM::LLVMFuncOp>(
27+ SymbolTable::lookupSymbolIn (moduleOp, name));
28+ if (existing)
29+ return existing;
30+
31+ OpBuilder::InsertionGuard guard (b);
32+ b.setInsertionPointToStart (&moduleOp->getRegion (0 ).front ());
33+ return LLVM::LLVMFuncOp::create (b, loc, name, type, LLVM::Linkage::External);
3434}
3535
36- static SmallString<16 > getUniqueSymbolName (gpu::GPUModuleOp moduleOp,
36+ static SmallString<16 > getUniqueSymbolName (Operation * moduleOp,
3737 StringRef prefix) {
3838 // Get a unique global name.
3939 unsigned stringNumber = 0 ;
4040 SmallString<16 > stringConstName;
4141 do {
4242 stringConstName.clear ();
4343 (prefix + Twine (stringNumber++)).toStringRef (stringConstName);
44- } while (moduleOp. lookupSymbol ( stringConstName));
44+ } while (SymbolTable::lookupSymbolIn (moduleOp, stringConstName));
4545 return stringConstName;
4646}
4747
48- LLVM::GlobalOp
49- mlir::getOrCreateStringConstant (OpBuilder &b, Location loc,
50- gpu::GPUModuleOp moduleOp, Type llvmI8,
51- StringRef namePrefix, StringRef str,
52- uint64_t alignment, unsigned addrSpace) {
48+ LLVM::GlobalOp mlir::getOrCreateStringConstant (OpBuilder &b, Location loc,
49+ Operation *moduleOp, Type llvmI8,
50+ StringRef namePrefix,
51+ StringRef str,
52+ uint64_t alignment,
53+ unsigned addrSpace) {
5354 llvm::SmallString<20 > nullTermStr (str);
5455 nullTermStr.push_back (' \0 ' ); // Null terminate for C
5556 auto globalType =
5657 LLVM::LLVMArrayType::get (llvmI8, nullTermStr.size_in_bytes ());
5758 StringAttr attr = b.getStringAttr (nullTermStr);
5859
5960 // Try to find existing global.
60- for (auto globalOp : moduleOp.getOps <LLVM::GlobalOp>())
61+ for (auto globalOp : moduleOp-> getRegion ( 0 ) .getOps <LLVM::GlobalOp>())
6162 if (globalOp.getGlobalType () == globalType && globalOp.getConstant () &&
6263 globalOp.getValueAttr () == attr &&
6364 globalOp.getAlignment ().value_or (0 ) == alignment &&
@@ -66,7 +67,7 @@ mlir::getOrCreateStringConstant(OpBuilder &b, Location loc,
6667
6768 // Not found: create new global.
6869 OpBuilder::InsertionGuard guard (b);
69- b.setInsertionPointToStart (moduleOp. getBody ());
70+ b.setInsertionPointToStart (& moduleOp-> getRegion ( 0 ). front ());
7071 SmallString<16 > name = getUniqueSymbolName (moduleOp, namePrefix);
7172 return LLVM::GlobalOp::create (b, loc, globalType,
7273 /* isConstant=*/ true , LLVM::Linkage::Internal,
@@ -396,10 +397,11 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite(
396397 auto ptrType = LLVM::LLVMPointerType::get (rewriter.getContext ());
397398 mlir::Type llvmI32 = typeConverter->convertType (rewriter.getI32Type ());
398399 mlir::Type llvmI64 = typeConverter->convertType (rewriter.getI64Type ());
399- // Note: this is the GPUModule op, not the ModuleOp that surrounds it
400- // This ensures that global constants and declarations are placed within
401- // the device code, not the host code
402- auto moduleOp = gpuPrintfOp->getParentOfType <gpu::GPUModuleOp>();
400+
401+ Operation *moduleOp = gpuPrintfOp->getParentWithTrait <OpTrait::SymbolTable>();
402+ if (!moduleOp)
403+ return rewriter.notifyMatchFailure (gpuPrintfOp,
404+ " Couldn't find a parent module" );
403405
404406 auto ocklBegin =
405407 getOrDefineFunction (moduleOp, loc, rewriter, " __ockl_printf_begin" ,
@@ -496,10 +498,10 @@ LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite(
496498 mlir::Type ptrType =
497499 LLVM::LLVMPointerType::get (rewriter.getContext (), addressSpace);
498500
499- // Note: this is the GPUModule op, not the ModuleOp that surrounds it
500- // This ensures that global constants and declarations are placed within
501- // the device code, not the host code
502- auto moduleOp = gpuPrintfOp-> getParentOfType <gpu::GPUModuleOp>( );
501+ Operation *moduleOp = gpuPrintfOp-> getParentWithTrait <OpTrait::SymbolTable>();
502+ if (!moduleOp)
503+ return rewriter. notifyMatchFailure (gpuPrintfOp,
504+ " Couldn't find a parent module " );
503505
504506 auto printfType =
505507 LLVM::LLVMFunctionType::get (rewriter.getI32Type (), {ptrType},
@@ -541,10 +543,10 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite(
541543 mlir::Type llvmI8 = typeConverter->convertType (rewriter.getIntegerType (8 ));
542544 mlir::Type ptrType = LLVM::LLVMPointerType::get (rewriter.getContext ());
543545
544- // Note: this is the GPUModule op, not the ModuleOp that surrounds it
545- // This ensures that global constants and declarations are placed within
546- // the device code, not the host code
547- auto moduleOp = gpuPrintfOp-> getParentOfType <gpu::GPUModuleOp>( );
546+ Operation *moduleOp = gpuPrintfOp-> getParentWithTrait <OpTrait::SymbolTable>();
547+ if (!moduleOp)
548+ return rewriter. notifyMatchFailure (gpuPrintfOp,
549+ " Couldn't find a parent module " );
548550
549551 // Create a valid global location removing any metadata attached to the
550552 // location as debug info metadata inside of a function cannot be used outside
0 commit comments