Skip to content

Commit 292c08f

Browse files
[DOCUMENTS]Update the DPAS encoding documents. (#2746)
Update the DPAS encoding documents based on the OCL interface requirements. --------- Co-authored-by: Whitney Tsang <[email protected]>
1 parent af8b01d commit 292c08f

File tree

8 files changed

+176
-55
lines changed

8 files changed

+176
-55
lines changed

lib/Dialect/TritonGPU/IR/Ops.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -427,7 +427,7 @@ LogicalResult UpcastMXFPOp::inferReturnTypes(
427427
dpasEncoding.getSystolicDepth(), dpasEncoding.getExecutionSize(),
428428
intel::DpasEncodingAttr::getOpsPerChannel(elemType),
429429
dpasEncoding.getWarpsPerCTA(), dpasEncoding.getRepCluster(),
430-
dpasEncoding.getSubGroupSize());
430+
product<unsigned>(dpasEncoding.getThreadsPerWarp()));
431431
newVEncoding = DotOperandEncodingAttr::get(
432432
ctx, opIdx, newDpasEncoding, newDpasEncoding.getOpsPerChannel());
433433
} else {

third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td

Lines changed: 156 additions & 40 deletions
Large diffs are not rendered by default.

third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ SmallVector<unsigned> DpasEncodingAttr::getShapeC() const {
134134
SmallVector<unsigned> DpasEncodingAttr::getSizePerThread() const {
135135
size_t rank = getWarpsPerCTA().size();
136136
SmallVector<unsigned> res(rank, 1);
137-
unsigned threadsPerWarp = getSubGroupSize();
137+
unsigned threadsPerWarp = getThreadsPerWarp__();
138138
SmallVector<unsigned> shapeC = getDPASInstShapeC();
139139
unsigned elemsNum = product<unsigned>(shapeC);
140140
unsigned elemsPerThread = elemsNum / threadsPerWarp;
@@ -263,7 +263,7 @@ unsigned DpasEncodingAttr::getTotalElemsPerThreadForOperand(
263263
ArrayRef<int64_t> shape, mlir::Type eltTy, int kWidth, OpIdx opIdx) const {
264264
SmallVector<int64_t> shapePerCTA = getShapePerCTA(*this, shape);
265265
SmallVector<int64_t> rep = getDPASRepetitions(shapePerCTA, opIdx);
266-
unsigned threadsPerWar = getSubGroupSize();
266+
unsigned threadsPerWar = getThreadsPerWarp__();
267267
size_t rank = shape.size();
268268

269269
switch (opIdx) {
@@ -302,7 +302,7 @@ SmallVector<unsigned> DpasEncodingAttr::getThreadsPerWarp() const {
302302
size_t rank = getWarpsPerCTA().size();
303303
SmallVector<unsigned> res(rank, 1);
304304
unsigned executionSize = getExecutionSize();
305-
unsigned subGroupSize = getSubGroupSize();
305+
unsigned subGroupSize = getThreadsPerWarp__();
306306
if (subGroupSize < executionSize) {
307307
llvm::report_fatal_error("DpasEncodingAttr sub-group size could not be "
308308
"smaller than the execution size");
@@ -321,7 +321,7 @@ DpasEncodingAttr::getSizePerThreadForOperand(int kWidth, OpIdx opIdx) const {
321321
switch (opIdx) {
322322
case OpIdx::OperandA: {
323323
SmallVector<unsigned> shapeA = getDPASInstShapeA();
324-
unsigned subGroupSize = getSubGroupSize();
324+
unsigned subGroupSize = getThreadsPerWarp__();
325325
unsigned opsPerChannel = getOpsPerChannel();
326326

327327
// pack the value to i16 for scalar bit width <=16.
@@ -339,7 +339,7 @@ DpasEncodingAttr::getSizePerThreadForOperand(int kWidth, OpIdx opIdx) const {
339339
} break;
340340
case OpIdx::OperandB: {
341341
SmallVector<unsigned> shapeB = getShapeB();
342-
unsigned subGroupSize = getSubGroupSize();
342+
unsigned subGroupSize = getThreadsPerWarp__();
343343
unsigned executionSize = getExecutionSize();
344344
if (subGroupSize < executionSize) {
345345
llvm::report_fatal_error("DpasEncodingAttr sub-group size could not "
@@ -359,7 +359,7 @@ SmallVector<unsigned> DpasEncodingAttr::getContigPerThread() const {
359359
assert(rank == 2 || rank == 3);
360360
SmallVector<unsigned> contigPerThread(rank, 1);
361361

362-
unsigned threadsPerWarp = getSubGroupSize();
362+
unsigned threadsPerWarp = getThreadsPerWarp__();
363363
SmallVector<unsigned> instShapeC = getDPASInstShapeC();
364364
// The software vectorization vectorized the value as C array: int a[N] ->
365365
// int a[N][threadsPerWarp]
@@ -494,7 +494,7 @@ void DpasEncodingAttr::print(AsmPrinter &printer) const {
494494
<< "systolicDepth = " << getSystolicDepth() << ", "
495495
<< "executionSize = " << getExecutionSize() << ", "
496496
<< "opsPerChan = " << getOpsPerChannel() << ", "
497-
<< "threadsPerWarp = " << getSubGroupSize() << ", "
497+
<< "threadsPerWarp = " << getThreadsPerWarp__() << ", "
498498
<< "warpsPerCTA = [" << llvm::ArrayRef<unsigned>(warpsPerCTA) << "], "
499499
<< "repCluster = [" << repCluster << "], " << "A = [" << rA << "], "
500500
<< "B = [" << rB << "], " << "C = [" << rC << "]" << "}>";

third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,7 @@ struct ConvertLayoutOpConversion
334334
size_t totalElems = elems.size();
335335
auto numElemsPerOperand =
336336
product<unsigned>(dpasLayout.getDPASInstShapeC()) /
337-
dpasLayout.getSubGroupSize();
337+
product<unsigned>(dpasLayout.getThreadsPerWarp());
338338
Type elemTy =
339339
this->getTypeConverter()->convertType(srcType.getElementType());
340340
VectorType dotOpTy = vec_ty(elemTy, numElemsPerOperand);

third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM/DPAS.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class DotOpDPASConversionHelper {
3737
Type i16Ty = type::i16Ty(ctx);
3838
Type s32Ty = IntegerType::get(ctx, 32, IntegerType::Signed);
3939

40-
unsigned threadsPerWarp = layout.getSubGroupSize();
40+
unsigned threadsPerWarp = product<unsigned>(layout.getThreadsPerWarp());
4141
unsigned opsPerChannel = layout.getOpsPerChannel();
4242
SmallVector<unsigned> shapeC = layout.getDPASInstShapeC();
4343
unsigned elemNumC = product<unsigned>(shapeC) / threadsPerWarp;

third_party/intel/lib/TritonIntelGPUToLLVM/Utility.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,8 @@ emitOffsetForDpasLayoutPerCTA(const DpasEncodingAttr &dpasLayout,
120120
sizePerThreads[rank - 2] / repCluster[rank - 2],
121121
sizePerThreads[rank - 1] / repCluster[rank - 1]};
122122

123-
unsigned rowsPerElem = dpasLayout.getSubGroupSize() / instShapeC[1];
123+
unsigned rowsPerElem =
124+
product<unsigned>(dpasLayout.getThreadsPerWarp()) / instShapeC[1];
124125
unsigned colsPerElem = 1;
125126

126127
unsigned repNumber = product<unsigned>(repCluster);

third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,8 @@ class DecomposeScaledBlocked : public OpRewritePattern<tt::DotScaledOp> {
304304
auto opEncoding = ttg::intel::DpasEncodingAttr::get(
305305
ctx, dpasEnc.getRepeatCount(), dpasEnc.getSystolicDepth(),
306306
dpasEnc.getExecutionSize(), opsPerChannel, dpasEnc.getWarpsPerCTA(),
307-
dpasEnc.getRepCluster(), dpasEnc.getSubGroupSize());
307+
dpasEnc.getRepCluster(),
308+
product<unsigned>(dpasEnc.getThreadsPerWarp()));
308309

309310
auto newOpEncoding = ttg::DotOperandEncodingAttr::get(
310311
ctx, unsigned(opIdx), opEncoding, opEncoding.getOpsPerChannel());
@@ -362,7 +363,8 @@ class DecomposeScaledBlocked : public OpRewritePattern<tt::DotScaledOp> {
362363
auto retDpasEncoding = ttg::intel::DpasEncodingAttr::get(
363364
ctx, dpasEnc.getRepeatCount(), dpasEnc.getSystolicDepth(),
364365
dpasEnc.getExecutionSize(), opsPerChannel, dpasEnc.getWarpsPerCTA(),
365-
dpasEnc.getRepCluster(), dpasEnc.getSubGroupSize());
366+
dpasEnc.getRepCluster(),
367+
product<unsigned>(dpasEnc.getThreadsPerWarp()));
366368
auto retDotOpEncoding =
367369
ttg::DotOperandEncodingAttr::get(ctx, unsigned(opIdx), retDpasEncoding,
368370
retDpasEncoding.getOpsPerChannel());

third_party/intel/lib/TritonIntelGPUTransforms/OptimizeReductionLocality.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
//===----------------------------------------------------------------------===//
1010

1111
#include "intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h"
12+
#include "triton/Dialect/Triton/IR/Utility.h"
1213

1314
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
1415

@@ -237,11 +238,12 @@ struct DpasOperandPattern final : OpRewritePattern<ReduceOp> {
237238

238239
// We want to transpose matrices of N*threads_per_warpxthreads_per_warp
239240
// shape.
241+
unsigned threadsPerWarp = product<unsigned>(encoding.getThreadsPerWarp());
240242
if ( // X axis condition
241-
encoding.getExecutionSize() != encoding.getSubGroupSize() ||
243+
encoding.getExecutionSize() != threadsPerWarp ||
242244
// Y axis conditions
243245
(encoding.getRepeatCount() * encoding.getRepCluster()[0]) %
244-
encoding.getSubGroupSize() !=
246+
threadsPerWarp !=
245247
0)
246248
return failure();
247249

0 commit comments

Comments
 (0)