Skip to content

Commit a93c725

Browse files
Merge OpenAI Triton commit a5b948c (#5556)
This PR changes the Triton base from c33b2d9 to a5b948c (Nov 7). Pass rate: 95.23%
2 parents dcb8ff0 + 16c1334 commit a93c725

File tree

29 files changed

+839
-227
lines changed

29 files changed

+839
-227
lines changed

include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,10 @@ LinearLayout chooseScaledMfmaScaleLayout(MLIRContext *ctx, int dotOperandIdx,
125125
ArrayRef<unsigned> warpsPerCTA);
126126

127127
LinearLayout chooseScaledWmmaScaleLayout(MLIRContext *ctx, int dotOperandIdx,
128-
ArrayRef<unsigned> warpsPerCTA,
129-
ArrayRef<int64_t> dotOperandShape);
128+
ArrayRef<int64_t> dotOperandShape,
129+
unsigned wmmaMDim,
130+
ArrayRef<unsigned> tilesPerWarp,
131+
ArrayRef<unsigned> warpsPerCTA);
130132

131133
LinearLayout getSM120DotScaledScaleLayout(MLIRContext *ctx,
132134
ArrayRef<int64_t> shape, int opIdx,

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1133,7 +1133,7 @@ Example 4:
11331133
This example demonstrates semantics of tilesPerWarp parameter. The MFMA layout (with tilesPerWarp=[1,1])
11341134
assumes that each warp within a CTA tile computes a single MFMA tile. When the tensor is larger than
11351135
a single CTA tile, these tiles are repeated across the tensor. In this setup, the output tiles computed
1136-
by each wave were strided by the number of warps per CTA tile in both row and column dimensions.
1136+
by each warp were strided by the number of warps per CTA tile in both row and column dimensions.
11371137

11381138
For instance, with 16 MFMA tiles and warpsPerCTA = [2, 2], the distribution of warps across the MFMA
11391139
tiles looked like:
@@ -1214,11 +1214,12 @@ It is characterized by the following parameters:
12141214
- 2: RDNA4; e.g., gfx1200, gfx1201
12151215
- 3: gfx1250
12161216
- `warpsPerCTA` indicates the warp layout in the block.
1217+
- `tilesPerWarp` The tile layout within a warp. Defaults to unit tile layout, i.e., single tile on all dimensions.
12171218
- `instrShape` indicates the shape in the form of (M, N, K) of the matrix
12181219
operation performed by a single WMMA instruction. Defaults to (16, 16, 16).
12191220
- `isTransposed` indicates the layout of the result tensor is transposed.
12201221

1221-
Example:
1222+
Example 1:
12221223
Suppose we have a tensor with shape [32, 64], `warpsPerCTA` set to [2, 2].
12231224
Matrix elements represent which lane owns the element. Currently only wave32 mode
12241225
is supported.
@@ -1292,20 +1293,59 @@ Row |
12921293
.. | ... ...
12931294
30 |[14 14 14 14 14 14 14 14 30 ... 30] [14 14 14 ... 30]
12941295
31 |[15 15 15 15 15 15 15 15 31 ... 31] [15 15 15 ... 31]
1296+
1297+
Example 2:
1298+
This example demonstrates the tilesPerWarp parameter, which shares the same sematics with
1299+
AMDMfmaEncodingAttr.
1300+
1301+
By default, WMMA layout assumes that each warp within a CTA tile computes a single WMMA tile.
1302+
When the tensor is larger than a single CTA tile, these tiles are repeated across the tensor.
1303+
In this setup, the output tiles computed by each warp are strided by the number of warps per CTA
1304+
tile in both row and column dimensions.
1305+
1306+
For instance, with 16 WMMA tiles and warpsPerCTA = [2, 2], the default(tilesPerWarp = [1, 1])
1307+
distribution of warps across the WMMA tiles looked like:
1308+
1309+
w0 w1 w0 w1
1310+
w2 w3 w2 w3
1311+
w0 w1 w0 w1
1312+
w2 w3 w2 w3
1313+
1314+
* Each unit reprsents a WMMA tile. w* shows which warp occupies that WMMA tile.
1315+
1316+
tilesPerWarp parameter allows each warp to compute contiguous WMMA tiles in the row and/or column dimensions.
1317+
Using the same example with tilesPerWarp = [2, 2], the layout becomes:
1318+
1319+
w0 w0 w1 w1
1320+
w0 w0 w1 w1
1321+
w2 w2 w3 w3
1322+
w2 w2 w3 w3
12951323
}];
12961324

