Skip to content

Commit 3797a0e

Browse files
Merge OpenAI Triton commit 673ca35 (#4677)
This PR change the Triton base from d2b6150 to 673ca35 (Jul 9). Pass rate: 96.19%
2 parents 7136842 + 8f9f753 commit 3797a0e

File tree

11 files changed

+169
-76
lines changed

11 files changed

+169
-76
lines changed

lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMA.cpp

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "triton/Conversion/TritonGPUToLLVM/FMADotUtility.h"
22
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
3+
#include "llvm/ADT/TypeSwitch.h"
34

45
using namespace mlir;
56
using namespace mlir::triton;
@@ -19,8 +20,26 @@ class GenericFMAVectorMultiplier : public FMAVectorMultiplier {
1920
auto K = a.size();
2021
assert(b.size() == K);
2122
Value accum = c;
22-
for (auto [aElem, bElem] : llvm::zip(a, b))
23-
accum = builder.create<LLVM::FMulAddOp>(loc, aElem, bElem, accum);
23+
Type tgtTy = accum.getType();
24+
for (auto it = llvm::zip(a, b).begin(); it != llvm::zip(a, b).end(); ++it) {
25+
const auto &aElem = std::get<0>(*it);
26+
const auto &bElem = std::get<1>(*it);
27+
28+
assert(aElem.getType() == tgtTy);
29+
assert(bElem.getType() == tgtTy);
30+
31+
// to avoid: 'llvm.intr.fmuladd' op operand #0 must be floating point LLVM
32+
// type or LLVM dialect-compatible vector of floating point LLVM type, but
33+
// got 'i32'
34+
llvm::TypeSwitch<Type>(tgtTy)
35+
.Case<FloatType>([&](auto) {
36+
accum = builder.create<LLVM::FMulAddOp>(loc, aElem, bElem, accum);
37+
})
38+
.Case<IntegerType>([&](auto) {
39+
accum = builder.create<LLVM::AddOp>(
40+
loc, builder.create<LLVM::MulOp>(loc, aElem, bElem), accum);
41+
});
42+
}
2443
return accum;
2544
}
2645
};

lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -775,6 +775,14 @@ static Value promoteOperand(OpBuilder &builder, Location loc, Value operand,
775775
return builder.create<arith::ExtFOp>(loc, tensorPromotedType, operand);
776776
}
777777

