Skip to content

Commit d2e3389

Browse files
authored
[mlir][GPU] Generalize gpu.printf to not need gpu.module (#161266)
In order to make the gpu.printf => [various LLVM calls] passes less order-dependent and to allow downstreams that don't use gpu.module to use gpu.printf, allow the flowerings for such prints to target the nearest `SymbolTable` instead.
1 parent dc6e4e9 commit d2e3389

File tree

3 files changed

+74
-33
lines changed

3 files changed

+74
-33
lines changed

mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp

Lines changed: 33 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -20,44 +20,45 @@
2020

2121
using namespace mlir;
2222

23-
LLVM::LLVMFuncOp mlir::getOrDefineFunction(gpu::GPUModuleOp moduleOp,
24-
Location loc, OpBuilder &b,
25-
StringRef name,
23+
LLVM::LLVMFuncOp mlir::getOrDefineFunction(Operation *moduleOp, Location loc,
24+
OpBuilder &b, StringRef name,
2625
LLVM::LLVMFunctionType type) {
27-
LLVM::LLVMFuncOp ret;
28-
if (!(ret = moduleOp.template lookupSymbol<LLVM::LLVMFuncOp>(name))) {
29-
OpBuilder::InsertionGuard guard(b);
30-
b.setInsertionPointToStart(moduleOp.getBody());
31-
ret = LLVM::LLVMFuncOp::create(b, loc, name, type, LLVM::Linkage::External);
32-
}
33-
return ret;
26+
auto existing = dyn_cast_or_null<LLVM::LLVMFuncOp>(
27+
SymbolTable::lookupSymbolIn(moduleOp, name));
28+
if (existing)
29+
return existing;
30+
31+
OpBuilder::InsertionGuard guard(b);
32+
b.setInsertionPointToStart(&moduleOp->getRegion(0).front());
33+
return LLVM::LLVMFuncOp::create(b, loc, name, type, LLVM::Linkage::External);
3434
}
3535

36-
static SmallString<16> getUniqueSymbolName(gpu::GPUModuleOp moduleOp,
36+
static SmallString<16> getUniqueSymbolName(Operation *moduleOp,
3737
StringRef prefix) {
3838
// Get a unique global name.
3939
unsigned stringNumber = 0;
4040
SmallString<16> stringConstName;
4141
do {
4242
stringConstName.clear();
4343
(prefix + Twine(stringNumber++)).toStringRef(stringConstName);
44-
} while (moduleOp.lookupSymbol(stringConstName));
44+
} while (SymbolTable::lookupSymbolIn(moduleOp, stringConstName));
4545
return stringConstName;
4646
}
4747

48-
LLVM::GlobalOp
49-
mlir::getOrCreateStringConstant(OpBuilder &b, Location loc,
50-
gpu::GPUModuleOp moduleOp, Type llvmI8,
51-
StringRef namePrefix, StringRef str,
52-
uint64_t alignment, unsigned addrSpace) {
48+
LLVM::GlobalOp mlir::getOrCreateStringConstant(OpBuilder &b, Location loc,
49+
Operation *moduleOp, Type llvmI8,
50+
StringRef namePrefix,
51+
StringRef str,
52+
uint64_t alignment,
53+
unsigned addrSpace) {
5354
llvm::SmallString<20> nullTermStr(str);
5455
nullTermStr.push_back('\0'); // Null terminate for C
5556
auto globalType =
5657
LLVM::LLVMArrayType::get(llvmI8, nullTermStr.size_in_bytes());
5758
StringAttr attr = b.getStringAttr(nullTermStr);
5859

5960
// Try to find existing global.
60-
for (auto globalOp : moduleOp.getOps<LLVM::GlobalOp>())
61+
for (auto globalOp : moduleOp->getRegion(0).getOps<LLVM::GlobalOp>())
6162
if (globalOp.getGlobalType() == globalType && globalOp.getConstant() &&
6263
globalOp.getValueAttr() == attr &&
6364
globalOp.getAlignment().value_or(0) == alignment &&
@@ -66,7 +67,7 @@ mlir::getOrCreateStringConstant(OpBuilder &b, Location loc,
6667

6768
// Not found: create new global.
6869
OpBuilder::InsertionGuard guard(b);
69-
b.setInsertionPointToStart(moduleOp.getBody());
70+
b.setInsertionPointToStart(&moduleOp->getRegion(0).front());
7071
SmallString<16> name = getUniqueSymbolName(moduleOp, namePrefix);
7172
return LLVM::GlobalOp::create(b, loc, globalType,
7273
/*isConstant=*/true, LLVM::Linkage::Internal,
@@ -396,10 +397,11 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite(
396397
auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
397398
mlir::Type llvmI32 = typeConverter->convertType(rewriter.getI32Type());
398399
mlir::Type llvmI64 = typeConverter->convertType(rewriter.getI64Type());
399-
// Note: this is the GPUModule op, not the ModuleOp that surrounds it
400-
// This ensures that global constants and declarations are placed within
401-
// the device code, not the host code
402-
auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>();
400+
401+
Operation *moduleOp = gpuPrintfOp->getParentWithTrait<OpTrait::SymbolTable>();
402+
if (!moduleOp)
403+
return rewriter.notifyMatchFailure(gpuPrintfOp,
404+
"Couldn't find a parent module");
403405

404406
auto ocklBegin =
405407
getOrDefineFunction(moduleOp, loc, rewriter, "__ockl_printf_begin",
@@ -496,10 +498,10 @@ LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite(
496498
mlir::Type ptrType =
497499
LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
498500

499-
// Note: this is the GPUModule op, not the ModuleOp that surrounds it
500-
// This ensures that global constants and declarations are placed within
501-
// the device code, not the host code
502-
auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>();
501+
Operation *moduleOp = gpuPrintfOp->getParentWithTrait<OpTrait::SymbolTable>();
502+
if (!moduleOp)
503+
return rewriter.notifyMatchFailure(gpuPrintfOp,
504+
"Couldn't find a parent module");
503505

504506
auto printfType =
505507
LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {ptrType},
@@ -541,10 +543,10 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite(
541543
mlir::Type llvmI8 = typeConverter->convertType(rewriter.getIntegerType(8));
542544
mlir::Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
543545

544-
// Note: this is the GPUModule op, not the ModuleOp that surrounds it
545-
// This ensures that global constants and declarations are placed within
546-
// the device code, not the host code
547-
auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>();
546+
Operation *moduleOp = gpuPrintfOp->getParentWithTrait<OpTrait::SymbolTable>();
547+
if (!moduleOp)
548+
return rewriter.notifyMatchFailure(gpuPrintfOp,
549+
"Couldn't find a parent module");
548550

549551
// Create a valid global location removing any metadata attached to the
550552
// location as debug info metadata inside of a function cannot be used outside

mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,18 @@ namespace mlir {
1818
// Helper Functions
1919
//===----------------------------------------------------------------------===//
2020

21+
/// Note that these functions don't take a `SymbolTable` because GPU module
22+
/// lowerings can have name collisions as an intermediate state.
23+
2124
/// Find or create an external function declaration in the given module.
22-
LLVM::LLVMFuncOp getOrDefineFunction(gpu::GPUModuleOp moduleOp, Location loc,
25+
LLVM::LLVMFuncOp getOrDefineFunction(Operation *moduleOp, Location loc,
2326
OpBuilder &b, StringRef name,
2427
LLVM::LLVMFunctionType type);
2528

2629
/// Create a global that contains the given string. If a global with the same
2730
/// string already exists in the module, return that global.
2831
LLVM::GlobalOp getOrCreateStringConstant(OpBuilder &b, Location loc,
29-
gpu::GPUModuleOp moduleOp, Type llvmI8,
32+
Operation *moduleOp, Type llvmI8,
3033
StringRef namePrefix, StringRef str,
3134
uint64_t alignment = 0,
3235
unsigned addrSpace = 0);

mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl-hip.mlir

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
// RUN: mlir-opt %s -convert-gpu-to-rocdl='runtime=HIP' -split-input-file | FileCheck %s
22

3+
// CHECK-LABEL: gpu.module @test_module
34
gpu.module @test_module {
45
// CHECK-DAG: llvm.mlir.global internal constant @[[$PRINT_GLOBAL0:[A-Za-z0-9_]+]]("Hello, world\0A\00")
56
// CHECK-DAG: llvm.mlir.global internal constant @[[$PRINT_GLOBAL1:[A-Za-z0-9_]+]]("Hello: %d\0A\00")
@@ -40,3 +41,38 @@ gpu.module @test_module {
4041
gpu.return
4142
}
4243
}
44+
45+
// -----
46+
47+
// The bulitin.module we're targetting is wrapped in a fake gpu.module
48+
// because the convert-gpu-to-rocdl pass only runs an `gpu.module` ops,
49+
// even though the printf patterns could run in other contexts.
50+
51+
// CHECK-LABEL: gpu.module @fake_gpu_module_for_test
52+
// CHECK-LABEL: builtin.module @test_module
53+
gpu.module @fake_gpu_module_for_test {
54+
builtin.module @test_module {
55+
// CHECK-DAG: llvm.mlir.global internal constant @[[$PRINT_GLOBAL1:[A-Za-z0-9_]+]]("Hello: %d\0A\00")
56+
// CHECK-DAG: llvm.func @__ockl_printf_append_args(i64, i32, i64, i64, i64, i64, i64, i64, i64, i32) -> i64
57+
// CHECK-DAG: llvm.func @__ockl_printf_append_string_n(i64, !llvm.ptr, i64, i32) -> i64
58+
// CHECK-DAG: llvm.func @__ockl_printf_begin(i64) -> i64
59+
60+
// CHECK-LABEL: llvm.func @test_printf
61+
// CHECK: (%[[ARG0:.*]]: i32)
62+
llvm.func @test_printf(%arg0: i32) {
63+
// CHECK: %[[CST0:.*]] = llvm.mlir.constant(0 : i64) : i64
64+
// CHECK-NEXT: %[[DESC0:.*]] = llvm.call @__ockl_printf_begin(%0) : (i64) -> i64
65+
// CHECK-NEXT: %[[FORMATSTR:.*]] = llvm.mlir.addressof @[[$PRINT_GLOBAL1]] : !llvm.ptr
66+
// CHECK-NEXT: %[[FORMATSTART:.*]] = llvm.getelementptr %[[FORMATSTR]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<11 x i8>
67+
// CHECK-NEXT: %[[FORMATLEN:.*]] = llvm.mlir.constant(11 : i64) : i64
68+
// CHECK-NEXT: %[[ISLAST:.*]] = llvm.mlir.constant(1 : i32) : i32
69+
// CHECK-NEXT: %[[ISNTLAST:.*]] = llvm.mlir.constant(0 : i32) : i32
70+
// CHECK-NEXT: %[[DESC1:.*]] = llvm.call @__ockl_printf_append_string_n(%[[DESC0]], %[[FORMATSTART]], %[[FORMATLEN]], %[[ISNTLAST]]) : (i64, !llvm.ptr, i64, i32) -> i64
71+
// CHECK-NEXT: %[[NARGS1:.*]] = llvm.mlir.constant(1 : i32) : i32
72+
// CHECK-NEXT: %[[ARG0_64:.*]] = llvm.zext %[[ARG0]] : i32 to i64
73+
// CHECK-NEXT: %{{.*}} = llvm.call @__ockl_printf_append_args(%[[DESC1]], %[[NARGS1]], %[[ARG0_64]], %[[CST0]], %[[CST0]], %[[CST0]], %[[CST0]], %[[CST0]], %[[CST0]], %[[ISLAST]]) : (i64, i32, i64, i64, i64, i64, i64, i64, i64, i32) -> i64
74+
gpu.printf "Hello: %d\n", %arg0 : i32
75+
llvm.return
76+
}
77+
}
78+
}

0 commit comments

Comments
 (0)