Skip to content

Commit 579520a

Browse files
committed
[mlir][rocdl] Add rocdl inlining interface
All rocdl ops should be safe to inline.
1 parent fd7aae3 commit 579520a

File tree

2 files changed

+25
-0
lines changed

2 files changed

+25
-0
lines changed

mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "mlir/IR/DialectImplementation.h"
2424
#include "mlir/IR/MLIRContext.h"
2525
#include "mlir/IR/Operation.h"
26+
#include "mlir/Transforms/InliningUtils.h"
2627
#include "llvm/ADT/TypeSwitch.h"
2728

2829
using namespace mlir;
@@ -180,6 +181,15 @@ void RawBufferAtomicUMinOp::print(mlir::OpAsmPrinter &p) {
180181
// ROCDLDialect initialization, type parsing, and registration.
181182
//===----------------------------------------------------------------------===//
182183

184+
namespace {
185+
struct ROCDLInlinerInterface final : DialectInlinerInterface {
186+
using DialectInlinerInterface::DialectInlinerInterface;
187+
bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
188+
return true;
189+
}
190+
};
191+
} // namespace
192+
183193
// TODO: This should be the llvm.rocdl dialect once this is supported.
184194
void ROCDLDialect::initialize() {
185195
addOperations<
@@ -194,6 +204,7 @@ void ROCDLDialect::initialize() {
194204

195205
// Support unknown operations because not all ROCDL operations are registered.
196206
allowUnknownOperations();
207+
addInterfaces<ROCDLInlinerInterface>();
197208
declarePromisedInterface<gpu::TargetAttrInterface, ROCDLTargetAttr>();
198209
}
199210

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
// RUN: mlir-opt %s --inline | FileCheck %s
2+
3+
llvm.func @threadidx() -> i32 {
4+
%tid = rocdl.workitem.id.x : i32
5+
llvm.return %tid : i32
6+
}
7+
8+
// CHECK-LABEL: func @caller
9+
llvm.func @caller() -> i32 {
10+
// CHECK-NOT: llvm.call @threadidx
11+
// CHECK: rocdl.workitem.id.x
12+
%z = llvm.call @threadidx() : () -> (i32)
13+
llvm.return %z : i32
14+
}

0 commit comments

Comments
 (0)