12971325
let parameters = (
12981326
ins
12991327
"unsigned": $version,
13001328
"bool":$isTransposed,
13011329
ArrayRefParameter<"unsigned">:$warpsPerCTA,
1330+
ArrayRefParameter<"unsigned">:$tilesPerWarp,
13021331
"CTALayoutAttr":$CTALayout,
13031332
ArrayRefParameter<"unsigned">:$instrShape
13041333
);
13051334

13061335
let genVerifyDecl = 1;
13071336
let hasCustomAssemblyFormat = 1;
13081337

1338+
let builders = [
1339+
AttrBuilder<(ins "unsigned":$version,
1340+
"bool":$isTransposed,
1341+
"ArrayRef<unsigned>":$warpsPerCTA,
1342+
"CTALayoutAttr":$CTALayout,
1343+
"ArrayRef<unsigned>":$instrShape), [{
1344+
SmallVector<unsigned> tilesPerWarp(warpsPerCTA.size(), 1);
1345+
return $_get(context, version, isTransposed, warpsPerCTA, tilesPerWarp, CTALayout, instrShape);
1346+
}]>
1347+
];
1348+
13091349
let extraClassDeclaration = extraDistributedDeclaration # [{
13101350
SmallVector<int64_t> getRepForOperand(ArrayRef<int64_t> operandShape, int kDim, int opIdx) const;
13111351
SmallVector<unsigned> getRepOrderForOperand(int opIdx) const;
@@ -1314,6 +1354,9 @@ Row |
13141354
return {16, 16, 16};
13151355
}
13161356

1357+
// Check if tilesPerWarp is 1 in every dimension.
1358+
bool hasUnitTilesPerWarp() const;
1359+
13171360
// Returns a swizzled shared layout matching this WMMA layout for the
13181361
// dot operand at the given |operandIdx| with |operandShape|.
13191362
SwizzledSharedEncodingAttr composeSharedLayoutForOperand(

include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
2828
#include "mlir/Dialect/Tensor/IR/Tensor.h"
29+
#include "mlir/IR/BuiltinAttributes.h"
2930
#include "mlir/IR/BuiltinOps.h"
3031
#include "mlir/IR/BuiltinTypes.h"
3132
#include "mlir/IR/Dialect.h"
@@ -51,6 +52,17 @@ LogicalResult verifyMMAv5Op(Operation *op);
5152

5253
namespace mlir::triton::nvidia_gpu {
5354

55+
constexpr static char AttrTwoCTAsName[] = "ttng.two-ctas";
56+
57+
inline bool getModuleTwoCTAs(ModuleOp mod) {
58+
auto attr = mod->getAttrOfType<BoolAttr>(AttrTwoCTAsName);
59+
return attr ? attr.getValue() : false;
60+
}
61+
62+
inline bool getModuleTwoCTAs(Operation *op) {
63+
return getModuleTwoCTAs(op->getParentOfType<ModuleOp>());
64+
}
65+
5466
struct TensorMemory : public SideEffects::Resource::Base<TensorMemory> {
5567
StringRef getName() final { return "<TensorMemory>"; }
5668
};

include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.td

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,4 +174,14 @@ def TritonNvidiaGPURemoveTMEMTokensPass : Pass<"triton-nvidia-gpu-remove-tmem-to
174174
}];
175175
}
176176

