Skip to content
2 changes: 1 addition & 1 deletion lib/Dialect/TritonGPU/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ LogicalResult UpcastMXFPOp::inferReturnTypes(
dpasEncoding.getSystolicDepth(), dpasEncoding.getExecutionSize(),
intel::DpasEncodingAttr::getOpsPerChannel(elemType),
dpasEncoding.getWarpsPerCTA(), dpasEncoding.getRepCluster(),
dpasEncoding.getSubGroupSize());
product<unsigned>(dpasEncoding.getThreadsPerWarp()));
newVEncoding = DotOperandEncodingAttr::get(
ctx, opIdx, newDpasEncoding, newDpasEncoding.getOpsPerChannel());
} else {
Expand Down

Large diffs are not rendered by default.

14 changes: 7 additions & 7 deletions third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ SmallVector<unsigned> DpasEncodingAttr::getShapeC() const {
SmallVector<unsigned> DpasEncodingAttr::getSizePerThread() const {
size_t rank = getWarpsPerCTA().size();
SmallVector<unsigned> res(rank, 1);
unsigned threadsPerWarp = getSubGroupSize();
unsigned threadsPerWarp = getThreadsPerWarp__();
SmallVector<unsigned> shapeC = getDPASInstShapeC();
unsigned elemsNum = product<unsigned>(shapeC);
unsigned elemsPerThread = elemsNum / threadsPerWarp;
Expand Down Expand Up @@ -263,7 +263,7 @@ unsigned DpasEncodingAttr::getTotalElemsPerThreadForOperand(
ArrayRef<int64_t> shape, mlir::Type eltTy, int kWidth, OpIdx opIdx) const {
SmallVector<int64_t> shapePerCTA = getShapePerCTA(*this, shape);
SmallVector<int64_t> rep = getDPASRepetitions(shapePerCTA, opIdx);
unsigned threadsPerWar = getSubGroupSize();
unsigned threadsPerWar = getThreadsPerWarp__();
size_t rank = shape.size();

switch (opIdx) {
Expand Down Expand Up @@ -302,7 +302,7 @@ SmallVector<unsigned> DpasEncodingAttr::getThreadsPerWarp() const {
size_t rank = getWarpsPerCTA().size();
SmallVector<unsigned> res(rank, 1);
unsigned executionSize = getExecutionSize();
unsigned subGroupSize = getSubGroupSize();
unsigned subGroupSize = getThreadsPerWarp__();
if (subGroupSize < executionSize) {
llvm::report_fatal_error("DpasEncodingAttr sub-group size could not be "
"smaller than the execution size");
Expand All @@ -321,7 +321,7 @@ DpasEncodingAttr::getSizePerThreadForOperand(int kWidth, OpIdx opIdx) const {
switch (opIdx) {
case OpIdx::OperandA: {
SmallVector<unsigned> shapeA = getDPASInstShapeA();
unsigned subGroupSize = getSubGroupSize();
unsigned subGroupSize = getThreadsPerWarp__();
unsigned opsPerChannel = getOpsPerChannel();

// pack the value to i16 for scalar bit width <=16.
Expand All @@ -339,7 +339,7 @@ DpasEncodingAttr::getSizePerThreadForOperand(int kWidth, OpIdx opIdx) const {
} break;
case OpIdx::OperandB: {
SmallVector<unsigned> shapeB = getShapeB();
unsigned subGroupSize = getSubGroupSize();
unsigned subGroupSize = getThreadsPerWarp__();
unsigned executionSize = getExecutionSize();
if (subGroupSize < executionSize) {
llvm::report_fatal_error("DpasEncodingAttr sub-group size could not "
Expand All @@ -359,7 +359,7 @@ SmallVector<unsigned> DpasEncodingAttr::getContigPerThread() const {
assert(rank == 2 || rank == 3);
SmallVector<unsigned> contigPerThread(rank, 1);

unsigned threadsPerWarp = getSubGroupSize();
unsigned threadsPerWarp = getThreadsPerWarp__();
SmallVector<unsigned> instShapeC = getDPASInstShapeC();
// The software vectorization vectorized the value as C array: int a[N] ->
// int a[N][threadsPerWarp]
Expand Down Expand Up @@ -494,7 +494,7 @@ void DpasEncodingAttr::print(AsmPrinter &printer) const {
<< "systolicDepth = " << getSystolicDepth() << ", "
<< "executionSize = " << getExecutionSize() << ", "
<< "opsPerChan = " << getOpsPerChannel() << ", "
<< "threadsPerWarp = " << getSubGroupSize() << ", "
<< "threadsPerWarp = " << getThreadsPerWarp__() << ", "
<< "warpsPerCTA = [" << llvm::ArrayRef<unsigned>(warpsPerCTA) << "], "
<< "repCluster = [" << repCluster << "], " << "A = [" << rA << "], "
<< "B = [" << rB << "], " << "C = [" << rC << "]" << "}>";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ struct ConvertLayoutOpConversion
size_t totalElems = elems.size();
auto numElemsPerOperand =
product<unsigned>(dpasLayout.getDPASInstShapeC()) /
dpasLayout.getSubGroupSize();
product<unsigned>(dpasLayout.getThreadsPerWarp());
Type elemTy =
this->getTypeConverter()->convertType(srcType.getElementType());
VectorType dotOpTy = vec_ty(elemTy, numElemsPerOperand);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class DotOpDPASConversionHelper {
Type i16Ty = type::i16Ty(ctx);
Type s32Ty = IntegerType::get(ctx, 32, IntegerType::Signed);

unsigned threadsPerWarp = layout.getSubGroupSize();
unsigned threadsPerWarp = product<unsigned>(layout.getThreadsPerWarp());
unsigned opsPerChannel = layout.getOpsPerChannel();
SmallVector<unsigned> shapeC = layout.getDPASInstShapeC();
unsigned elemNumC = product<unsigned>(shapeC) / threadsPerWarp;
Expand Down
3 changes: 2 additions & 1 deletion third_party/intel/lib/TritonIntelGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,8 @@ emitOffsetForDpasLayoutPerCTA(const DpasEncodingAttr &dpasLayout,
sizePerThreads[rank - 2] / repCluster[rank - 2],
sizePerThreads[rank - 1] / repCluster[rank - 1]};

unsigned rowsPerElem = dpasLayout.getSubGroupSize() / instShapeC[1];
unsigned rowsPerElem =
product<unsigned>(dpasLayout.getThreadsPerWarp()) / instShapeC[1];
unsigned colsPerElem = 1;

unsigned repNumber = product<unsigned>(repCluster);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,8 @@ class DecomposeScaledBlocked : public OpRewritePattern<tt::DotScaledOp> {
auto opEncoding = ttg::intel::DpasEncodingAttr::get(
ctx, dpasEnc.getRepeatCount(), dpasEnc.getSystolicDepth(),
dpasEnc.getExecutionSize(), opsPerChannel, dpasEnc.getWarpsPerCTA(),
dpasEnc.getRepCluster(), dpasEnc.getSubGroupSize());
dpasEnc.getRepCluster(),
product<unsigned>(dpasEnc.getThreadsPerWarp()));

auto newOpEncoding = ttg::DotOperandEncodingAttr::get(
ctx, unsigned(opIdx), opEncoding, opEncoding.getOpsPerChannel());
Expand Down Expand Up @@ -362,7 +363,8 @@ class DecomposeScaledBlocked : public OpRewritePattern<tt::DotScaledOp> {
auto retDpasEncoding = ttg::intel::DpasEncodingAttr::get(
ctx, dpasEnc.getRepeatCount(), dpasEnc.getSystolicDepth(),
dpasEnc.getExecutionSize(), opsPerChannel, dpasEnc.getWarpsPerCTA(),
dpasEnc.getRepCluster(), dpasEnc.getSubGroupSize());
dpasEnc.getRepCluster(),
product<unsigned>(dpasEnc.getThreadsPerWarp()));
auto retDotOpEncoding =
ttg::DotOperandEncodingAttr::get(ctx, unsigned(opIdx), retDpasEncoding,
retDpasEncoding.getOpsPerChannel());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
//===----------------------------------------------------------------------===//

#include "intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h"
#include "triton/Dialect/Triton/IR/Utility.h"

#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

Expand Down Expand Up @@ -237,11 +238,12 @@ struct DpasOperandPattern final : OpRewritePattern<ReduceOp> {

// We want to transpose matrices of N*threads_per_warpxthreads_per_warp
// shape.
unsigned threadsPerWarp = product<unsigned>(encoding.getThreadsPerWarp());
if ( // X axis condition
encoding.getExecutionSize() != encoding.getSubGroupSize() ||
encoding.getExecutionSize() != threadsPerWarp ||
// Y axis conditions
(encoding.getRepeatCount() * encoding.getRepCluster()[0]) %
encoding.getSubGroupSize() !=
threadsPerWarp !=
0)
return failure();

Expand Down