Skip to content

Commit cb6997a

Browse files
Merge OpenAI Triton commit f1f9ed9 (#4080)
This PR change the Triton base from 62fbca4 to f1f9ed9 (Apr 29). Pass rate: 92.08% Please do not squash and merge this PR.
2 parents add2d40 + d5e41d1 commit cb6997a

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+1109
-823
lines changed

bench/.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
triton_bench.egg-info/

include/triton/Dialect/TritonGPU/Transforms/MMAv5PipelineUtility.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,6 @@ namespace triton::nvidia_gpu {
1717
// MMA Pipeline Analysis
1818
//===----------------------------------------------------------------------===//
1919

20-
// Returns the TMEMAllocOp and TMEMLoadOp that are used to allocate and load the
21-
// accumulator for the given MMA operation. The TMEMAllocOp and TMEMLoadOp must
22-
// be in the same region as the MMA operation.
23-
std::optional<std::pair<TMEMAllocOp, TMEMLoadOp>>
24-
getTMemAllocAndLoad(MMAv5OpInterface mmaOp);
2520
// Given an MMAv5 operation in a loop, determine if its accumulator can be
2621
// multibuffered.
2722
bool isAccMultibufferingPossible(MMAv5OpInterface mma, scf::ForOp forOp);

include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,15 @@ void hoistOpsBefore(Operation *refOp,
4848
void hoistOpsBefore(Block *block, Block::iterator it,
4949
const llvm::SetVector<Operation *> &toHoist);
5050

51+
//===----------------------------------------------------------------------===//
52+
// Sinking Utilities
53+
//===----------------------------------------------------------------------===//
54+
55+
// Sink a value redefinition into a block, provided that the block is dominated
56+
// by `in` and postdominated by `out`.
57+
Value sinkValueRedefinition(RewriterBase &rewriter, Value in, Value out,
58+
Block *block);
59+
5160
//===----------------------------------------------------------------------===//
5261
// Loop Pipelining Utilities
5362
//===----------------------------------------------------------------------===//

include/triton/Dialect/TritonGPU/Transforms/Utility.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -243,10 +243,6 @@ SetVector<Value> getNestedOperands(Operation *op);
243243
// Erase the given loop carried values from the loop, where `loop` is replaced
244244
// with a new loop.
245245
void eraseLoopCarriedValues(scf::ForOp &loop, llvm::BitVector indices);
246-
247-
// Return true if two value sets may refer to the same allocation.
248-
bool mayAliasAllocations(const DenseSet<Value> &lhs,
249-
const DenseSet<Value> &rhs);
250246
} // namespace mlir
251247

252248
namespace mlir::triton {

include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOpInterfaces.td

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,15 @@ def MMAv5OpInterface : OpInterface<"MMAv5OpInterface"> {
4040
"void",
4141
"setPredicate",
4242
(ins "::mlir::Value":$pred)>,
43+
InterfaceMethod<"Get the memory dependencies of the accumulator.",
44+
"::mlir::Value",
45+
"getAccDep">,
46+
InterfaceMethod<"Get the mutable memory dependencies of the accumulator.",
47+
"::mlir::MutableOperandRange",
48+
"getAccDepMutable">,
49+
InterfaceMethod<"Get the produced write dependency of the accumulator.",
50+
"::mlir::Value",
51+
"getToken">,
4352
];
4453
}
4554
#endif // TRITON_NVIDIAGPU_OP_INTERFACES

include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td

Lines changed: 68 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -417,7 +417,7 @@ def TTNG_TCGen5MMAOp : TTNG_Op<"tc_gen5_mma", [
417417
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
418418
DeclareOpInterfaceMethods<DotOpInterface>,
419419
DeclareOpInterfaceMethods<MMAv5OpInterface>,
420-
SameVariadicOperandSize
420+
AttrSizedOperandSegments
421421
]> {
422422
let summary = "block level op mapping to tensorcore gen5 mma";
423423

@@ -427,29 +427,36 @@ def TTNG_TCGen5MMAOp : TTNG_Op<"tc_gen5_mma", [
427427
If there is a barrier the result will be safe to read after a barrier wait.
428428
If $two_ctas is set the op will execute a matmul across two contiguous CTAs, it will read the data distributed across the two CTAs.
429429
and syncronize both CTAs if the op is synchronous.
430+
431+
This operation takes and produces an optional token to indicate TMEM read
432+
and write on its accumulator operand. When the tokens are present, they can
433+
be used to check aliasing and modref on the accumulator memory.
430434
}];
431435

432436
let arguments = (ins
433437
TTG_MemDescType:$a,
434438
TTG_MemDescType:$b,
435439
TTG_MemDescType:$d,
440+
Optional<TTG_AsyncToken>:$acc_dep,
436441
I1:$useD,
437442
I1:$pred,
438443
Variadic<TTG_MemDescType>:$barriers,
439444
Variadic<I1>:$barrier_preds,
440445
OptionalAttr<UnitAttr>:$two_ctas
441446
);
447+
let results = (outs Optional<TTG_AsyncToken>:$token);
442448

443449
let builders = [
444-
OpBuilder<(ins
445-
"Value":$a, "Value":$b, "Value":$d, "Value":$useD, "Value":$pred,
446-
CArg<"bool", "false">:$two_ctas, CArg<"ValueRange", "{}">:$barriers,
450+
OpBuilder<(ins "Type":$token,
451+
"Value":$a, "Value":$b, "Value":$d, "Value":$acc_dep, "Value":$useD,
452+
"Value":$pred, CArg<"bool", "false">:$two_ctas,
453+
CArg<"ValueRange", "{}">:$barriers,
447454
CArg<"ValueRange", "{}">:$barrier_preds)>
448455
];
449456

450457
let assemblyFormat = [{
451-
$a`,` $b`,` $d`,` $useD`,` $pred
452-
`` custom<BarriersAndPreds>($barriers, $barrier_preds)
458+
$a `,` $b `,` $d `` custom<Token>($acc_dep, type($token)) `,` $useD`,`
459+
$pred `` custom<BarriersAndPreds>($barriers, $barrier_preds)
453460
attr-dict `:` qualified(type($a)) `,` qualified(type($b)) `,`
454461
qualified(type($d)) (`,` qualified(type($barriers))^)?
455462
}];
@@ -459,20 +466,25 @@ def TTNG_TCGen5MMAScaledOp : TTNG_Op<"tc_gen5_mma_scaled", [
459466
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
460467
DeclareOpInterfaceMethods<DotOpInterface, ["verifyDims", "verifyOutputDims"]>,
461468
DeclareOpInterfaceMethods<MMAv5OpInterface>,
462-
SameVariadicOperandSize
469+
AttrSizedOperandSegments
463470
]> {
464471
let summary = "block level op mapping to tensorcore gen5 mma";
465472

466473
let description = [{
467474
$d += matrix_multiply(scale($lhs, $lhs_scale), scale(rlhs, $rhs_scale))
468475
If no barrier is given the op is assumed to be synchronous otherwise the op will trigger a commit/arrive on the given barrier.
469476
If there is a barrier the result will be safe to read after a barrier wait.
477+
478+
This operation takes and produces an optional token to indicate TMEM read
479+
and write on its accumulator operand. When the tokens are present, they can
480+
be used to check aliasing and modref on the accumulator memory.
470481
}];
471482

472483
let arguments = (ins
473484
TTG_MemDescType:$a,
474485
TTG_MemDescType:$b,
475486
TTG_MemDescType:$d,
487+
Optional<TTG_AsyncToken>:$acc_dep,
476488
TTG_MemDescType:$a_scale,
477489
TTG_MemDescType:$b_scale,
478490
TT_ScaleDotElemTypeAttr:$a_type,
@@ -482,6 +494,8 @@ def TTNG_TCGen5MMAScaledOp : TTNG_Op<"tc_gen5_mma_scaled", [
482494
Variadic<TTG_MemDescType>:$barriers,
483495
Variadic<I1>:$barrier_preds
484496
);
497+
let results = (outs Optional<TTG_AsyncToken>:$token);
498+
485499
let extraClassDeclaration = [{
486500
int64_t getBlockM();
487501
int64_t getBlockN();
@@ -491,19 +505,19 @@ def TTNG_TCGen5MMAScaledOp : TTNG_Op<"tc_gen5_mma_scaled", [
491505
let builders = [
492506
// Namespaces need to be prefixed so ODS prefers our
493507
// custom builder signature over the default-generated one.
494-
OpBuilder<(ins
508+
OpBuilder<(ins "::mlir::Type":$token,
495509
"::mlir::Value":$a, "::mlir::Value":$b, "::mlir::Value":$d,
496-
"::mlir::Value":$a_scale, "::mlir::Value":$b_scale,
497-
"::mlir::triton::ScaleDotElemType":$a_type,
510+
"::mlir::Value":$acc_dep, "::mlir::Value":$a_scale,
511+
"::mlir::Value":$b_scale, "::mlir::triton::ScaleDotElemType":$a_type,
498512
"::mlir::triton::ScaleDotElemType":$b_type,
499513
"::mlir::Value":$useD, "::mlir::Value":$pred,
500514
CArg<"::mlir::ValueRange", "{}">:$barriers,
501515
CArg<"::mlir::ValueRange", "{}">:$barrier_preds)>
502516
];
503517

504518
let assemblyFormat = [{
505-
$a `,` $b `,` $d `,` $a_scale `,` $b_scale `,` $useD`,` $pred
506-
`lhs` `=` $a_type `rhs` `=` $b_type
519+
$a `,` $b `,` $d `` custom<Token>($acc_dep, type($token)) `,` $a_scale `,`
520+
$b_scale `,` $useD `,` $pred `lhs` `=` $a_type `rhs` `=` $b_type
507521
`` custom<BarriersAndPreds>($barriers, $barrier_preds)
508522
attr-dict `:` qualified(type($a)) `,` qualified(type($b)) `,`
509523
qualified(type($d)) `,` qualified(type($a_scale)) `,`
@@ -517,27 +531,55 @@ def TTNG_TMEMLoadOp : TTNG_Op<"tmem_load"> {
517531
let description = [{
518532
This is similar to ttg.local_load except the result layout is restricted to only few possibility.
519533
Therefore we cannot combine this op with any convert layout like local_load.
534+
535+
This operation takes and produces an optional token to indicate TMEM read
536+
on its source operand. When the tokens are present, they can
537+
be used to check aliasing and modref on the TMEM buffer.
538+
}];
539+
let arguments = (ins
540+
Arg<TTG_MemDescType, "", [MemRead<TensorMemory>]>:$src,
541+
Optional<TTG_AsyncToken>:$dep
542+
);
543+
let results = (outs
544+
TT_Tensor:$result,
545+
Optional<TTG_AsyncToken>:$token
546+
);
547+
548+
let assemblyFormat = [{
549+
$src `` custom<Token>($dep, type($token))
550+
attr-dict `:` qualified(type($src)) `->` type($result)
520551
}];
521-
let arguments = (ins Arg<TTG_MemDescType, "", [MemRead<TensorMemory>]>:$src);
522552

523-
let assemblyFormat = [{$src attr-dict `:` qualified(type($src)) `->` type($result)}];
524-
let results = (outs TT_Tensor:$result);
525553
let hasVerifier = 1;
554+
555+
let extraClassDeclaration = [{
556+
RankedTensorType getType() { return getResult().getType(); }
557+
operator TypedValue<RankedTensorType>() { return getResult(); }
558+
}];
526559
}
527560

528561
def TTNG_TMEMStoreOp : TTNG_Op<"tmem_store"> {
529562
let summary = "Store a distributed tensor into a buffer in tensor memory";
530563

531564
let description = [{
532-
This is similar to ttg.local_local except the source layout is restricted to only few possibility.
565+
This is similar to ttg.local_store except the source layout is restricted to only few possibility.
566+
567+
This operation takes and produces an optional token to indicate TMEM write
568+
on its source operand. When the tokens are present, they can
569+
be used to check aliasing and modref on the TMEM buffer.
533570
}];
534571
let arguments = (ins
535572
Arg<TTG_MemDescType, "", [MemWrite<TensorMemory>]>:$dst,
573+
Optional<TTG_AsyncToken>:$dep,
536574
TT_Tensor:$src,
537575
I1:$pred
538576
);
577+
let results = (outs Optional<TTG_AsyncToken>:$token);
539578

540-
let assemblyFormat = [{$src `,` $dst `,` $pred attr-dict `:` type($src) `->` qualified(type($dst))}];
579+
let assemblyFormat = [{
580+
$src `,` $dst `` custom<Token>($dep, type($token)) `,` $pred
581+
attr-dict `:` type($src) `->` qualified(type($dst))
582+
}];
541583
let hasVerifier = 1;
542584
}
543585

@@ -551,13 +593,21 @@ def TTNG_TMEMAllocOp : TTNG_Op<"tmem_alloc", [DeclareOpInterfaceMethods<MemoryEf
551593
Explicitly deallocating a buffer is optional; see local_dealloc.
552594
}];
553595
let arguments = (ins Optional<TT_Tensor>:$src);
596+
let results = (outs
597+
TTG_MemDescType:$result,
598+
Optional<TTG_AsyncToken>:$token
599+
);
554600

555601
let assemblyFormat = [{
556602
($src^)? attr-dict `:` functional-type(operands, results)
557603
}];
558604

559-
let results = (outs TTG_MemDescType:$result);
560605
let hasVerifier = 1;
606+
607+
let extraClassDeclaration = [{
608+
triton::gpu::MemDescType getType() { return getResult().getType(); }
609+
operator TypedValue<triton::gpu::MemDescType>() { return getResult(); }
610+
}];
561611
}
562612

563613
def TTNG_TMEMSubSliceOp : TTNG_Op<"tmem_subslice", [Pure]> {

include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ std::unique_ptr<Pass> createTritonNvidiaGPUMMALoweringPass();
5858

5959
std::unique_ptr<Pass> createTritonNvidiaGPUPromoteLHSToTMemPass();
6060

61+
std::unique_ptr<Pass> createTritonNvidiaGPURemoveTMEMTokensPass();
62+
6163
std::unique_ptr<Pass> createTritonNvidiaGPUOptimizeDescriptorEncodingPass();
6264

6365
std::unique_ptr<Pass> createTritonNvidiaGPUOptimizeTMemSubtilingPass();

include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.td

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,4 +142,13 @@ def TritonNvidiaGPUOptimizeTMemSubtilingPass : Pass<"triton-nvidia-optimize-tmem
142142
"mlir::triton::TritonDialect"];
143143
}
144144

145+
def TritonNvidiaGPURemoveTMEMTokensPass : Pass<"triton-nvidia-gpu-remove-tmem-tokens", "mlir::ModuleOp"> {
146+
let summary = "remove TMEM tokens";
147+
148+
let description = [{
149+
The `triton-nvidia-gpu-remove-tmem-tokens` pass removes TMEM memory
150+
dependency tokens from the IR, after they are no longer needed.
151+
}];
152+
}
153+
145154
#endif

lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp

Lines changed: 37 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,21 @@ static int getMMAVersionSafe(int computeCapability, DotOp op) {
4646
for (int baseVersion : versionsSupported) {
4747
if (supportMMA(op, baseVersion))
4848
return baseVersion;
49-
if (baseVersion == 3)
50-
op.emitRemark() << "Warning: can't use MMA V3 for the dot op";
49+
if (baseVersion == 3) {
50+
auto remark = op.emitRemark()
51+
<< "MMA version 3 acceleration not applied due to "
52+
"unsupported shapes or data types.";
53+
remark.attachNote() << "Target compute capability (" << computeCapability
54+
<< ") supports MMA v3.";
55+
}
56+
57+
if (baseVersion == 5) {
58+
auto remark = op.emitRemark()
59+
<< "MMA version 5 acceleration not applied due to "
60+
"unsupported shapes or data types.";
61+
remark.attachNote() << "Target compute capability (" << computeCapability
62+
<< ") supports MMA v5.";
63+
}
5164
}
5265
return 0;
5366
}
@@ -544,15 +557,17 @@ class BlockedToMMAv5 : public mlir::OpRewritePattern<DotOp> {
544557
newDistributedEncoding);
545558
Value cvtAcc =
546559
rewriter.create<ConvertLayoutOp>(loc, newAccType, dotOp.getOperand(2));
560+
auto tokType = rewriter.getType<AsyncTokenType>();
547561
auto acc = rewriter.create<triton::nvidia_gpu::TMEMAllocOp>(
548-
loc, accMemDescType, cvtAcc);
562+
loc, accMemDescType, tokType, cvtAcc);
549563
auto vTrue = rewriter.create<arith::ConstantIntOp>(dotOp.getLoc(), 1, 1);
550564
auto mma = rewriter.create<triton::nvidia_gpu::TCGen5MMAOp>(
551-
loc, a, b, acc, /*useD=*/vTrue, /*pred=*/vTrue);
565+
loc, tokType, a, b, acc, acc.getToken(), /*useD=*/vTrue,
566+
/*pred=*/vTrue);
552567
mma.setTwoCtas(useTwoCTAs);
553568

554-
auto ld =
555-
rewriter.create<triton::nvidia_gpu::TMEMLoadOp>(loc, newAccType, acc);
569+
auto ld = rewriter.create<triton::nvidia_gpu::TMEMLoadOp>(
570+
loc, newAccType, tokType, acc, /*dep=*/mma.getToken());
556571
rewriter.replaceOpWithNewOp<ConvertLayoutOp>(dotOp, oldRetType, ld);
557572
return success();
558573
}
@@ -697,8 +712,9 @@ class ScaledBlockedToMMAv5
697712
newDistributedEncoding);
698713
Value cvtAcc =
699714
rewriter.create<ConvertLayoutOp>(loc, newAccType, dotOp.getOperand(2));
715+
auto tokType = rewriter.getType<AsyncTokenType>();
700716
auto acc = rewriter.create<triton::nvidia_gpu::TMEMAllocOp>(
701-
loc, accMemDescType, cvtAcc);
717+
loc, accMemDescType, tokType, cvtAcc);
702718

703719
RankedTensorType oldScaleAType = dotOp.getAScale().getType();
704720
RankedTensorType oldScaleBType = dotOp.getBScale().getType();
@@ -728,17 +744,22 @@ class ScaledBlockedToMMAv5
728744
rewriter.create<ConvertLayoutOp>(loc, newScaleAType, lhsScale);
729745
Value newScaleB =
730746
rewriter.create<ConvertLayoutOp>(loc, newScaleBType, rhsScale);
731-
Value scaleA = rewriter.create<triton::nvidia_gpu::TMEMAllocOp>(
732-
loc, scaleAType, newScaleA);
733-
Value scaleB = rewriter.create<triton::nvidia_gpu::TMEMAllocOp>(
734-
loc, scaleBType, newScaleB);
747+
748+
// We don't need to track memory dependencies for the scale operands since
749+
// they are not pipelined.
750+
auto scaleA = rewriter.create<triton::nvidia_gpu::TMEMAllocOp>(
751+
loc, scaleAType, /*token=*/Type(), newScaleA);
752+
auto scaleB = rewriter.create<triton::nvidia_gpu::TMEMAllocOp>(
753+
loc, scaleBType, /*token=*/Type(), newScaleB);
754+
735755
auto vTrue = rewriter.create<arith::ConstantIntOp>(dotOp.getLoc(), 1, 1);
736-
rewriter.create<triton::nvidia_gpu::TCGen5MMAScaledOp>(
737-
loc, a, b, acc, scaleA, scaleB, dotOp.getAElemType(),
738-
dotOp.getBElemType(), /*useD=*/vTrue, /*pred=*/vTrue);
756+
auto mmaOp = rewriter.create<triton::nvidia_gpu::TCGen5MMAScaledOp>(
757+
loc, tokType, a, b, acc.getResult(), acc.getToken(), scaleA.getResult(),
758+
scaleB.getResult(), dotOp.getAElemType(), dotOp.getBElemType(),
759+
/*useD=*/vTrue, /*pred=*/vTrue);
739760

740-
auto ld =
741-
rewriter.create<triton::nvidia_gpu::TMEMLoadOp>(loc, newAccType, acc);
761+
auto ld = rewriter.create<triton::nvidia_gpu::TMEMLoadOp>(
762+
loc, newAccType, tokType, acc, mmaOp.getToken());
742763
rewriter.replaceOpWithNewOp<ConvertLayoutOp>(dotOp, oldRetType, ld);
743764
return success();
744765
}

0 commit comments

Comments
 (0)