177+
def TritonNvidiaGPUCheckMatmulTwoCTAPass : Pass<"triton-nvidia-check-matmul-two-cta", "mlir::ModuleOp"> {
178+
let summary = "Verify consistent two_ctas usage across matmuls";
179+
180+
let description = [{
181+
Inspect all matmul operations and ensure they agree on the `two_ctas`
182+
setting. Propagate the chosen value to the module so later lowering steps
183+
can access it. Compilation fails if mixed configurations are detected.
184+
}];
185+
}
186+
177187
#endif

include/triton/Tools/LinearLayout.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -459,6 +459,15 @@ class LinearLayout {
459459
auto getOutDimSizes() const { return llvm::make_second_range(outDims); }
460460

461461
// Relevant for reshaping
462+
463+
SmallVector<std::pair<StringAttr, int32_t>> getInDims() const {
464+
SmallVector<std::pair<StringAttr, int32_t>> inDims;
465+
inDims.reserve(bases.size());
466+
for (auto [inDim, inDimBases] : bases) {
467+
inDims.push_back({inDim, getInDimSize(inDim)});
468+
}
469+
return inDims;
470+
}
462471
SmallVector<std::pair<StringAttr, int32_t>> getOutDims() const {
463472
return to_vector(outDims);
464473
}

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1287,6 +1287,9 @@ LogicalResult AMDMfmaEncodingAttr::verify(
12871287
//===----------------------------------------------------------------------===//
12881288
// WMMA encoding
12891289
//===----------------------------------------------------------------------===//
1290+
bool AMDWmmaEncodingAttr::hasUnitTilesPerWarp() const {
1291+
return llvm::all_of(getTilesPerWarp(), [](int x) { return x == 1; });
1292+
}
12901293

12911294
Attribute AMDWmmaEncodingAttr::parse(AsmParser &parser, Type type) {
12921295
if (parser.parseLess().failed())
@@ -1303,6 +1306,7 @@ Attribute AMDWmmaEncodingAttr::parse(AsmParser &parser, Type type) {
13031306
std::optional<SmallVector<unsigned>> CTAsPerCGA;
13041307
std::optional<SmallVector<unsigned>> CTASplitNum;
13051308
std::optional<SmallVector<unsigned>> CTAOrder;
1309+
SmallVector<unsigned> tilesPerWarp = {};
13061310
SmallVector<unsigned> instrShape = getDefaultInstrShape();
13071311

13081312
for (const NamedAttribute &attr : dict) {
@@ -1318,6 +1322,11 @@ Attribute AMDWmmaEncodingAttr::parse(AsmParser &parser, Type type) {
13181322
if (parseIntArrayAttr(parser, attr, warpsPerCTA, "warpsPerCTA").failed())
13191323
return {};
13201324
}
1325+
if (attr.getName() == "tilesPerWarp") {
1326+
if (parseIntArrayAttr(parser, attr, tilesPerWarp, "tilesPerWarp")
1327+
.failed())
1328+
return {};
1329+
}
13211330
if (attr.getName() == "CTAsPerCGA") {
13221331
if (parseIntArrayAttr(parser, attr, CTAsPerCGA.emplace(), "CTAsPerCGA")
13231332
.failed())
@@ -1346,9 +1355,12 @@ Attribute AMDWmmaEncodingAttr::parse(AsmParser &parser, Type type) {
13461355
if (!CTALayout.has_value())
13471356
return {};
13481357

1349-
return parser.getChecked<AMDWmmaEncodingAttr>(parser.getContext(), version,
1350-
isTransposed, warpsPerCTA,
1351-
*CTALayout, instrShape);
1358+
if (tilesPerWarp.empty())
1359+
tilesPerWarp = SmallVector<unsigned>(instrShape.size(), 1);
1360+
1361+
return parser.getChecked<AMDWmmaEncodingAttr>(
1362+
parser.getContext(), version, isTransposed, warpsPerCTA, tilesPerWarp,
1363+
*CTALayout, instrShape);
13521364
}
13531365

13541366
void AMDWmmaEncodingAttr::print(AsmPrinter &printer) const {
@@ -1360,6 +1372,10 @@ void AMDWmmaEncodingAttr::print(AsmPrinter &printer) const {
13601372
maybePrintCTALayout(getContext(), printer, getCTALayout(),
13611373
/*rank=*/getWarpsPerCTA().size());
13621374

1375+
auto tilesPerWarp = getTilesPerWarp();
1376+
if (!hasUnitTilesPerWarp())
1377+
printer << ", tilesPerWarp = [" << getTilesPerWarp() << "]";
1378+
13631379
if (getInstrShape() != ArrayRef(getDefaultInstrShape())) {
13641380
printer << ", instrShape = [" << getInstrShape() << "]";
13651381
}
@@ -1369,7 +1385,8 @@ void AMDWmmaEncodingAttr::print(AsmPrinter &printer) const {
13691385
LogicalResult AMDWmmaEncodingAttr::verify(
13701386
function_ref<mlir::InFlightDiagnostic()> emitError, unsigned version,
13711387
bool isTransposed, llvm::ArrayRef<unsigned int> warpsPerCTA,
1372-
CTALayoutAttr ctaLayout, llvm::ArrayRef<unsigned> instrShape) {
1388+
llvm::ArrayRef<unsigned int> tilesPerWarp, CTALayoutAttr ctaLayout,
1389+
llvm::ArrayRef<unsigned> instrShape) {
13731390
if (!(version >= 1 && version <= 3))
13741391
return emitError() << "WMMA version must be in the [1, 3] range";
13751392

@@ -2176,7 +2193,7 @@ void AMDRotatingSharedEncodingAttr::print(AsmPrinter &printer) const {
21762193
// TODO: there is a lot of common code with MmaEncoding here
21772194

21782195
bool AMDMfmaEncodingAttr::hasUnitTilesPerWarp() const {
2179-
return !llvm::any_of(getTilesPerWarp(), [](int x) { return x != 1; });
2196+
return llvm::all_of(getTilesPerWarp(), [](int x) { return x == 1; });
21802197
}
21812198

21822199
SmallVector<int64_t>
@@ -2309,6 +2326,8 @@ AMDWmmaEncodingAttr::getRepForOperand(ArrayRef<int64_t> operandShape, int kDim,
23092326

23102327
assert(operandTileShape.size() == 2);
23112328
auto warpsPerCTA = getWarpsPerCTA();
2329+
auto tilesPerWarp = getTilesPerWarp();
2330+
23122331
auto rank = operandShape.size();
23132332
assert(rank == 2 || rank == 3);
23142333
int numRepBatch =
@@ -2317,15 +2336,19 @@ AMDWmmaEncodingAttr::getRepForOperand(ArrayRef<int64_t> operandShape, int kDim,
23172336
return {
23182337
numRepBatch,
23192338
std::max<int64_t>(1, operandShape[rank - 2] /
2320-
(operandTileShape[0] * warpsPerCTA[rank - 2])),
2339+
(operandTileShape[0] * tilesPerWarp[rank - 2] *
2340+
warpsPerCTA[rank - 2])) *
2341+
tilesPerWarp[rank - 2],
23212342
std::max<int64_t>(1, operandShape[rank - 1] / operandTileShape[1])};
23222343
else {
23232344
assert(opIdx == 1);
23242345
return {
23252346
numRepBatch,
23262347
std::max<int64_t>(1, operandShape[rank - 2] / operandTileShape[0]),
2327-
std::max<int64_t>(1, operandShape[rank - 1] / (operandTileShape[1] *
2328-
warpsPerCTA[rank - 1]))};
2348+
std::max<int64_t>(1, operandShape[rank - 1] /
2349+
(operandTileShape[1] * tilesPerWarp[rank - 1] *
2350+
warpsPerCTA[rank - 1])) *
2351+
tilesPerWarp[rank - 1]};
23292352
}
23302353
}
23312354

0 commit comments

Comments
 (0)