Skip to content

Commit 3bb0b5f

Browse files
Merge OpenAI Triton commit 418c127 (#4517)
This PR change the Triton base from 717997b to 418c127 (Jun 11). Pass rate: 97.11%
2 parents f737aee + 9f22aac commit 3bb0b5f

File tree

18 files changed

+33
-36
lines changed

18 files changed

+33
-36
lines changed

cmake/llvm-hash.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
e12cbd8339b89563059c2bb2a312579b652560d0
1+
8957e64a20fc7f4277565c6cfe3e555c119783ce

lib/Analysis/Utility.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -960,7 +960,7 @@ SetVector<Operation *> multiRootGetSlice(Operation *op,
960960
BackwardSliceOptions opt;
961961
opt.omitBlockArguments = true;
962962
opt.filter = backwardFilter;
963-
getBackwardSlice(currentOp, &backwardSlice, opt);
963+
(void)getBackwardSlice(currentOp, &backwardSlice, opt);
964964
slice.insert(backwardSlice.begin(), backwardSlice.end());
965965

966966
// Compute and insert the forwardSlice starting from currentOp.

lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,7 @@ struct ElementwiseInlineAsmOpConversion
298298
/*asm_string=*/op.getAsmString(),
299299
/*constraints=*/op.getConstraints(),
300300
/*has_side_effects=*/!op.getPure(),
301-
/*is_align_stack=*/false,
301+
/*is_align_stack=*/false, LLVM::TailCallKind::None,
302302
/*asm_dialect=*/
303303
LLVM::AsmDialectAttr::get(rewriter.getContext(),
304304
LLVM::AsmDialect::AD_ATT),

lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ static int computeOrigBitWidth(Value x) {
257257
mlir::BackwardSliceOptions opt;
258258
opt.omitBlockArguments = true;
259259
opt.filter = bwdFilter;
260-
getBackwardSlice(x, &slice, opt);
260+
(void)getBackwardSlice(x, &slice, opt);
261261

262262
// TODO: This heuristic may be a bit too coarse and may need improving
263263
// If the chain contains a fp4 to fp16/bf16 conversion, then the original

lib/Dialect/TritonGPU/Transforms/Pipeliner/ScheduleLoops.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ CoarseSchedule::Cluster schedulePrologueAndEpilogue(scf::ForOp forOp,
271271
BackwardSliceOptions opt;
272272
opt.omitBlockArguments = true;
273273
opt.omitUsesFromAbove = false;
274-
getBackwardSlice((Operation *)op, &backwardSlice, opt);
274+
(void)getBackwardSlice((Operation *)op, &backwardSlice, opt);
275275

276276
for (auto op : backwardSlice) {
277277
if (auto ifOp = dyn_cast<scf::IfOp>(op)) {

lib/Dialect/TritonGPU/Transforms/Pipeliner/WGMMAPipeline.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ static void threadValuesThroughWait(ttng::WarpGroupDotWaitOp wait,
197197
return op->getBlock() == wait->getBlock();
198198
};
199199
SetVector<Operation *> slice;
200-
getBackwardSlice(v, &slice, options);
200+
(void)getBackwardSlice(v, &slice, options);
201201
}
202202

203203
for (ttng::WarpGroupDotOp dot : asyncDots) {

python/test/unit/language/test_core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6980,7 +6980,7 @@ def test_tl_range_num_stages(device):
69806980
if capability[0] >= 8:
69816981
ptx = pgm.asm['ptx']
69826982
# check that the loop got pipelined with the right number of stages.
6983-
assert 'cp.async.wait_group 6' in ptx
6983+
assert 'cp.async.wait_group \t6' in ptx
69846984

69856985

69866986
def test_tl_range_fuse():

test/Triton/reproducer.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,4 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-
1717
#-}
1818

1919
// CHECK: Pass Manager with
20-
// CHECK-NEXT: convert-triton-gpu-to-llvm
20+
// CHECK: convert-triton-gpu-to-llvm

third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -37,30 +37,25 @@ cvtScalePkUpcastFromFp8(Location loc, ConversionPatternRewriter &rewriter,
3737
fp8x4Vec = b.insert_element(fp8x4VecTy, fp8x4Vec, v1, idx1);
3838
auto i32v = b.bitcast(fp8x4Vec, i32_ty);
3939

40-
auto resType = i32_ty;
41-
auto dstType = f32_ty;
40+
Type resElemType;
4241
if constexpr (std::is_same_v<ConvertOp, ROCDL::CvtScaleF32PkF32Fp8Op> ||
4342
std::is_same_v<ConvertOp, ROCDL::CvtScaleF32PkF32Bf8Op>) {
44-
resType = i64_ty;
45-
dstType = f32_ty;
43+
resElemType = f32_ty;
4644
} else if constexpr (std::is_same_v<ConvertOp,
4745
ROCDL::CvtScaleF32PkF16Fp8Op> ||
4846
std::is_same_v<ConvertOp,
4947
ROCDL::CvtScaleF32PkF16Bf8Op>) {
50-
resType = i32_ty;
51-
dstType = f16_ty;
48+
resElemType = f16_ty;
5249
} else {
53-
resType = i32_ty;
54-
dstType = bf16_ty;
50+
resElemType = bf16_ty;
5551
}
52+
Type resType = vec_ty(resElemType, 2);
5653
Value scale = b.f32_val(1);
57-
Value select = b.false_val();
58-
auto result = rewriter.create<ConvertOp>(loc, resType, i32v, scale, select);
59-
auto retVecTy = vec_ty(dstType, 2);
60-
auto retVec = b.bitcast(result, retVecTy);
54+
auto result = rewriter.create<ConvertOp>(loc, resType, i32v, scale,
55+
/*srcLoHiSel=*/false);
6156
SmallVector<Value> ret(2);
62-
ret[0] = b.extract_element(dstType, retVec, idx0);
63-
ret[1] = b.extract_element(dstType, retVec, idx1);
57+
ret[0] = b.extract_element(resElemType, result, idx0);
58+
ret[1] = b.extract_element(resElemType, result, idx1);
6459
return ret;
6560
}
6661

@@ -73,13 +68,12 @@ cvtScalePkDowncastToFp8(Location loc, ConversionPatternRewriter &rewriter,
7368
Type v2I16Ty = vec_ty(i16_ty, 2);
7469
Value v2I16Vec = b.undef(v2I16Ty);
7570
Value scale = b.f32_val(1);
76-
Value select = b.false_val();
7771

7872
Value result;
7973
if constexpr (std::is_same_v<ConvertOp, ROCDL::CvtScaleF32PkFp8F32Op> ||
8074
std::is_same_v<ConvertOp, ROCDL::CvtScaleF32PkBf8F32Op>) {
8175
result = rewriter.create<ConvertOp>(loc, v2I16Ty, v2I16Vec, v0, v1, scale,
82-
select);
76+
/*dstLoHiSel=*/false);
8377
} else {
8478
Type v2F16Ty = vec_ty(v0.getType(), 2);
8579
Value srcVec = b.undef(v2F16Ty);
@@ -88,7 +82,7 @@ cvtScalePkDowncastToFp8(Location loc, ConversionPatternRewriter &rewriter,
8882
srcVec = b.insert_element(v2F16Ty, srcVec, v0, idx0);
8983
srcVec = b.insert_element(v2F16Ty, srcVec, v1, idx1);
9084
result = rewriter.create<ConvertOp>(loc, v2I16Ty, v2I16Vec, srcVec, scale,
91-
select);
85+
/*dstLoHiSel=*/false);
9286
}
9387
auto fp8x4VecTy = vec_ty(i8_ty, 4);
9488
auto fp8x4Vec = b.bitcast(result, fp8x4VecTy);
@@ -312,8 +306,8 @@ static SmallVector<Value> cvtPkF8ToFp32(Location loc,
312306
auto resType = i64_ty;
313307
auto dstType = f32_ty;
314308

315-
Value select = b.false_val();
316-
auto result = rewriter.create<ConvertOp>(loc, resType, i32v, select);
309+
auto result =
310+
rewriter.create<ConvertOp>(loc, resType, i32v, /*wordSel=*/false);
317311
auto f32x2VecTy = vec_ty(dstType, 2);
318312
auto retVec = b.bitcast(result, f32x2VecTy);
319313
SmallVector<Value> ret(2);
@@ -330,10 +324,10 @@ static SmallVector<Value> cvtPkFp32ToF8(Location loc,
330324
auto b = TritonLLVMOpBuilder(loc, rewriter);
331325
Type v2I16Ty = vec_ty(i16_ty, 2);
332326
Value old = b.undef(i32_ty);
333-
Value select = b.false_val();
334327

335328
Value result;
336-
result = rewriter.create<ConvertOp>(loc, v2I16Ty, v0, v1, old, select);
329+
result =
330+
rewriter.create<ConvertOp>(loc, v2I16Ty, v0, v1, old, /*wordSel=*/false);
337331
auto fp8x4VecTy = vec_ty(i8_ty, 4);
338332
auto fp8x4Vec = b.bitcast(result, fp8x4VecTy);
339333
SmallVector<Value> ret(2);

third_party/amd/lib/TritonAMDGPUToLLVM/GCNAsmFormat.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ mlir::Value GCNBuilder::launch(RewriterBase &rewriter, Location loc, Type resTy,
8282
getConstraints(), // constraints
8383
hasSideEffect, // has_side_effects
8484
isAlignStack, // is_align_stack
85+
LLVM::TailCallKind::None,
8586
LLVM::AsmDialectAttr::get(ctx,
8687
LLVM::AsmDialect::AD_ATT), // asm_dialect
8788
ArrayAttr::get(ctx, attrs) // operand_attrs

0 commit comments

Comments
 (0)