Skip to content

Commit 847d9fc

Browse files
Merge OpenAI Triton commit 04159ed (#3574)
This PR change the Triton base from 3523ab4 to 04159ed (Feb 26). Pass rate: 90.1%->92.35%
2 parents 178ebd4 + 0beb855 commit 847d9fc

File tree

51 files changed

+975
-957
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+975
-957
lines changed

bin/RegisterTritonDialects.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
9090
// TritonAMDGPUTransforms passes
9191
mlir::registerTritonAMDGPUAccelerateMatmul();
9292
mlir::registerTritonAMDGPUOptimizeEpilogue();
93+
mlir::registerTritonAMDGPUHoistLayoutConversions();
9394
mlir::registerTritonAMDGPUReorderInstructions();
9495
mlir::registerTritonAMDGPUBlockPingpong();
9596
mlir::registerTritonAMDGPUStreamPipeline();

include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,6 @@ using namespace mlir;
99
using namespace mlir::triton;
1010

1111
using ::mlir::triton::gpu::BlockedEncodingAttr;
12-
13-
namespace SharedToDotOperandFMA {
14-
Value convertLayout(int opIdx, Value val, Value llVal,
15-
BlockedEncodingAttr dLayout, Value thread, Location loc,
16-
const LLVMTypeConverter *typeConverter,
17-
ConversionPatternRewriter &rewriter);
18-
}
1912
LogicalResult convertFMADot(triton::DotOp op, triton::DotOp::Adaptor adaptor,
2013
const LLVMTypeConverter *typeConverter,
2114
ConversionPatternRewriter &rewriter);

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -505,7 +505,7 @@ def TT_CatOp : TT_Op<"cat", [NoMemoryEffect,
505505

506506
def TT_JoinOp : TT_Op<"join", [
507507
NoMemoryEffect, SameTypeOperands,
508-
DeclareOpInterfaceMethods<InferTypeOpInterface>,
508+
InferTypeOpWithLayoutEquivalence,
509509
]> {
510510
let summary = "join two tensors along a new, minor dimension";
511511
let description = [{
@@ -523,7 +523,7 @@ def TT_JoinOp : TT_Op<"join", [
523523

524524
def TT_SplitOp : TT_Op<"split", [
525525
NoMemoryEffect,
526-
DeclareOpInterfaceMethods<InferTypeOpInterface>,
526+
InferTypeOpWithLayoutEquivalence,
527527
TypesMatchWith<"outLHS and outRHS types match",
528528
"outLHS", "outRHS", "$_self">,
529529
]> {

include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -264,8 +264,8 @@ LinearLayout chooseLdMatrixLayout(Attribute enc, ArrayRef<int64_t> shape,
264264

265265
// The primary goal of this function is to efficiently load 2D tiles of a
266266
// tensor from shared memory using the `ds_read_tr` instruction for AMD GPUs.
267-
LinearLayout chooseDsReadB64Tr16Layout(Attribute enc, ArrayRef<int64_t> shape,
268-
int32_t elemBitWidth);
267+
LinearLayout chooseDsReadB64TrLayout(Attribute enc, ArrayRef<int64_t> shape,
268+
int32_t elemBitWidth);
269269

270270
// Create LinearLayout for mxfp4 and mxfp8 operand in scaled mfma.
271271
// For mxfp4, we use dot layout directly. Mxfp8 is not covered by dot
@@ -275,6 +275,9 @@ chooseScaledMfmaOperandLayout(AMDMfmaEncodingAttr mfmaEnc, int kWidth,
275275
int dotOperandIdx, ScaleDotElemType elemType,
276276
llvm::ArrayRef<int64_t> dotOperandShape);
277277

278+
LinearLayout getScaleTMEMStoreLinearLayout(RankedTensorType scaleType,
279+
int numWarps);
280+
278281
// Create LinearLayout for scale in scaled mfma.
279282
LinearLayout chooseScaledMfmaScaleLayout(
280283
MLIRContext *ctx, int dotOperandIdx,

include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ static const char *kDisallowAccMultiBufferAttrName =
1515
"tt.disallow_acc_multi_buffer";
1616
static const char *kLoopStageAttrName = "loop.stage";
1717
static const char *kLoopClusterAttrName = "loop.cluster";
18+
static const char *kScheduledMaxStageAttrName = "tt.scheduled_max_stage";
1819
static const char *kLatencyAttrName = "tt.latency";
1920

2021
bool loopHasDistGreaterThanOne(scf::ForOp forOp);

include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,7 @@ def TTNG_TCGen5MMAOp : TTNG_Op<"tc_gen5_mma", [DeclareOpInterfaceMethods<MemoryE
332332

333333
let description = [{
334334
$d += matrix_multiply($a, $b).
335-
If not barrier is given the op is assumed to be synchronous otherwise the op will trigger a commit/arrive on the given barrier.
335+
If no barrier is given the op is assumed to be synchronous otherwise the op will trigger a commit/arrive on the given barrier.
336336
If there is a barrier the result will be safe to read after a barrier wait.
337337
If $two_ctas is set the op will execute a matmul across two contiguous CTAs, it will read the data distributed across the two CTAs.
338338
and syncronize both CTAs if the op is synchronous.
@@ -355,7 +355,7 @@ def TTNG_TCGen5MMAScaledOp : TTNG_Op<"tc_gen5_mma_scaled", [DeclareOpInterfaceMe
355355

356356
let description = [{
357357
$d += matrix_multiply(scale($lhs, $lhs_scale), scale(rlhs, $rhs_scale))
358-
If not barrier is given the op is assumed to be synchronous otherwise the op will trigger a commit/arrive on the given barrier.
358+
If no barrier is given the op is assumed to be synchronous otherwise the op will trigger a commit/arrive on the given barrier.
359359
If there is a barrier the result will be safe to read after a barrier wait.
360360
}];
361361

lib/Conversion/TritonGPUToLLVM/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
add_triton_library(TritonGPUToLLVM
2-
ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp
32
DotOpToLLVM/FMA.cpp
43
DotOpToLLVM/FMADotUtility.cpp
54
AllocateSharedMemory.cpp

0 commit comments

Comments
 (0)