778+
static bool mmav2SupportsFp8Operands(int computeCapability) {
779+
// promote operands for sm < 89 since fp8 mma is not natively supported
780+
// although PTX instructions for mma v2 w/ fp8 operands exist for sm90 and
781+
// sm100, they are emulated as fp16 upcasts + fp16 HMMA in SASS. sm120 has
782+
// hardware support for fp8 operands w/ mmav2.
783+
return computeCapability == 89 || computeCapability == 120;
784+
}
785+
778786
// promote operands of dot op if the existing combination is not natively
779787
// supported.
780788
static void decomposeMixedModeDotOp(ModuleOp mod, int computeCapability) {
@@ -787,10 +795,10 @@ static void decomposeMixedModeDotOp(ModuleOp mod, int computeCapability) {
787795
dyn_cast<NvidiaMmaEncodingAttr>(D.getType().getEncoding());
788796
if (mmaLayout) {
789797
bool isNativeFP8 = llvm::isa<Float8E5M2Type, Float8E4M3FNType>(AElType);
790-
// promote operands for sm < 89 since fp8 mma is not natively supported
791-
// promote operands for sm >= 90 when mma is not v3
798+
// promote to f16 unless there's hardware support for fp8 operands
792799
if (!isNativeFP8 ||
793-
(isNativeFP8 && (computeCapability == 89 || mmaLayout.isHopper())))
800+
(isNativeFP8 && (mmav2SupportsFp8Operands(computeCapability) ||
801+
mmaLayout.isHopper())))
794802
return;
795803
promoteType = builder.getF16Type();
796804
} else {

test/Conversion/nvgpu_to_llvm.mlir

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,25 +2,11 @@
22

33
// CHECK-LABEL: @nvvm_syncs
44
llvm.func @nvvm_syncs() {
5-
// CHECK: wgmma.fence.sync.aligned;
6-
nvgpu.wgmma_fence
7-
8-
// CHECK: wgmma.commit_group.sync.aligned;
9-
nvgpu.wgmma_commit_group
10-
11-
// CHECK: barrier.cluster.wait.aligned;
12-
nvgpu.cluster_wait
13-
145
// CHECK: fence.proxy.async.shared::cta;
156
nvgpu.fence_async_shared {bCluster = false}
167
// CHECK: fence.proxy.async.shared::cluster;
178
nvgpu.fence_async_shared {bCluster = true}
189

19-
// CHECK: barrier.cluster.arrive.aligned;
20-
nvgpu.cluster_arrive {relaxed = false}
21-
// CHECK: barrier.cluster.arrive.relaxed.aligned;
22-
nvgpu.cluster_arrive {relaxed = true}
23-
2410
llvm.return
2511
}
2612

test/Conversion/tritongpu_to_llvm.mlir

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1347,6 +1347,32 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
13471347

13481348
// -----
13491349

1350+
#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
1351+
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
1352+
#dot_operand_a = #ttg.dot_op<{opIdx=0, parent=#blocked}>
1353+
#dot_operand_b = #ttg.dot_op<{opIdx=1, parent=#blocked}>
1354+
#smem = #ttg.shared_memory
1355+
module attributes {"ttg.target" = "cuda:70", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
1356+
// CHECK-LABEL: matmul_fmadot_integer
1357+
tt.func @matmul_fmadot_integer(%ptr:!tt.ptr<i32> {tt.divisibility = 16 : i32},
1358+
%a:!ttg.memdesc<32x16xi32, #shared, #smem>, %b:!ttg.memdesc<16x32xi32, #shared, #smem>) {
1359+
%cst = arith.constant dense<0> : tensor<32x32xi32, #blocked>
1360+
// CHECK-NOT: llvm.intr.fmuladd
1361+
// CHECK: llvm.mul
1362+
// CHECK: llvm.add
1363+
%a_mat = ttg.local_load %a : !ttg.memdesc<32x16xi32, #shared, #smem> -> tensor<32x16xi32, #dot_operand_a>
1364+
%b_mat = ttg.local_load %b : !ttg.memdesc<16x32xi32, #shared, #smem> -> tensor<16x32xi32, #dot_operand_b>
1365+
1366+
%28 = tt.dot %a_mat, %b_mat, %cst, inputPrecision = ieee : tensor<32x16xi32, #dot_operand_a> * tensor<16x32xi32, #dot_operand_b> -> tensor<32x32xi32, #blocked>
1367+
%30 = tt.splat %ptr : !tt.ptr<i32> -> tensor<32x1x!tt.ptr<i32>, #blocked>
1368+
%36 = tt.broadcast %30 : tensor<32x1x!tt.ptr<i32>, #blocked> -> tensor<32x32x!tt.ptr<i32>, #blocked>
1369+
tt.store %36, %28 : tensor<32x32x!tt.ptr<i32>, #blocked>
1370+
tt.return
1371+
}
1372+
}
1373+
1374+
// -----
1375+
13501376
#mma = #ttg.nvidia_mma<{versionMajor=2, warpsPerCTA=[2, 2], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}>
13511377
#shared = #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
13521378
#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
@@ -2257,6 +2283,39 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
22572283

22582284
// -----
22592285

2286+
#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [16, 8]}>
2287+
module attributes {"ttg.num-warps" = 8 : i32, ttg.target = "cuda:120"} {
2288+
// CHECK-LABEL: mmav2_e5m2_e5m2_fp16
2289+
tt.func public @mmav2_e5m2_e5m2_fp16(%arg0: tensor<32x32xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>, %arg1: tensor<32x32xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>, %arg2: tensor<32x32xf16, #mma>) {
2290+
// CHECK: mma.{{.*}}.col.f16.e5m2.e5m2.f16
2291+
%0 = tt.dot %arg0, %arg1, %arg2 {maxNumImpreciseAcc = 1073741824 : i32} : tensor<32x32xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<32x32xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<32x32xf16, #mma>
2292+
tt.return
2293+
}
2294+
2295+
// CHECK-LABEL: mmav2_e5m2_e4m3_fp16
2296+
tt.func public @mmav2_e5m2_e4m3_fp16(%arg0: tensor<32x32xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>, %arg1: tensor<32x32xf8E4M3FN, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>, %arg2: tensor<32x32xf16, #mma>) {
2297+
// CHECK: mma.{{.*}}.col.f16.e5m2.e4m3.f16
2298+
%0 = tt.dot %arg0, %arg1, %arg2 {maxNumImpreciseAcc = 1073741824 : i32} : tensor<32x32xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<32x32xf8E4M3FN, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<32x32xf16, #mma>
2299+
tt.return
2300+
}
2301+
2302+
// CHECK-LABEL: mmav2_e4m3_e5m2_fp16
2303+
tt.func public @mmav2_e4m3_e5m2_fp16(%arg0: tensor<32x32xf8E4M3FN, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>, %arg1: tensor<32x32xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>, %arg2: tensor<32x32xf16, #mma>) {
2304+
// CHECK: mma.{{.*}}.col.f16.e4m3.e5m2.f16
2305+
%0 = tt.dot %arg0, %arg1, %arg2 {maxNumImpreciseAcc = 1073741824 : i32} : tensor<32x32xf8E4M3FN, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<32x32xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<32x32xf16, #mma>
2306+
tt.return
2307+
}
2308+
2309+
// CHECK-LABEL: mmav2_e4m3_e4m3_fp16
2310+
tt.func public @mmav2_e4m3_e4m3_fp16(%arg0: tensor<32x32xf8E4M3FN, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>, %arg1: tensor<32x32xf8E4M3FN, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>, %arg2: tensor<32x32xf16, #mma>) {
2311+
// CHECK: mma.{{.*}}.col.f16.e4m3.e4m3.f16
2312+
%0 = tt.dot %arg0, %arg1, %arg2 {maxNumImpreciseAcc = 1073741824 : i32} : tensor<32x32xf8E4M3FN, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<32x32xf8E4M3FN, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<32x32xf16, #mma>
2313+
tt.return
2314+
}
2315+
}
2316+
2317+
// -----
2318+
22602319
#blocked = #ttg.blocked<{sizePerThread = [1, 1, 16], threadsPerWarp = [4, 4, 2], warpsPerCTA = [8, 1, 1], order = [2, 1, 0]}>
22612320
#linear = #ttg.linear<{register = [[0, 0], [0, 0], [0, 0], [0, 0]], lane = [[0, 0], [0, 1], [0, 2], [1, 0], [2, 0]], warp = [[4, 0], [8, 0], [16, 0]], block = []}>
22622321

test/TritonGPU/accelerate-matmul.mlir

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -562,6 +562,27 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
562562
}
563563
}
564564

