Skip to content

Commit 548d1eb

Browse files
Merge OpenAI Triton commit f804bbc (#5050)
This PR change the Triton base from 31baa6d to f804bbc (Sep 1). Pass rate: 98.74%->98.73%
2 parents 542a9a6 + a0d613f commit 548d1eb

File tree

34 files changed

+1404
-1934
lines changed

34 files changed

+1404
-1934
lines changed

include/triton/Dialect/Triton/IR/TritonOpInterfaces.td

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,7 @@ def DotOpInterface : OpInterface<"DotOpInterface"> {
4949
/*retType=*/"::mlir::Value",
5050
/*methodName=*/"getB",
5151
/*args=*/(ins)>,
52-
InterfaceMethod<
53-
/*desc=*/"Get the output tensor",
54-
/*retType=*/"::mlir::Value",
55-
/*methodName=*/"getD",
56-
/*args=*/(ins)>,
57-
InterfaceMethod<
52+
InterfaceMethod<
5853
/*desc=*/"Verify the dimensions of the A and B DotOp operands.",
5954
/*retType=*/"bool",
6055
/*methodName=*/"verifyDims",
@@ -69,7 +64,6 @@ def DotOpInterface : OpInterface<"DotOpInterface"> {
6964
auto aTy = cast<ShapedType>($_op.getA().getType());
7065
auto bTy = cast<ShapedType>($_op.getB().getType());
7166
auto cTy = cast<ShapedType>($_op->getOperand(2).getType());
72-
auto dTy = cast<ShapedType>($_op.getD().getType());
7367
auto aShape = aTy.getShape();
7468
auto bShape = bTy.getShape();
7569
auto cShape = cTy.getShape();

include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ class AMDRotatingSharedEncodingAttr;
1919
class AMDMfmaEncodingAttr;
2020
class TensorOrMemDesc;
2121
class MemDescType;
22-
class CTALayoutAttr;
2322

2423
// - BlockedEncodingAttrs have the following input dimensions.
2524
//
@@ -127,13 +126,6 @@ LinearLayout chooseScaledMfmaScaleLayout(MLIRContext *ctx, int dotOperandIdx,
127126
ArrayRef<unsigned> tilesPerWarp,
128127
ArrayRef<unsigned> warpsPerCTA);
129128

130-
LinearLayout getSM120DotScaledScaleLayout(MLIRContext *ctx, int dotOperandIdx,
131-
ArrayRef<int64_t> dotOperandShape,
132-
ArrayRef<unsigned> tilesPerWarp,
133-
ArrayRef<unsigned> warpsPerCTA,
134-
unsigned instrM, unsigned instrN,
135-
CTALayoutAttr ctaLayoutAttr);
136-
137129
// Create LinearLayout for nvidia mma tile.
138130
LinearLayout nvidiaMmaTile(MLIRContext *ctx, ArrayRef<unsigned> tileShape,
139131
unsigned kWidth, ArrayRef<unsigned> order,

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 0 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -1407,83 +1407,6 @@ LinearLayout chooseDsReadB64TrLayout(Attribute enc, ArrayRef<int64_t> shape,
14071407
return chooseDotDsReadB64TrLayout(dot, shape, elemBitWidth);
14081408
}
14091409

1410-
// Warp-level block scaling (sm_120, m16n8k32)
1411-
// Reference: NVIDIA PTX ISA "Warp-level block scaling"
1412-
// https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-block-scaling
1413-
//
1414-
// Semantics:
1415-
// D = (A * SF_A) * (B * SF_B) + C
1416-
// scale_vec::1X -> SF_A shape Mx1 (per-row), SF_B shape 1xN (per-col)
1417-
//
1418-
// Providers (within each warp quad of 4 lanes):
1419-
// - A scales are provided by a lane-pair selected by thread-id-a ∈ {0,1}
1420-
// (0 => lanes {0,1}, 1 => lanes {2,3} in the quad).
1421-
// - B scales are provided by a single lane selected by thread-id-b ∈
1422-
// {0,1,2,3}.
1423-
//
1424-
// Byte selectors (which subfield of the 32-bit metadata is used):
1425-
// - 1X: 1 byte => byte-id ∈ {0,1,2,3}
1426-
//
1427-
// Implementation notes:
1428-
// - We support only scale_vec::1X for now.
1429-
// - We choose a fixed provider for A (thread-id-a = 0) and B (thread-id-b =
1430-
// 0)
1431-
// - In this implementation, each lane in a quad has the same scale factor.
1432-
LinearLayout getSM120DotScaledScaleLayout(
1433-
MLIRContext *ctx, int dotOperandIdx, ArrayRef<int64_t> dotOperandShape,
1434-
ArrayRef<unsigned> tilesPerWarp, ArrayRef<unsigned> warpsPerCTA,
1435-
unsigned mmaInstrM, unsigned mmaInstrN, CTALayoutAttr ctaLayoutAttr) {
1436-
unsigned rank = dotOperandShape.size();
1437-
auto outDims = standardOutDimNames(ctx, rank);
1438-
1439-
StringAttr kRegister = StringAttr::get(ctx, "register");
1440-
StringAttr kLane = StringAttr::get(ctx, "lane");
1441-
StringAttr kWarp = StringAttr::get(ctx, "warp");
1442-
1443-
const unsigned mIndex = 0;
1444-
const unsigned nIndex = 1;
1445-
const int instrM = mmaInstrM;
1446-
const int instrN = mmaInstrN;
1447-
const int kSize = dotOperandShape[1];
1448-
const int mWarps = warpsPerCTA[mIndex];
1449-
const int nWarps = warpsPerCTA[nIndex];
1450-
const int totalWarps = mWarps * nWarps;
1451-
const unsigned mRep_warp = tilesPerWarp[mIndex];
1452-
const unsigned nRep_warp = tilesPerWarp[nIndex];
1453-
const unsigned kRep = std::min<unsigned>(kSize, 2);
1454-
1455-
std::vector<std::vector<int32_t>> registerBase;
1456-
std::vector<std::vector<int32_t>> laneBase;
1457-
std::vector<std::vector<int32_t>> warpBase;
1458-
if (dotOperandIdx == 0) { // per-row A-scale
1459-
laneBase = {{0, 8}, {0, 0}, {0, 1}, {0, 2}, {0, 4}};
1460-
for (int offset = instrM * mWarps; offset < instrM * mWarps * mRep_warp;
1461-
offset <<= 1)
1462-
registerBase.push_back({0, offset});
1463-
for (int w = mWarps; w < totalWarps; w <<= 1)
1464-
warpBase.push_back({0, 0});
1465-
for (int offset = instrM; offset < instrM * mWarps; offset <<= 1)
1466-
warpBase.push_back({0, offset});
1467-
} else { // per-col B-scale
1468-
laneBase = {{0, 0}, {0, 0}, {0, 1}, {0, 2}, {0, 4}};
1469-
if (nRep_warp > 1)
1470-
registerBase.push_back({0, nWarps * instrN});
1471-
for (int k = 1; k < kRep; k += 1)
1472-
registerBase.push_back({1 << (k - 1), 0});
1473-
for (int offset = instrN; offset < instrN * nWarps; offset <<= 1)
1474-
warpBase.push_back({0, offset});
1475-
for (int w = nWarps; w < totalWarps; w <<= 1)
1476-
warpBase.push_back({0, 0});
1477-
}
1478-
1479-
const unsigned kIdx = (dotOperandShape[0] == 1) ? 0 : 1;
1480-
const unsigned mnIdx = 1 - kIdx;
1481-
LinearLayout ctaLayout(
1482-
{{kRegister, registerBase}, {kLane, laneBase}, {kWarp, warpBase}},
1483-
{outDims[kIdx], outDims[mnIdx]});
1484-
return combineCtaCgaWithShape(ctaLayout, ctaLayoutAttr, dotOperandShape);
1485-
}
1486-
14871410
LinearLayout chooseScaledMfmaScaleLayout(MLIRContext *ctx, int dotOperandIdx,
14881411
ArrayRef<int64_t> dotOperandShape,
14891412
unsigned mfmaMDim,

0 commit comments

Comments
 (0)