Skip to content

Commit ac96e7f

Browse files
authored
Merge OpenAI Triton commit 2ad519c (#5126)
This PR change the Triton base from 6fa1dd6 to 2ad519c (Sep 10). Pass rate: 98.8%
2 parents d6b921e + f3ed6e8 commit ac96e7f

File tree

26 files changed

+1467
-322
lines changed

26 files changed

+1467
-322
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,7 @@ See [`python/triton/knobs.py`](python/triton/knobs.py) for the full list of conf
232232
- `TRITON_F32_DEFAULT` sets the default input precision of `tl.dot` when using 32-bit floats, which can be either `ieee`, `tf32`, or `tf32x3`.
233233
- `TRITON_FRONT_END_DEBUGGING=1` disables exception wrapping when an error occurs in the compiler frontend, allowing the full stack trace to be seen.
234234
- `TRITON_DISABLE_LINE_INFO=1` removes all line information from the module.
235+
- `PTXAS_OPTIONS` passes additional command-line options to the PTX assembler `ptxas` (only on NVIDIA).
235236

236237
> [!NOTE]
237238
> Some of these environment variables don't have a knob in `knobs.py`-- those are only relevant to the C++ layer(s), hence they don't exist in the python layer.

include/triton/Dialect/Triton/IR/TritonOpInterfaces.td

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,12 @@ def DotOpInterface : OpInterface<"DotOpInterface"> {
4949
/*retType=*/"::mlir::Value",
5050
/*methodName=*/"getB",
5151
/*args=*/(ins)>,
52-
InterfaceMethod<
52+
InterfaceMethod<
53+
/*desc=*/"Get the output tensor",
54+
/*retType=*/"::mlir::Value",
55+
/*methodName=*/"getD",
56+
/*args=*/(ins)>,
57+
InterfaceMethod<
5358
/*desc=*/"Verify the dimensions of the A and B DotOp operands.",
5459
/*retType=*/"bool",
5560
/*methodName=*/"verifyDims",
@@ -64,6 +69,7 @@ def DotOpInterface : OpInterface<"DotOpInterface"> {
6469
auto aTy = cast<ShapedType>($_op.getA().getType());
6570
auto bTy = cast<ShapedType>($_op.getB().getType());
6671
auto cTy = cast<ShapedType>($_op->getOperand(2).getType());
72+
auto dTy = cast<ShapedType>($_op.getD().getType());
6773
auto aShape = aTy.getShape();
6874
auto bShape = bTy.getShape();
6975
auto cShape = cTy.getShape();

include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,13 @@ LinearLayout chooseScaledMfmaScaleLayout(MLIRContext *ctx, int dotOperandIdx,
135135
ArrayRef<unsigned> tilesPerWarp,
136136
ArrayRef<unsigned> warpsPerCTA);
137137

138+
LinearLayout getSM120DotScaledScaleLayout(MLIRContext *ctx, int dotOperandIdx,
139+
ArrayRef<int64_t> dotOperandShape,
140+
ArrayRef<unsigned> tilesPerWarp,
141+
ArrayRef<unsigned> warpsPerCTA,
142+
unsigned instrM, unsigned instrN,
143+
CTALayoutAttr ctaLayoutAttr);
144+
138145
// Create LinearLayout for nvidia mma tile.
139146
LinearLayout nvidiaMmaTile(MLIRContext *ctx, ArrayRef<unsigned> tileShape,
140147
unsigned kWidth, ArrayRef<unsigned> order,

include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,10 @@ bool isPureScalarOp(Operation *op);
4040
bool getDominatingValueSetOpsToHoist(
4141
DominanceInfo &domInfo, Operation *refOp, ArrayRef<Value> valueSet,
4242
llvm::SetVector<Operation *> &toHoist,
43-
function_ref<bool(Operation *)> canHoist = isPureScalarOp);
43+
function_ref<bool(Operation *)> canHoist = isPureScalarOp,
44+
function_ref<bool(BlockArgument)> canUseArg = [](BlockArgument) {
45+
return false;
46+
});
4447

4548
// Hoist the given set of operations above the reference operation.
4649
void hoistOpsBefore(Operation *refOp,

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1405,6 +1405,83 @@ LinearLayout chooseDsReadB64TrLayout(Attribute enc, ArrayRef<int64_t> shape,
14051405
return chooseDotDsReadB64TrLayout(dot, shape, elemBitWidth);
14061406
}
14071407

1408+
// Warp-level block scaling (sm_120, m16n8k32)
1409+
// Reference: NVIDIA PTX ISA "Warp-level block scaling"
1410+
// https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-block-scaling
1411+
//
1412+
// Semantics:
1413+
// D = (A * SF_A) * (B * SF_B) + C
1414+
// scale_vec::1X -> SF_A shape Mx1 (per-row), SF_B shape 1xN (per-col)
1415+
//
1416+
// Providers (within each warp quad of 4 lanes):
1417+
// - A scales are provided by a lane-pair selected by thread-id-a ∈ {0,1}
1418+
// (0 => lanes {0,1}, 1 => lanes {2,3} in the quad).
1419+
// - B scales are provided by a single lane selected by thread-id-b ∈
1420+
// {0,1,2,3}.
1421+
//
1422+
// Byte selectors (which subfield of the 32-bit metadata is used):
1423+
// - 1X: 1 byte => byte-id ∈ {0,1,2,3}
1424+
//
1425+
// Implementation notes:
1426+
// - We support only scale_vec::1X for now.
1427+
// - We choose a fixed provider for A (thread-id-a = 0) and B (thread-id-b =
1428+
// 0)
1429+
// - In this implementation, each lane in a quad has the same scale factor.
1430+
LinearLayout getSM120DotScaledScaleLayout(
1431+
MLIRContext *ctx, int dotOperandIdx, ArrayRef<int64_t> dotOperandShape,
1432+
ArrayRef<unsigned> tilesPerWarp, ArrayRef<unsigned> warpsPerCTA,
1433+
unsigned mmaInstrM, unsigned mmaInstrN, CTALayoutAttr ctaLayoutAttr) {
1434+
unsigned rank = dotOperandShape.size();
1435+
auto outDims = standardOutDimNames(ctx, rank);
1436+
1437+
StringAttr kRegister = StringAttr::get(ctx, "register");
1438+
StringAttr kLane = StringAttr::get(ctx, "lane");
1439+
StringAttr kWarp = StringAttr::get(ctx, "warp");
1440+
1441+
const unsigned mIndex = 0;
1442+
const unsigned nIndex = 1;
1443+
const int instrM = mmaInstrM;
1444+
const int instrN = mmaInstrN;
1445+
const int kSize = dotOperandShape[1];
1446+
const int mWarps = warpsPerCTA[mIndex];
1447+
const int nWarps = warpsPerCTA[nIndex];
1448+
const int totalWarps = mWarps * nWarps;
1449+
const unsigned mRep_warp = tilesPerWarp[mIndex];
1450+
const unsigned nRep_warp = tilesPerWarp[nIndex];
1451+
const unsigned kRep = std::min<unsigned>(kSize, 2);
1452+
1453+
std::vector<std::vector<int32_t>> registerBase;
1454+
std::vector<std::vector<int32_t>> laneBase;
1455+
std::vector<std::vector<int32_t>> warpBase;
1456+
if (dotOperandIdx == 0) { // per-row A-scale
1457+
laneBase = {{0, 8}, {0, 0}, {0, 1}, {0, 2}, {0, 4}};
1458+
for (int offset = instrM * mWarps; offset < instrM * mWarps * mRep_warp;
1459+
offset <<= 1)
1460+
registerBase.push_back({0, offset});
1461+
for (int w = mWarps; w < totalWarps; w <<= 1)
1462+
warpBase.push_back({0, 0});
1463+
for (int offset = instrM; offset < instrM * mWarps; offset <<= 1)
1464+
warpBase.push_back({0, offset});
1465+
} else { // per-col B-scale
1466+
laneBase = {{0, 0}, {0, 0}, {0, 1}, {0, 2}, {0, 4}};
1467+
if (nRep_warp > 1)
1468+
registerBase.push_back({0, nWarps * instrN});
1469+
for (int k = 1; k < kRep; k += 1)
1470+
registerBase.push_back({1 << (k - 1), 0});
1471+
for (int offset = instrN; offset < instrN * nWarps; offset <<= 1)
1472+
warpBase.push_back({0, offset});
1473+
for (int w = nWarps; w < totalWarps; w <<= 1)
1474+
warpBase.push_back({0, 0});
1475+
}
1476+
1477+
const unsigned kIdx = (dotOperandShape[0] == 1) ? 0 : 1;
1478+
const unsigned mnIdx = 1 - kIdx;
1479+
LinearLayout ctaLayout(
1480+
{{kRegister, registerBase}, {kLane, laneBase}, {kWarp, warpBase}},
1481+
{outDims[kIdx], outDims[mnIdx]});
1482+
return combineCtaCgaWithShape(ctaLayout, ctaLayoutAttr, dotOperandShape);
1483+
}
1484+
14081485
LinearLayout chooseScaledMfmaScaleLayout(MLIRContext *ctx, int dotOperandIdx,
14091486
ArrayRef<int64_t> dotOperandShape,
14101487
unsigned mfmaMDim,

0 commit comments

Comments
 (0)