565+
// -----
566+
567+
#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>
568+
#blocked1 = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0]}>
569+
#blocked2 = #ttg.blocked<{sizePerThread = [16, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 8], order = [0, 1]}>
570+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:120", "ttg.threads-per-warp" = 32 : i32} {
571+
// CHECK-LABEL: sm120_fp8_dot
572+
tt.func public @sm120_fp8_dot(%arg0: tensor<128x256xf32, #blocked>, %arg1: tensor<128x128x!tt.ptr<f8E4M3FN>, #blocked1>, %arg2: tensor<128x256x!tt.ptr<f8E4M3FN>, #blocked2>, %arg3: tensor<128x128xi1, #blocked1>, %arg4: tensor<128x256xi1, #blocked2>) -> tensor<128x256xf32, #blocked> {
573+
%cst = arith.constant dense<0.000000e+00> : tensor<128x256xf8E4M3FN, #blocked2>
574+
%cst_0 = arith.constant dense<0.000000e+00> : tensor<128x128xf8E4M3FN, #blocked1>
575+
%0 = tt.load %arg1, %arg3, %cst_0 : tensor<128x128x!tt.ptr<f8E4M3FN>, #blocked1>
576+
%1 = tt.load %arg2, %arg4, %cst : tensor<128x256x!tt.ptr<f8E4M3FN>, #blocked2>
577+
%2 = ttg.convert_layout %0 : tensor<128x128xf8E4M3FN, #blocked1> -> tensor<128x128xf8E4M3FN, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>
578+
%3 = ttg.convert_layout %1 : tensor<128x256xf8E4M3FN, #blocked2> -> tensor<128x256xf8E4M3FN, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>
579+
// CHECK: {{.*}} = tt.dot {{.*}} tensor<128x128xf8E4M3FN
580+
%4 = tt.dot %2, %3, %arg0, inputPrecision = tf32 : tensor<128x128xf8E4M3FN, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<128x256xf8E4M3FN, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x256xf32, #blocked>
581+
tt.return %4 : tensor<128x256xf32, #blocked>
582+
}
583+
}
584+
585+
565586
// -----
566587

