Skip to content

Commit 2f5a031

Browse files
authored
[NVIDIA] Replace inline assembly for the lowering of ttn::ClusterCTAIdOp (#7512)
This PR refactors the ClusterCTAIdOp conversion from using inline PTX assembly to a series of operations (including some NVVM ops that can generate intrinsic call), preserving more semantic information at the LLVM level. While the new implementation expands the computation into separate multiply and add operations, the backend will typically optimize them into `mad`, so there is no performance regression.
1 parent 3772dbd commit 2f5a031

File tree

2 files changed

+29
-21
lines changed

2 files changed

+29
-21
lines changed

test/Conversion/nvgpu_to_llvm.mlir

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22

33
// CHECK-LABEL: @cluster_id
44
llvm.func @cluster_id() -> i32 {
5-
// CHECK: %cluster_ctaid.x;
6-
// CHECK-SAME: %cluster_ctaid.y;
7-
// CHECK-SAME: %cluster_ctaid.z;
8-
// CHECK-SAME: %cluster_nctaid.x;
9-
// CHECK-SAME: %cluster_nctaid.y;
5+
// CHECK: nvvm.read.ptx.sreg.cluster.ctaid.x
6+
// CHECK: nvvm.read.ptx.sreg.cluster.ctaid.y
7+
// CHECK: nvvm.read.ptx.sreg.cluster.ctaid.z
8+
// CHECK: nvvm.read.ptx.sreg.cluster.nctaid.x
9+
// CHECK: nvvm.read.ptx.sreg.cluster.nctaid.y
1010
%id = nvgpu.cluster_id
1111
llvm.return %id : i32
1212
}

third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -23,17 +23,6 @@ namespace triton {
2323

2424
namespace {
2525

26-
const std::string kClusterCtaIdOp = "{\n"
27-
".reg .u32 a<5>; \n"
28-
"mov.u32 a0, %cluster_ctaid.x;\n" // x
29-
"mov.u32 a1, %cluster_ctaid.y;\n" // y
30-
"mov.u32 a2, %cluster_ctaid.z;\n" // z
31-
"mov.u32 a3, %cluster_nctaid.x;\n" // nx
32-
"mov.u32 a4, %cluster_nctaid.y;\n" // ny
33-
"mad.lo.u32 a1, a2, a4, a1; \n"
34-
"mad.lo.u32 $0, a1, a3, a0; \n"
35-
"}";
36-
3726
bool isNumber(const std::string &s) {
3827
return !s.empty() && std::find_if(s.begin(), s.end(), [](unsigned char c) {
3928
return !std::isdigit(c);
@@ -238,6 +227,26 @@ class WarpIdOpPattern : public OpRewritePattern<ttn::WarpIdOp> {
238227
}
239228
};
240229

230+
class ClusterCTAIdOpPattern : public OpRewritePattern<ttn::ClusterCTAIdOp> {
231+
using OpRewritePattern<ttn::ClusterCTAIdOp>::OpRewritePattern;
232+
233+
LogicalResult matchAndRewrite(ttn::ClusterCTAIdOp op,
234+
PatternRewriter &rewriter) const override {
235+
auto loc = op.getLoc();
236+
auto a0 = rewriter.create<NVVM::BlockInClusterIdXOp>(loc, i32_ty);
237+
auto a1 = rewriter.create<NVVM::BlockInClusterIdYOp>(loc, i32_ty);
238+
auto a2 = rewriter.create<NVVM::BlockInClusterIdZOp>(loc, i32_ty);
239+
auto a3 = rewriter.create<NVVM::ClusterDimBlocksXOp>(loc, i32_ty);
240+
auto a4 = rewriter.create<NVVM::ClusterDimBlocksYOp>(loc, i32_ty);
241+
auto p1 = rewriter.create<LLVM::MulOp>(loc, a2, a4);
242+
auto s1 = rewriter.create<LLVM::AddOp>(loc, a1, p1);
243+
auto p2 = rewriter.create<LLVM::MulOp>(loc, s1, a3);
244+
auto res = rewriter.create<LLVM::AddOp>(loc, a0, p2);
245+
rewriter.replaceOp(op, res);
246+
return success();
247+
}
248+
};
249+
241250
// Base class for Matrix Operation Patterns
242251
template <typename MatrixOpType, typename ConcreteMatrixOpPattern>
243252
class MatrixOpPattern : public OpRewritePattern<MatrixOpType> {
@@ -736,11 +745,10 @@ class ConvertNVGPUToLLVM
736745
ModuleOp mod = getOperation();
737746
RewritePatternSet patterns(context);
738747

739-
patterns.add<NVGPUOpGenericPattern<ttn::ClusterCTAIdOp>>(
740-
context, kClusterCtaIdOp, Constraints({"=r"}), Constraints());
741-
742-
patterns.add<LoadMatrixOpPattern, WGMMAOpPattern, LoadAcquireOpPattern,
743-
WGMMAWaitGroupOpPattern, WarpIdOpPattern>(context);
748+
patterns
749+
.add<ClusterCTAIdOpPattern, LoadMatrixOpPattern, WGMMAOpPattern,
750+
LoadAcquireOpPattern, WGMMAWaitGroupOpPattern, WarpIdOpPattern>(
751+
context);
744752

745753
if (applyPatternsGreedily(mod, std::move(patterns)).failed())
746754
signalPassFailure();

0 commit comments

Comments
 (0)