Skip to content

Commit 2d934e7

Browse files
committed
Merge commit '31baa6d284b6a52ba464bfc260bc9395e66a4ac8'
2 parents e510e90 + 31baa6d commit 2d934e7

File tree

37 files changed

+1137
-234
lines changed

37 files changed

+1137
-234
lines changed

include/triton/Dialect/Triton/IR/TritonOpInterfaces.td

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,12 @@ def DotOpInterface : OpInterface<"DotOpInterface"> {
4949
/*retType=*/"::mlir::Value",
5050
/*methodName=*/"getB",
5151
/*args=*/(ins)>,
52-
InterfaceMethod<
52+
InterfaceMethod<
53+
/*desc=*/"Get the output tensor",
54+
/*retType=*/"::mlir::Value",
55+
/*methodName=*/"getD",
56+
/*args=*/(ins)>,
57+
InterfaceMethod<
5358
/*desc=*/"Verify the dimensions of the A and B DotOp operands.",
5459
/*retType=*/"bool",
5560
/*methodName=*/"verifyDims",
@@ -64,6 +69,7 @@ def DotOpInterface : OpInterface<"DotOpInterface"> {
6469
auto aTy = cast<ShapedType>($_op.getA().getType());
6570
auto bTy = cast<ShapedType>($_op.getB().getType());
6671
auto cTy = cast<ShapedType>($_op->getOperand(2).getType());
72+
auto dTy = cast<ShapedType>($_op.getD().getType());
6773
auto aShape = aTy.getShape();
6874
auto bShape = bTy.getShape();
6975
auto cShape = cTy.getShape();

include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ class AMDRotatingSharedEncodingAttr;
1919
class AMDMfmaEncodingAttr;
2020
class TensorOrMemDesc;
2121
class MemDescType;
22+
class CTALayoutAttr;
2223

2324
// - BlockedEncodingAttrs have the following input dimensions.
2425
//
@@ -126,6 +127,13 @@ LinearLayout chooseScaledMfmaScaleLayout(MLIRContext *ctx, int dotOperandIdx,
126127
ArrayRef<unsigned> tilesPerWarp,
127128
ArrayRef<unsigned> warpsPerCTA);
128129

130+
LinearLayout getSM120DotScaledScaleLayout(MLIRContext *ctx, int dotOperandIdx,
131+
ArrayRef<int64_t> dotOperandShape,
132+
ArrayRef<unsigned> tilesPerWarp,
133+
ArrayRef<unsigned> warpsPerCTA,
134+
unsigned instrM, unsigned instrN,
135+
CTALayoutAttr ctaLayoutAttr);
136+
129137
// Create LinearLayout for nvidia mma tile.
130138
LinearLayout nvidiaMmaTile(MLIRContext *ctx, ArrayRef<unsigned> tileShape,
131139
unsigned kWidth, ArrayRef<unsigned> order,

lib/Conversion/TritonGPUToLLVM/HistogramOpToLLVM.cpp

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ static SmallVector<Value> computeWarpLevelHistogram(
2525
int numBits = llvm::Log2_64(numBins);
2626
int numBitsLaneId = llvm::Log2_64(numThreadPerWarp);
2727
unsigned numElementsPerThreads = getTotalElemsPerThread(srcType);
28-
unsigned numThreadWithUniqueData = getThreadsPerWarp(srcType)[0];
2928
// The histogram is distributed across threads, each thread owns `numBins /
3029
// numThreadPerWarp` bins.
3130
SmallVector<Value> warpLevelHistogram(numBins / numThreadPerWarp, zero);
@@ -43,10 +42,6 @@ static SmallVector<Value> computeWarpLevelHistogram(
4342
numThreadPerWarp == 32 ? 0xFFFFFFFF : 0xFFFFFFFFFFFFFFFF;
4443
Value fullMask = b.int_val(numThreadPerWarp, fullMaskValue);
4544
Value mask = fullMask;
46-
// If not all threads have unique data, mask out the redundant ones.
47-
if (numThreadWithUniqueData < numThreadPerWarp) {
48-
mask = b.int_val(numThreadPerWarp, (1ULL << numThreadWithUniqueData) - 1);
49-
}
5045
for (int i = 0; i < numBitsLaneId; i++) {
5146
Value updateMask =
5247
b.select(b.icmp_ne(b.and_(threadId, b.i32_val(1 << i)), zero),
@@ -96,8 +91,6 @@ static SmallVector<Value> computeCrossWarpHistogram(
9691
Value threadId, int numWarps) {
9792
auto b = TritonLLVMOpBuilder(loc, rewriter);
9893
SmallVector<Value> histogramValues;
99-
unsigned numWarpsWithUniqueData = mlir::triton::gpu::getWarpsPerCTA(
100-
srcType.getEncoding(), srcType.getShape())[0];
10194
Value laneId = b.and_(threadId, b.i32_val(numThreadPerWarp - 1));
10295
// Initialize the shared memory with zeros.
10396
int64_t numElementPerThread =
@@ -112,19 +105,6 @@ static SmallVector<Value> computeCrossWarpHistogram(
112105
}
113106
b.barrier();
114107
Block *afterAtomics = nullptr;
115-
// If some warps have replicated data we need to skip those warps when
116-
// accumulating.
117-
if (numWarpsWithUniqueData < numWarps) {
118-
Block *currentBlock = rewriter.getInsertionBlock();
119-
afterAtomics =
120-
rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
121-
Block *atomicBlock = rewriter.createBlock(afterAtomics);
122-
rewriter.setInsertionPointToEnd(currentBlock);
123-
Value cond = b.icmp_ult(
124-
threadId, b.i32_val(numWarpsWithUniqueData * numThreadPerWarp));
125-
rewriter.create<LLVM::CondBrOp>(loc, cond, atomicBlock, afterAtomics);
126-
rewriter.setInsertionPointToStart(atomicBlock);
127-
}
128108
// Apply atomic add to update the histogram in shared memory.
129109
for (int i = 0; i < warpLevelHistogram.size(); ++i) {
130110
Value warpLevelHistogramValue = warpLevelHistogram[i];
@@ -209,6 +189,24 @@ struct HistogramOpConversion
209189
loc, rewriter, srcType, baseSharedMemPtr, warpLevelHistogram, numBins,
210190
numThreadsPerWarp, innerDimIndices, threadId, numWarps);
211191

192+
// Depending on the layout, some threads may have duplicate data. We can
193+
// account for this by calculating a "replication factor" and dividing the
194+
// results by it to avoid overcounting.
195+
auto replicationFactor = numWarps * numThreadsPerWarp;
196+
auto threadsPerWarp = getThreadsPerWarp(srcType);
197+
auto warpsPerCTA =
198+
getWarpsPerCTA(srcType.getEncoding(), srcType.getShape());
199+
replicationFactor /= std::accumulate(
200+
threadsPerWarp.begin(), threadsPerWarp.end(), 1, std::multiplies<>());
201+
replicationFactor /= std::accumulate(warpsPerCTA.begin(), warpsPerCTA.end(),
202+
1, std::multiplies<>());
203+
204+
auto b = TritonLLVMOpBuilder(loc, rewriter);
205+
for (auto i = 0; i < histogramValue.size(); ++i) {
206+
histogramValue[i] =
207+
b.sdiv(histogramValue[i], b.i32_val(replicationFactor));
208+
}
209+
212210
Value results = packLLElements(loc, typeConverter, histogramValue, rewriter,
213211
op.getType());
214212
rewriter.replaceOp(op, results);

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1339,12 +1339,6 @@ AMDWmmaEncodingAttr::verify(function_ref<mlir::InFlightDiagnostic()> emitError,
13391339
if (version != 1 && version != 2) {
13401340
return emitError() << "WMMA version must be in the [1, 2] range";
13411341
}
1342-
// Transposed layout is needed for bypassing LDS between multiple dots.
1343-
// Version 1 tt.dot results and tt.dot operand layouts are different,
1344-
// therefore we test and support transposed only for version 2.
1345-
if (version != 2 && isTransposed) {
1346-
return emitError() << "Transposed WMMA is supported only for version 2";
1347-
}
13481342
return success();
13491343
}
13501344

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1407,6 +1407,83 @@ LinearLayout chooseDsReadB64TrLayout(Attribute enc, ArrayRef<int64_t> shape,
14071407
return chooseDotDsReadB64TrLayout(dot, shape, elemBitWidth);
14081408
}
14091409

1410+
// Warp-level block scaling (sm_120, m16n8k32)
1411+
// Reference: NVIDIA PTX ISA "Warp-level block scaling"
1412+
// https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-block-scaling
1413+
//
1414+
// Semantics:
1415+
// D = (A * SF_A) * (B * SF_B) + C
1416+
// scale_vec::1X -> SF_A shape Mx1 (per-row), SF_B shape 1xN (per-col)
1417+
//
1418+
// Providers (within each warp quad of 4 lanes):
1419+
// - A scales are provided by a lane-pair selected by thread-id-a ∈ {0,1}
1420+
// (0 => lanes {0,1}, 1 => lanes {2,3} in the quad).
1421+
// - B scales are provided by a single lane selected by thread-id-b ∈
1422+
// {0,1,2,3}.
1423+
//
1424+
// Byte selectors (which subfield of the 32-bit metadata is used):
1425+
// - 1X: 1 byte => byte-id ∈ {0,1,2,3}
1426+
//
1427+
// Implementation notes:
1428+
// - We support only scale_vec::1X for now.
1429+
// - We choose a fixed provider for A (thread-id-a = 0) and B (thread-id-b =
1430+
// 0)
1431+
// - In this implementation, each lane in a quad has the same scale factor.
1432+
LinearLayout getSM120DotScaledScaleLayout(
1433+
MLIRContext *ctx, int dotOperandIdx, ArrayRef<int64_t> dotOperandShape,
1434+
ArrayRef<unsigned> tilesPerWarp, ArrayRef<unsigned> warpsPerCTA,
1435+
unsigned mmaInstrM, unsigned mmaInstrN, CTALayoutAttr ctaLayoutAttr) {
1436+
unsigned rank = dotOperandShape.size();
1437+
auto outDims = standardOutDimNames(ctx, rank);
1438+
1439+
StringAttr kRegister = StringAttr::get(ctx, "register");
1440+
StringAttr kLane = StringAttr::get(ctx, "lane");
1441+
StringAttr kWarp = StringAttr::get(ctx, "warp");
1442+
1443+
const unsigned mIndex = 0;
1444+
const unsigned nIndex = 1;
1445+
const int instrM = mmaInstrM;
1446+
const int instrN = mmaInstrN;
1447+
const int kSize = dotOperandShape[1];
1448+
const int mWarps = warpsPerCTA[mIndex];
1449+
const int nWarps = warpsPerCTA[nIndex];
1450+
const int totalWarps = mWarps * nWarps;
1451+
const unsigned mRep_warp = tilesPerWarp[mIndex];
1452+
const unsigned nRep_warp = tilesPerWarp[nIndex];
1453+
const unsigned kRep = std::min<unsigned>(kSize, 2);
1454+
1455+
std::vector<std::vector<int32_t>> registerBase;
1456+
std::vector<std::vector<int32_t>> laneBase;
1457+
std::vector<std::vector<int32_t>> warpBase;
1458+
if (dotOperandIdx == 0) { // per-row A-scale
1459+
laneBase = {{0, 8}, {0, 0}, {0, 1}, {0, 2}, {0, 4}};
1460+
for (int offset = instrM * mWarps; offset < instrM * mWarps * mRep_warp;
1461+
offset <<= 1)
1462+
registerBase.push_back({0, offset});
1463+
for (int w = mWarps; w < totalWarps; w <<= 1)
1464+
warpBase.push_back({0, 0});
1465+
for (int offset = instrM; offset < instrM * mWarps; offset <<= 1)
1466+
warpBase.push_back({0, offset});
1467+
} else { // per-col B-scale
1468+
laneBase = {{0, 0}, {0, 0}, {0, 1}, {0, 2}, {0, 4}};
1469+
if (nRep_warp > 1)
1470+
registerBase.push_back({0, nWarps * instrN});
1471+
for (int k = 1; k < kRep; k += 1)
1472+
registerBase.push_back({1 << (k - 1), 0});
1473+
for (int offset = instrN; offset < instrN * nWarps; offset <<= 1)
1474+
warpBase.push_back({0, offset});
1475+
for (int w = nWarps; w < totalWarps; w <<= 1)
1476+
warpBase.push_back({0, 0});
1477+
}
1478+
1479+
const unsigned kIdx = (dotOperandShape[0] == 1) ? 0 : 1;
1480+
const unsigned mnIdx = 1 - kIdx;
1481+
LinearLayout ctaLayout(
1482+
{{kRegister, registerBase}, {kLane, laneBase}, {kWarp, warpBase}},
1483+
{outDims[kIdx], outDims[mnIdx]});
1484+
return combineCtaCgaWithShape(ctaLayout, ctaLayoutAttr, dotOperandShape);
1485+
}
1486+
14101487
LinearLayout chooseScaledMfmaScaleLayout(MLIRContext *ctx, int dotOperandIdx,
14111488
ArrayRef<int64_t> dotOperandShape,
14121489
unsigned mfmaMDim,

lib/Dialect/TritonGPU/IR/Ops.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -908,6 +908,12 @@ LogicalResult WarpSpecializeOp::verify() {
908908
"cannot be nested inside another `ttg.warp_specialize` op");
909909
}
910910

911+
std::optional<int> numWarps = maybeLookupNumWarps(*this);
912+
if (numWarps && *numWarps % 4 != 0) {
913+
return mlir::emitError(getLoc()) << "warp-specialized kernels requires "
914+
"num_warps to be a multiple of 4";
915+
}
916+
911917
return success();
912918
}
913919

0 commit comments

Comments
 (0)