567588
#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>

third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUtility.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,15 +50,19 @@ createTmpLayout(triton::gpu::DistributedEncodingTrait layout,
5050
src.getOrder(), src.getCTALayout());
5151
if (auto src = dyn_cast<triton::gpu::DotOperandEncodingAttr>(layout)) {
5252
auto parent = cast<triton::gpu::DistributedEncodingTrait>(src.getParent());
53-
return triton::gpu::DotOperandEncodingAttr::get(
54-
ctx, src.getOpIdx(), createTmpLayout(parent, warpsPerCTA),
55-
src.getKWidth());
53+
parent = createTmpLayout(parent, warpsPerCTA);
54+
if (!parent)
55+
return {};
56+
return triton::gpu::DotOperandEncodingAttr::get(ctx, src.getOpIdx(), parent,
57+
src.getKWidth());
5658
}
5759
if (auto src = dyn_cast<triton::gpu::SliceEncodingAttr>(layout)) {
5860
auto warps = to_vector(warpsPerCTA);
5961
warps.insert(warps.begin() + src.getDim(), 1);
60-
return triton::gpu::SliceEncodingAttr::get(
61-
ctx, src.getDim(), createTmpLayout(src.getParent(), warps));
62+
auto parent = createTmpLayout(src.getParent(), warps);
63+
if (!parent)
64+
return {};
65+
return triton::gpu::SliceEncodingAttr::get(ctx, src.getDim(), parent);
6266
}
6367
// TODO: support linear layout if needed.
6468
if (isa<triton::gpu::LinearEncodingAttr>(layout))

third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUOps.td

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -63,14 +63,6 @@ def NVGPU_MemSyncScopeAttr : I32EnumAttr<
6363
class NVGPU_Op<string mnemonic, list<Trait> traits = []> :
6464
LLVM_OpBase<NVGPU_Dialect, mnemonic, traits>;
6565

