Skip to content

Commit 306051e

Browse files
Merge OpenAI Triton commit 413b521 (#3902)
This PR change the Triton base from 6f0ae97 to 413b521 (Apr 8). Pass rate: 90.77%->88.4% Please do not squash and merge this PR.
2 parents 55a2172 + 3910d9e commit 306051e

File tree

41 files changed

+805
-834
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+805
-834
lines changed

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,6 @@ createLLVMIntrinsicCallOp(OpBuilder &builder, Location loc, StringRef intrinsic,
5353
TypeRange types, ValueRange args);
5454
} // namespace mlir::LLVM
5555

56-
// Is v an integer or floating-point scalar constant equal to 0?
57-
bool isConstantZero(Value v);
58-
5956
namespace mlir::triton {
6057

6158
struct TritonLLVMOpBuilder {
@@ -348,9 +345,6 @@ LLVM::LLVMFuncOp appendOrGetExternFuncOp(RewriterBase &rewriter, Operation *op,
348345
namespace LLVM {
349346
using namespace mlir::triton;
350347

351-
// Is v an integer or floating-point scalar constant equal to 0?
352-
bool isConstantZero(Value v);
353-
354348
class SharedMemoryObject {
355349
public:
356350
SharedMemoryObject(Value base, Type baseElemType, ArrayRef<Value> offsets)

include/triton/Dialect/Triton/IR/TritonOpInterfaces.td

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,23 @@ def DotOpInterface : OpInterface<"DotOpInterface"> {
5353
/*desc=*/"Verify the dimensions of the A and B DotOp operands.",
5454
/*retType=*/"bool",
5555
/*methodName=*/"verifyDims",
56-
/*args=*/(ins)>
56+
/*args=*/(ins)>,
57+
InterfaceMethod<
58+
/*desc=*/"Verify the dimensions of the DotOp output.",
59+
/*retType=*/"bool",
60+
/*methodName=*/"verifyOutputDims",
61+
/*args=*/(ins),
62+
/*methodBody=*/[{}],
63+
/*defaultImpl=*/ [{
64+
auto aTy = cast<ShapedType>($_op.getA().getType());
65+
auto bTy = cast<ShapedType>($_op.getB().getType());
66+
auto cTy = cast<ShapedType>($_op->getOperand(2).getType());
67+
auto aShape = aTy.getShape();
68+
auto bShape = bTy.getShape();
69+
auto cShape = cTy.getShape();
70+
return cShape[cShape.size() - 2] == aShape[aShape.size() - 2] &&
71+
cShape[cShape.size() - 1] == bShape[aShape.size() - 1];
72+
}]>
5773
];
5874

5975
let verify = [{ return ::mlir::triton::impl::verifyDotOpInterface($_op); }];

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -676,7 +676,7 @@ def TT_DotOp : TT_Op<"dot", [Pure,
676676
//
677677
def TT_DotScaledOp : TT_Op<"dot_scaled", [Pure,
678678
AttrSizedOperandSegments,
679-
DeclareOpInterfaceMethods<DotOpInterface>,
679+
DeclareOpInterfaceMethods<DotOpInterface, ["verifyDims", "verifyOutputDims"]>,
680680
TypesMatchWith<"result's type matches accumulator's type",
681681
"d", "c", "$_self">]> {
682682
let summary = "dot_scaled";
@@ -697,7 +697,9 @@ def TT_DotScaledOp : TT_Op<"dot_scaled", [Pure,
697697
Optional<RankedTensorOf<[TT_Float, I8]>>:$b_scale,
698698
TT_ScaleDotElemTypeAttr:$a_elem_type,
699699
TT_ScaleDotElemTypeAttr:$b_elem_type,
700-
BoolAttr:$fastMath
700+
BoolAttr:$fastMath,
701+
DefaultValuedAttr<BoolAttr, "true">:$lhs_k_pack,
702+
DefaultValuedAttr<BoolAttr, "true">:$rhs_k_pack
701703
);
702704

703705
let results = (outs TT_FloatTensor:$d);

include/triton/Dialect/TritonGPU/Transforms/MMAv5PipelineUtility.h

Lines changed: 8 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -13,50 +13,6 @@ class ForOp;
1313
} // namespace scf
1414
namespace triton::nvidia_gpu {
1515

16-
//===----------------------------------------------------------------------===//
17-
// MMAInfo
18-
//===----------------------------------------------------------------------===//
19-
20-
// This struct contains analysis information about an MMAv5 operation inside a
21-
// loop used for pipelining MMA ops.
22-
struct MMAInfo {
23-
// This struct contains information about when the MMA's accumulator is
24-
// overridden in the loop, if it is at all.
25-
struct AccOverridePoint {
26-
// The operation which overrides the accumulator.
27-
Operation *op;
28-
// The condition on which the accumulator is reset.
29-
Value condition = nullptr;
30-
// The initial value of the accumulator and the value after a reset.
31-
Value initValue = nullptr;
32-
// The number of loop iterations ago the accumulator was reset.
33-
int distance = 0;
34-
// Whether the accumulator is reset via setting the `useAcc` flag to false
35-
// or by clearing the accumulator tensor value.
36-
bool isFlag = false;
37-
};
38-
39-
// The TMEM allocation of the accumuator, which directly precedes the dot op.
40-
TMEMAllocOp accAlloc;
41-
// The TMEM load of the accumulator value out of TMEM, which directly follows
42-
// the dot op.
43-
TMEMLoadOp accLoad;
44-
// The override point of the accumulator value, if it is overriden in the
45-
// loop. E.g. this is typically present for persistent kernels.
46-
std::optional<AccOverridePoint> accDef;
47-
// If the accumulator is used in future iterations of the loop, this is the
48-
// iter arg number.
49-
std::optional<int> yieldArgNo;
50-
// Whether the accumulator needs to be multibuffered.
51-
bool accIsMultiBuffered;
52-
53-
Value phase = nullptr;
54-
Value barrierIdx = nullptr;
55-
Value accInsertIdx = nullptr;
56-
Value accExtractIdx = nullptr;
57-
Value barrierAlloc = nullptr;
58-
};
59-
6016
//===----------------------------------------------------------------------===//
6117
// MMA Pipeline Analysis
6218
//===----------------------------------------------------------------------===//
@@ -66,12 +22,14 @@ struct MMAInfo {
6622
// be in the same region as the MMA operation.
6723
std::optional<std::pair<TMEMAllocOp, TMEMLoadOp>>
6824
getTMemAllocAndLoad(MMAv5OpInterface mmaOp);
69-
// Get immediate users of the accumulator within the current loop iteration.
70-
SmallVector<Operation *> getDirectAccUses(TMEMLoadOp accDef);
71-
// Analyze an MMA op inside a loop to determine information about how it can be
72-
// pipelined. Returns `std::nullopt` if it cannot be pipelined.
73-
std::optional<MMAInfo> getMMAInfo(scf::ForOp forOp, MMAv5OpInterface mmaOp,
74-
DominanceInfo &domInfo);
25+
// Given an MMAv5 operation in a loop, determine if its accumulator can be
26+
// multibuffered.
27+
bool isAccMultibufferingPossible(MMAv5OpInterface mma, scf::ForOp forOp);
28+
// Only pipeline the loops where the MMA happens before the tmem_load, or is in
29+
// the same stage as the tmem_load. Lowering does not support the case where the
30+
// MMA is in a different stage as the tmem_load and happens after it.
31+
bool mmav5DominatesTmemLoads(
32+
scf::ForOp forOp, function_ref<bool(MMAv5OpInterface)> isMmaPipelineable);
7533

7634
//===----------------------------------------------------------------------===//
7735
// MMA Pipeline Rewriters
@@ -82,11 +40,6 @@ std::optional<MMAInfo> getMMAInfo(scf::ForOp forOp, MMAv5OpInterface mmaOp,
8240
TMEMAllocOp createTMemAlloc(OpBuilder &builder, TMEMAllocOp oldTMemAllocOp,
8341
bool multiBufferred, int numStages);
8442

85-
// Create a store op of the initial value of the accumulator into the
86-
// potentially multi-buffered accumulator.
87-
void createInitStore(OpBuilder &builder, TMEMAllocOp allocOp, Value initVal,
88-
bool multiBufferred);
89-
9043
// Return true if operands of the MMA operation are/are going to be pipelined
9144
// and multibuffered, enabling the MMA operation to be pipelined.
9245
bool mmaHasPipelineableOperands(

include/triton/Dialect/TritonGPU/Transforms/Utility.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
namespace mlir {
1313
class DominanceInfo;
14+
class PostDominanceInfo;
1415

1516
namespace triton {
1617
class ModuleAxisInfoAnalysis;
@@ -222,6 +223,11 @@ getMMAsWithMultiBufferredOperands(scf::ForOp forOp,
222223
// regions. The result op is not necessarily one of the ops in the list.
223224
Operation *findNearestCommonDominator(ArrayRef<Operation *> ops,
224225
DominanceInfo &domInfo);
226+
// Given a list of ops, find the naerest common postdominator of all ops or
227+
// return null if one could not be found. The ops are allowed to be in different
228+
// regions. The result op is not necessarily one of the ops in the list.
229+
Operation *findNearestCommonPostDominator(ArrayRef<Operation *> ops,
230+
PostDominanceInfo &postDomInfo);
225231

226232
/// Visit the operands of `op` and the operands of any nested ops defined
227233
/// outside of `op`.

include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,7 @@ def TTNG_TCGen5MMAOp : TTNG_Op<"tc_gen5_mma", [DeclareOpInterfaceMethods<MemoryE
404404
let assemblyFormat = "$a`,` $b`,` $d`,` $useD`,` $pred (`,` $barrier^)? attr-dict `:` functional-type(operands, results)";
405405
}
406406

407-
def TTNG_TCGen5MMAScaledOp : TTNG_Op<"tc_gen5_mma_scaled", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, DeclareOpInterfaceMethods<DotOpInterface>, DeclareOpInterfaceMethods<MMAv5OpInterface>]> {
407+
def TTNG_TCGen5MMAScaledOp : TTNG_Op<"tc_gen5_mma_scaled", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, DeclareOpInterfaceMethods<DotOpInterface, ["verifyDims", "verifyOutputDims"]>, DeclareOpInterfaceMethods<MMAv5OpInterface>]> {
408408
let summary = "block level op mapping to tensorcore gen5 mma";
409409

410410
let description = [{
@@ -423,7 +423,11 @@ def TTNG_TCGen5MMAScaledOp : TTNG_Op<"tc_gen5_mma_scaled", [DeclareOpInterfaceMe
423423
I1:$useD,
424424
I1:$pred,
425425
Optional<TTG_MemDescType>:$barrier);
426-
426+
let extraClassDeclaration = [{
427+
int64_t getBlockM();
428+
int64_t getBlockN();
429+
int64_t getBlockK();
430+
}];
427431
// TODO: improve printing format.
428432
let assemblyFormat = "$a `,` $b `,` $d `,` $a_scale `,` $b_scale `,` $useD`,` $pred `lhs` `=` $a_type `rhs` `=` $b_type (`,` $barrier^)? attr-dict `:` functional-type(operands, results)";
429433
}

lib/Conversion/TritonGPUToLLVM/Utility.cpp

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -667,18 +667,6 @@ createLLVMIntrinsicCallOp(OpBuilder &builder, Location loc, StringRef intrinsic,
667667
return op;
668668
}
669669

670-
bool isConstantZero(Value v) {
671-
if (auto constantOp = v.getDefiningOp<arith::ConstantOp>()) {
672-
if (auto attr = dyn_cast<IntegerAttr>(constantOp.getValue())) {
673-
return attr.getValue().isZero();
674-
}
675-
if (auto attr = dyn_cast<FloatAttr>(constantOp.getValue())) {
676-
return attr.getValue().isZero();
677-
}
678-
}
679-
return false;
680-
}
681-
682670
Value getStructFromSharedMemoryObject(Location loc,
683671
const SharedMemoryObject &smemObj,
684672
RewriterBase &rewriter) {

lib/Dialect/Triton/IR/OpInterfaces.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,7 @@ LogicalResult verifyDotOpInterface(Operation *op) {
6464
"operand to be equal to the first dimension of "
6565
"the result");
6666
// Check the output shape
67-
if (cShape[cShape.size() - 2] != aShape[aShape.size() - 2] ||
68-
cShape[cShape.size() - 1] != bShape[aShape.size() - 1])
67+
if (!dotOp.verifyOutputDims())
6968
return dotOp->emitOpError(
7069
"expected the output shape to be the concatenation of the last "
7170
"dimension of the first operand and the last dimension of the "

lib/Dialect/Triton/IR/Ops.cpp

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -323,14 +323,39 @@ bool DotScaledOp::verifyDims() {
323323

324324
auto aKdim = aShape[aShape.size() - 1];
325325
auto bKdim = bShape[aShape.size() - 2];
326-
if (this->getAElemType() == ScaleDotElemType::E2M1)
327-
aKdim *= 2;
328-
if (this->getBElemType() == ScaleDotElemType::E2M1)
329-
bKdim *= 2;
326+
if (this->getAElemType() == ScaleDotElemType::E2M1) {
327+
if (this->getLhsKPack())
328+
aKdim *= 2;
329+
}
330+
if (this->getBElemType() == ScaleDotElemType::E2M1) {
331+
if (this->getRhsKPack())
332+
bKdim *= 2;
333+
}
330334

331335
return aKdim == bKdim;
332336
}
333337

338+
bool DotScaledOp::verifyOutputDims() {
339+
auto cShape = this->getC().getType().getShape();
340+
auto oMdim = cShape[cShape.size() - 2];
341+
auto oNdim = cShape[cShape.size() - 1];
342+
auto aShape = this->getA().getType().getShape();
343+
auto bShape = this->getB().getType().getShape();
344+
auto adim = aShape[aShape.size() - 2];
345+
auto bdim = bShape[bShape.size() - 1];
346+
if (this->getAElemType() == ScaleDotElemType::E2M1) {
347+
if (!this->getLhsKPack())
348+
adim *= 2;
349+
}
350+
if (this->getBElemType() == ScaleDotElemType::E2M1) {
351+
if (!this->getRhsKPack())
352+
bdim *= 2;
353+
}
354+
if (adim != oMdim || bdim != oNdim)
355+
return false;
356+
return true;
357+
}
358+
334359
//-- MakeRangeOp --
335360
OpFoldResult MakeRangeOp::fold(FoldAdaptor adaptor) {
336361
// make_range(start, start + 1) -> constant(start)

lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ warpsPerTileV3(DotOp dotOp, const ArrayRef<int64_t> shape, int numWarps,
149149
static Value
150150
getSharedMemoryMMAOperand(Value v, mlir::PatternRewriter &rewriter, int opIdx,
151151
bool allowTranspose, bool isMMAv5Fp4Padded = false,
152+
bool forceTranspose = false,
152153
Operation *op = nullptr /*only for diagnostic*/) {
153154
OpBuilder::InsertionGuard g(rewriter);
154155
Value arg = v;
@@ -167,6 +168,8 @@ getSharedMemoryMMAOperand(Value v, mlir::PatternRewriter &rewriter, int opIdx,
167168
} else {
168169
newOrder = {1, 0};
169170
}
171+
if (forceTranspose)
172+
std::swap(newOrder[0], newOrder[1]);
170173
}
171174

172175
if (newOrder != order && op) {
@@ -648,49 +651,47 @@ class ScaledBlockedToMMAv5
648651

649652
bool IsAMixedPrecFp4 = false;
650653
bool IsBMixedPrecFp4 = false;
654+
bool isAFP4 = dotOp.getAElemType() == ScaleDotElemType::E2M1;
655+
bool isBFP4 = dotOp.getBElemType() == ScaleDotElemType::E2M1;
651656

652657
if (dotOp.getAElemType() != dotOp.getBElemType()) {
653-
if (dotOp.getAElemType() == ScaleDotElemType::E2M1)
658+
if (isAFP4)
654659
IsAMixedPrecFp4 = true;
655-
else if (dotOp.getBElemType() == ScaleDotElemType::E2M1)
660+
else if (isBFP4)
656661
IsBMixedPrecFp4 = true;
657662
}
658-
663+
// If we use txgen05.mma.kind.mxf864 we need to padd the fp4 operands:
664+
// https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-packing-formats-mxf8f6f4-smem
665+
bool isMMAv5Fp4PaddedLhs = IsAMixedPrecFp4 || !dotOp.getLhsKPack();
666+
bool isMMAv5Fp4PaddedRhs = IsBMixedPrecFp4 || !dotOp.getRhsKPack();
659667
// For mixed-precision fp4 operands, set allowTranspose = false, to force
660668
// the packed axis, K, to be contiguous in SMEM
661669
a = getSharedMemoryMMAOperand(a, rewriter, 0,
662-
/*allowTranspose=*/!IsAMixedPrecFp4,
663-
IsAMixedPrecFp4, dotOp);
670+
/*allowTranspose=*/!isAFP4,
671+
/*isMMAv5Fp4Padded=*/isMMAv5Fp4PaddedLhs,
672+
/*forceTranspose=*/!dotOp.getLhsKPack(),
673+
dotOp);
664674
b = getSharedMemoryMMAOperand(b, rewriter, 1,
665-
/*allowTranspose=*/!IsBMixedPrecFp4,
666-
IsBMixedPrecFp4, dotOp);
675+
/*allowTranspose=*/!isBFP4,
676+
/*isMMAv5Fp4Padded=*/isMMAv5Fp4PaddedRhs,
677+
/*forceTranspose=*/!dotOp.getRhsKPack(),
678+
dotOp);
667679

668680
MLIRContext *context = dotOp->getContext();
669681
unsigned m = 128;
670682
unsigned n = retShapePerCTA[1] >= 256 ? 256 : retShapePerCTA[1];
671-
unsigned k = 32;
672-
// If both operands are E2M1, target the FP4 tensor core implicitly.
673-
// This may result in a downstream compile-time error if the scaled TC
674-
// descriptor requires options that are unavailable to the .kind=mxf4 mma.
675-
// This is likely preferable over a silent runtime performance degradation
676-
// from running f4xf4 via .kind=mxf8f6f4
677-
if (dotOp.getAElemType() == ScaleDotElemType::E2M1 &&
678-
dotOp.getBElemType() == ScaleDotElemType::E2M1) {
679-
k = 64;
680-
}
681-
SmallVector<unsigned> instrShape = {m, n, k};
683+
682684
ArrayRef<unsigned> CTASplitNum = CTALayout.getCTASplitNum();
683685
Attribute accEncoding = triton::nvidia_gpu::TensorMemoryEncodingAttr::get(
684-
context, instrShape[0], instrShape[1], /*unpacked=*/true,
685-
CTASplitNum[0], CTASplitNum[1]);
686+
context, m, n, /*unpacked=*/true, CTASplitNum[0], CTASplitNum[1]);
686687
Attribute tensorMemorySpace =
687688
triton::nvidia_gpu::TensorMemorySpaceAttr::get(context);
688689
Type accMemDescType = triton::gpu::MemDescType::get(
689690
oldRetType.getShape(), oldRetType.getElementType(), accEncoding,
690691
tensorMemorySpace,
691692
/*mutableMemory=*/true);
692-
Attribute newDistributedEncoding = nvidia_gpu::getTmemCompatibleLayout(
693-
instrShape[0], instrShape[1], oldRetType, numWarps);
693+
Attribute newDistributedEncoding =
694+
nvidia_gpu::getTmemCompatibleLayout(m, n, oldRetType, numWarps);
694695
auto newAccType = RankedTensorType::get(oldRetType.getShape(),
695696
oldRetType.getElementType(),
696697
newDistributedEncoding);

0 commit comments

Comments
 (0)