Skip to content

Commit 8e79a35

Browse files
authored
[NVIDIA] Replace some NVGPU ops with equivalent NVVM ops (#7420)
This change updates the lowering of WGMMAFenceOp, WGMMACommitGroupOp and ClusterWaitOp to generate NVVM dialect operations instead of inline assembly strings. The NVVM ops will be lowered to LLVM intrinsics in subsequent passes, providing better optimization opportunities. Additionally, unused constant kFenceMbarrierInitOp is cleaned up.
1 parent 299b3bb commit 8e79a35

File tree

5 files changed

+16
-66
lines changed

5 files changed

+16
-66
lines changed

test/Conversion/nvgpu_to_llvm.mlir

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,25 +2,11 @@
22

33
// CHECK-LABEL: @nvvm_syncs
44
llvm.func @nvvm_syncs() {
5-
// CHECK: wgmma.fence.sync.aligned;
6-
nvgpu.wgmma_fence
7-
8-
// CHECK: wgmma.commit_group.sync.aligned;
9-
nvgpu.wgmma_commit_group
10-
11-
// CHECK: barrier.cluster.wait.aligned;
12-
nvgpu.cluster_wait
13-
145
// CHECK: fence.proxy.async.shared::cta;
156
nvgpu.fence_async_shared {bCluster = false}
167
// CHECK: fence.proxy.async.shared::cluster;
178
nvgpu.fence_async_shared {bCluster = true}
189

19-
// CHECK: barrier.cluster.arrive.aligned;
20-
nvgpu.cluster_arrive {relaxed = false}
21-
// CHECK: barrier.cluster.arrive.relaxed.aligned;
22-
nvgpu.cluster_arrive {relaxed = true}
23-
2410
llvm.return
2511
}
2612

third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUOps.td

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -63,14 +63,6 @@ def NVGPU_MemSyncScopeAttr : I32EnumAttr<
6363
class NVGPU_Op<string mnemonic, list<Trait> traits = []> :
6464
LLVM_OpBase<NVGPU_Dialect, mnemonic, traits>;
6565

66-
def NVGPU_WGMMAFenceOp : NVGPU_Op<"wgmma_fence", []> {
67-
let assemblyFormat = "attr-dict";
68-
}
69-
70-
def NVGPU_WGMMACommitGroupOp : NVGPU_Op<"wgmma_commit_group", []> {
71-
let assemblyFormat = "attr-dict";
72-
}
73-
7466
def NVGPU_WGMMAWaitGroupOp : NVGPU_Op<"wgmma_wait_group", [DeclareOpInterfaceMethods<InferTypeOpInterface>,
7567
AllTypesMatch<["input", "output"]>]> {
7668
let arguments = (ins LLVM_AnyStruct:$input, I32Attr:$pendings);
@@ -118,16 +110,6 @@ def NVGPU_FenceAsyncSharedOp : NVGPU_Op<"fence_async_shared", []> {
118110
let assemblyFormat = "attr-dict";
119111
}
120112

121-
def NVGPU_ClusterArriveOp : NVGPU_Op<"cluster_arrive", []> {
122-
let arguments = (ins I1Attr:$relaxed);
123-
124-
let assemblyFormat = "attr-dict";
125-
}
126-
127-
def NVGPU_ClusterWaitOp : NVGPU_Op<"cluster_wait", []> {
128-
let assemblyFormat = "attr-dict";
129-
}
130-
131113
def NVGPU_StoreMatrixOp : NVGPU_Op<"stmatrix", [MemoryEffects<[MemWrite]>]> {
132114
let arguments = (
133115
ins LLVM_PointerShared:$addr,

third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp

Lines changed: 3 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,6 @@ namespace triton {
2323

2424
namespace {
2525

26-
const std::string kWgmmaFenceOp = "wgmma.fence.sync.aligned;";
27-
const std::string kWgmmaCommitGroupOp = "wgmma.commit_group.sync.aligned;";
28-
const std::string kClusterWaitOp = "barrier.cluster.wait.aligned;";
29-
const std::string kFenceMbarrierInitOp = "fence.mbarrier_init.release.cluster;";
3026
const std::string kClusterCtaIdOp = "{\n"
3127
".reg .u32 a<5>; \n"
3228
"mov.u32 a0, %cluster_ctaid.x;\n" // x
@@ -255,19 +251,6 @@ class WarpIdOpPattern : public OpRewritePattern<ttn::WarpIdOp> {
255251
}
256252
};
257253

258-
class ClusterArriveOpPattern : public OpRewritePattern<ttn::ClusterArriveOp> {
259-
public:
260-
using OpRewritePattern<ttn::ClusterArriveOp>::OpRewritePattern;
261-
262-
LogicalResult matchAndRewrite(ttn::ClusterArriveOp op,
263-
PatternRewriter &rewriter) const override {
264-
std::string ptxAsm = op.getRelaxed()
265-
? "barrier.cluster.arrive.relaxed.aligned;"
266-
: "barrier.cluster.arrive.aligned;";
267-
return rewriteAsPtxAsm(op, rewriter, std::move(ptxAsm));
268-
}
269-
};
270-
271254
// Base class for Matrix Operation Patterns
272255
template <typename MatrixOpType, typename ConcreteMatrixOpPattern>
273256
class MatrixOpPattern : public OpRewritePattern<MatrixOpType> {
@@ -788,21 +771,12 @@ class ConvertNVGPUToLLVM
788771
ModuleOp mod = getOperation();
789772
RewritePatternSet patterns(context);
790773

791-
#define POPULATE_NVGPU_OP(SRC_OP, ASM) \
792-
patterns.add<NVGPUOpGenericPattern<SRC_OP>>(context, ASM, Constraints(), \
793-
Constraints());
794-
POPULATE_NVGPU_OP(ttn::WGMMAFenceOp, kWgmmaFenceOp)
795-
POPULATE_NVGPU_OP(ttn::WGMMACommitGroupOp, kWgmmaCommitGroupOp)
796-
POPULATE_NVGPU_OP(ttn::ClusterWaitOp, kClusterWaitOp)
797-
#undef POPULATE_NVGPU_OP
798774
patterns.add<NVGPUOpGenericPattern<ttn::ClusterCTAIdOp>>(
799775
context, kClusterCtaIdOp, Constraints({"=r"}), Constraints());
800776

801-
patterns
802-
.add<FenceAsyncSharedOpPattern, LoadMatrixOpPattern,
803-
StoreMatrixOpPattern, ClusterArriveOpPattern, WGMMAOpPattern,
804-
LoadAcquireOpPattern, WGMMAWaitGroupOpPattern, WarpIdOpPattern>(
805-
context);
777+
patterns.add<FenceAsyncSharedOpPattern, LoadMatrixOpPattern,
778+
StoreMatrixOpPattern, WGMMAOpPattern, LoadAcquireOpPattern,
779+
WGMMAWaitGroupOpPattern, WarpIdOpPattern>(context);
806780

807781
if (applyPatternsGreedily(mod, std::move(patterns)).failed())
808782
signalPassFailure();

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ClusterOpsToLLVM.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "Dialect/NVGPU/IR/Dialect.h"
2525
#include "PatternTritonGPUOpToLLVM.h"
2626
#include "mlir/Conversion/LLVMCommon/Pattern.h"
27+
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
2728
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
2829

2930
using namespace mlir;
@@ -38,8 +39,13 @@ struct ClusterArriveOpConversion
3839
LogicalResult
3940
matchAndRewrite(triton::nvidia_gpu::ClusterArriveOp op, OpAdaptor adaptor,
4041
ConversionPatternRewriter &rewriter) const override {
41-
rewriter.replaceOpWithNewOp<triton::nvgpu::ClusterArriveOp>(
42-
op, op.getRelaxed());
42+
auto ctx = rewriter.getContext();
43+
auto unitAttr = UnitAttr::get(ctx);
44+
if (op.getRelaxed()) {
45+
rewriter.replaceOpWithNewOp<NVVM::ClusterArriveRelaxedOp>(op, unitAttr);
46+
} else {
47+
rewriter.replaceOpWithNewOp<NVVM::ClusterArriveOp>(op, unitAttr);
48+
}
4349
return success();
4450
}
4551
};
@@ -52,7 +58,8 @@ struct ClusterWaitOpConversion
5258
LogicalResult
5359
matchAndRewrite(triton::nvidia_gpu::ClusterWaitOp op, OpAdaptor adaptor,
5460
ConversionPatternRewriter &rewriter) const override {
55-
rewriter.replaceOpWithNewOp<triton::nvgpu::ClusterWaitOp>(op);
61+
auto ctx = rewriter.getContext();
62+
rewriter.replaceOpWithNewOp<NVVM::ClusterWaitOp>(op, UnitAttr::get(ctx));
5663
return success();
5764
}
5865
};

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
#include "MMAHelpers.h"
2525
#include "Utility.h"
26+
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
2627
#include "mlir/Support/LLVM.h"
2728

2829
using namespace mlir;
@@ -408,7 +409,7 @@ LogicalResult convertDot(const LLVMTypeConverter *typeConverter,
408409
: triton::nvgpu::WGMMALayout::col;
409410

410411
auto func = op->getParentOfType<LLVM::LLVMFuncOp>();
411-
Operation *startSequence = rewriter.create<triton::nvgpu::WGMMAFenceOp>(loc);
412+
Operation *startSequence = rewriter.create<NVVM::WgmmaFenceAlignedOp>(loc);
412413
SmallVector<Value> mmaResults;
413414
for (int m = 0; m < numRepM; ++m) {
414415
for (int n = 0; n < numRepN; ++n) {
@@ -479,7 +480,7 @@ LogicalResult convertDot(const LLVMTypeConverter *typeConverter,
479480
}
480481
}
481482
}
482-
rewriter.create<triton::nvgpu::WGMMACommitGroupOp>(loc);
483+
rewriter.create<NVVM::WgmmaGroupSyncAlignedOp>(loc);
483484

484485
if (sync)
485486
mmaResults = emitWait(rewriter, loc, mmaResults, 0);

0 commit comments

Comments
 (0)