66-
def NVGPU_WGMMAFenceOp : NVGPU_Op<"wgmma_fence", []> {
67-
let assemblyFormat = "attr-dict";
68-
}
69-
70-
def NVGPU_WGMMACommitGroupOp : NVGPU_Op<"wgmma_commit_group", []> {
71-
let assemblyFormat = "attr-dict";
72-
}
73-
7466
def NVGPU_WGMMAWaitGroupOp : NVGPU_Op<"wgmma_wait_group", [DeclareOpInterfaceMethods<InferTypeOpInterface>,
7567
AllTypesMatch<["input", "output"]>]> {
7668
let arguments = (ins LLVM_AnyStruct:$input, I32Attr:$pendings);
@@ -118,16 +110,6 @@ def NVGPU_FenceAsyncSharedOp : NVGPU_Op<"fence_async_shared", []> {
118110
let assemblyFormat = "attr-dict";
119111
}
120112

121-
def NVGPU_ClusterArriveOp : NVGPU_Op<"cluster_arrive", []> {
122-
let arguments = (ins I1Attr:$relaxed);
123-
124-
let assemblyFormat = "attr-dict";
125-
}
126-
127-
def NVGPU_ClusterWaitOp : NVGPU_Op<"cluster_wait", []> {
128-
let assemblyFormat = "attr-dict";
129-
}
130-
131113
def NVGPU_StoreMatrixOp : NVGPU_Op<"stmatrix", [MemoryEffects<[MemWrite]>]> {
132114
let arguments = (
133115
ins LLVM_PointerShared:$addr,

third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp

Lines changed: 3 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,6 @@ namespace triton {
2323

2424
namespace {
2525

26-
const std::string kWgmmaFenceOp = "wgmma.fence.sync.aligned;";
27-
const std::string kWgmmaCommitGroupOp = "wgmma.commit_group.sync.aligned;";
28-
const std::string kClusterWaitOp = "barrier.cluster.wait.aligned;";
29-
const std::string kFenceMbarrierInitOp = "fence.mbarrier_init.release.cluster;";
3026
const std::string kClusterCtaIdOp = "{\n"
3127
".reg .u32 a<5>; \n"
3228
"mov.u32 a0, %cluster_ctaid.x;\n" // x
@@ -255,19 +251,6 @@ class WarpIdOpPattern : public OpRewritePattern<ttn::WarpIdOp> {
255251
}
256252
};
257253

258-
class ClusterArriveOpPattern : public OpRewritePattern<ttn::ClusterArriveOp> {
259-
public:
260-
using OpRewritePattern<ttn::ClusterArriveOp>::OpRewritePattern;
261-
262-
LogicalResult matchAndRewrite(ttn::ClusterArriveOp op,
263-
PatternRewriter &rewriter) const override {
264-
std::string ptxAsm = op.getRelaxed()
265-
? "barrier.cluster.arrive.relaxed.aligned;"
266-
: "barrier.cluster.arrive.aligned;";
267-
return rewriteAsPtxAsm(op, rewriter, std::move(ptxAsm));
268-
}
269-
};
270-
271254
// Base class for Matrix Operation Patterns
272255
template <typename MatrixOpType, typename ConcreteMatrixOpPattern>
273256
class MatrixOpPattern : public OpRewritePattern<MatrixOpType> {
@@ -788,21 +771,12 @@ class ConvertNVGPUToLLVM
788771
ModuleOp mod = getOperation();
789772
RewritePatternSet patterns(context);
790773

791-
#define POPULATE_NVGPU_OP(SRC_OP, ASM) \
792-
patterns.add<NVGPUOpGenericPattern<SRC_OP>>(context, ASM, Constraints(), \
793-
Constraints());
794-
POPULATE_NVGPU_OP(ttn::WGMMAFenceOp, kWgmmaFenceOp)
795-
POPULATE_NVGPU_OP(ttn::WGMMACommitGroupOp, kWgmmaCommitGroupOp)
796-
POPULATE_NVGPU_OP(ttn::ClusterWaitOp, kClusterWaitOp)
797-
#undef POPULATE_NVGPU_OP
798774
patterns.add<NVGPUOpGenericPattern<ttn::ClusterCTAIdOp>>(
799775
context, kClusterCtaIdOp, Constraints({"=r"}), Constraints());
800776

801-
patterns
802-
.add<FenceAsyncSharedOpPattern, LoadMatrixOpPattern,
803-
StoreMatrixOpPattern, ClusterArriveOpPattern, WGMMAOpPattern,
804-
LoadAcquireOpPattern, WGMMAWaitGroupOpPattern, WarpIdOpPattern>(
805-
context);
777+
patterns.add<FenceAsyncSharedOpPattern, LoadMatrixOpPattern,
778+
StoreMatrixOpPattern, WGMMAOpPattern, LoadAcquireOpPattern,
779+
WGMMAWaitGroupOpPattern, WarpIdOpPattern>(context);
806780

807781
if (applyPatternsGreedily(mod, std::move(patterns)).failed())
808782
signalPassFailure();

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ClusterOpsToLLVM.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "Dialect/NVGPU/IR/Dialect.h"
2525
#include "PatternTritonGPUOpToLLVM.h"
2626
#include "mlir/Conversion/LLVMCommon/Pattern.h"
27+
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
2728
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
2829

2930
using namespace mlir;
@@ -38,8 +39,13 @@ struct ClusterArriveOpConversion
3839
LogicalResult
3940
matchAndRewrite(triton::nvidia_gpu::ClusterArriveOp op, OpAdaptor adaptor,
4041
ConversionPatternRewriter &rewriter) const override {
41-
rewriter.replaceOpWithNewOp<triton::nvgpu::ClusterArriveOp>(
42-
op, op.getRelaxed());
42+
auto ctx = rewriter.getContext();
43+
auto unitAttr = UnitAttr::get(ctx);
44+
if (op.getRelaxed()) {
45+
rewriter.replaceOpWithNewOp<NVVM::ClusterArriveRelaxedOp>(op, unitAttr);
46+
} else {
47+
rewriter.replaceOpWithNewOp<NVVM::ClusterArriveOp>(op, unitAttr);
48+
}
4349
return success();
4450
}
4551
};
@@ -52,7 +58,8 @@ struct ClusterWaitOpConversion
5258
LogicalResult
5359
matchAndRewrite(triton::nvidia_gpu::ClusterWaitOp op, OpAdaptor adaptor,
5460
ConversionPatternRewriter &rewriter) const override {
55-
rewriter.replaceOpWithNewOp<triton::nvgpu::ClusterWaitOp>(op);
61+
auto ctx = rewriter.getContext();
62+
rewriter.replaceOpWithNewOp<NVVM::ClusterWaitOp>(op, UnitAttr::get(ctx));
5663
return success();
5764
}
5865
};

0 commit comments

Comments
 (0)