Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions include/triton/Dialect/TritonGPU/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,7 @@ SmallVector<unsigned> getCTAOrder(Attribute layout);
* (3) In the implementation of emitIndices, ShapePerCTATile will
* be replicated or wrapped to fit ShapePerCTA.
*/
SmallVector<unsigned>
getShapePerCTATile(Attribute layout,
ArrayRef<int64_t> tensorShape = ArrayRef<int64_t>());
SmallVector<unsigned> getShapePerCTATile(Attribute layout);

SmallVector<int64_t> getShapePerCTA(ArrayRef<unsigned> CTASplitNum,
ArrayRef<int64_t> shape);
Expand Down
16 changes: 0 additions & 16 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -502,11 +502,6 @@ We call each individual tile "rep".
"SmallVector<unsigned>",
"getCTASplitNum">,

InterfaceMethod<"Gets the shape of the encoding's tile, e.g. sizePerThread * threadsPerWarp * warpsPerCTA",
"SmallVector<unsigned>",
"getShapePerCTATile",
(ins "ArrayRef<int64_t>":$tensorShape)>,

InterfaceMethod<"Gets the number of contiguous elements per thread.",
"SmallVector<unsigned>",
"getContigPerThread">,
Expand Down Expand Up @@ -565,7 +560,6 @@ L(T) = [ {0,8} , {1,9} , {2,10}, {3,11}, {0,8} , {1, 9} , {2, 10}, {3, 11},
SmallVector<unsigned> getThreadOrder() const;

SmallVector<unsigned> getSizePerThread() const;
SmallVector<unsigned> getShapePerCTATile(ArrayRef<int64_t> tensorShape = ArrayRef<int64_t>()) const;

std::optional<LinearLayout> toLinearLayout(ArrayRef<int64_t> shape) const;
}];
Expand Down Expand Up @@ -765,13 +759,6 @@ def MmaEncodingTrait : AttrInterface<"MmaEncodingTrait"> {
"bool",
"supportReduction">,

InterfaceMethod<"Return shape per CTA.",
"SmallVector<unsigned>",
"getShapePerCTATileForOperand",
(ins "ArrayRef<int64_t>":$tensorShape,
"int":$kWidth,
"int":$opIdx)>,

InterfaceMethod<"Return size per thread for dot operands.",
"SmallVector<unsigned>",
"getSizePerThreadForOperand",
Expand Down Expand Up @@ -905,7 +892,6 @@ V [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129,
return true;
}
SmallVector<unsigned> getSizePerThreadForOperand(int kWidth, int opIdx) const;
SmallVector<unsigned> getShapePerCTATileForOperand(ArrayRef<int64_t> shape, int kWidth, int opIdx) const;
unsigned getTotalElemsPerThreadForOperand(ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;
SmallVector<int64_t> getInstrShapeForOperand(int kWidth, int opIdx) const;
SmallVector<int64_t> getRepForOperand(ArrayRef<int64_t> operandShape, int kWidth, int opIdx) const;
Expand Down Expand Up @@ -1017,7 +1003,6 @@ Row | warp 0 warp 2
return true;
}
SmallVector<unsigned> getSizePerThreadForOperand(int kWidth, int opIdx) const;
SmallVector<unsigned> getShapePerCTATileForOperand(ArrayRef<int64_t> shape, int kWidth, int opIdx) const;
unsigned getTotalElemsPerThreadForOperand(ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;
SmallVector<int64_t> getElemsPerInstrForOperands() const;
SmallVector<int64_t> getRepForOperand(ArrayRef<int64_t> operandShape,
Expand Down Expand Up @@ -1153,7 +1138,6 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
return false;
};
SmallVector<unsigned> getSizePerThreadForOperand(int kWidth, int opIdx) const;
SmallVector<unsigned> getShapePerCTATileForOperand(ArrayRef<int64_t> shape, int kWidth, int opIdx) const;

SmallVector<unsigned> getContigPerThread() {
assert(isAmpere() || isHopper());
Expand Down
6 changes: 2 additions & 4 deletions lib/Analysis/Allocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,8 @@ static SmallVector<unsigned> getRepShapeForCvt(RankedTensorType srcTy,

auto srcShapePerCTA = gpu::getShapePerCTA(srcTy);
auto dstShapePerCTA = gpu::getShapePerCTA(dstTy);
auto srcShapePerCTATile =
gpu::getShapePerCTATile(srcLayout, srcTy.getShape());
auto dstShapePerCTATile =
gpu::getShapePerCTATile(dstLayout, dstTy.getShape());
auto srcShapePerCTATile = gpu::getShapePerCTATile(srcLayout);
auto dstShapePerCTATile = gpu::getShapePerCTATile(dstLayout);

assert(srcTy.getRank() == dstTy.getRank() &&
"src and dst must have the same rank");
Expand Down
4 changes: 2 additions & 2 deletions lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,8 @@ struct ConvertLayoutOpConversion
SmallVector<unsigned> outNumCTAsEachRep(rank);
SmallVector<unsigned> inNumCTAs(rank);
SmallVector<unsigned> outNumCTAs(rank);
auto srcShapePerCTATile = getShapePerCTATile(srcLayout, srcTy.getShape());
auto dstShapePerCTATile = getShapePerCTATile(dstLayout, shape);
auto srcShapePerCTATile = getShapePerCTATile(srcLayout);
auto dstShapePerCTATile = getShapePerCTATile(dstLayout);
auto shapePerCTA = getShapePerCTA(srcLayout, shape);

for (unsigned d = 0; d < rank; ++d) {
Expand Down
2 changes: 1 addition & 1 deletion lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ struct ReduceOpConversion
auto resultIndices = emitIndices(loc, rewriter, targetInfo,
resultLayout, resultTy, true);
auto resultShape = resultTy.getShape();
auto resultCTATile = getShapePerCTATile(resultLayout, resultShape);
auto resultCTATile = getShapePerCTATile(resultLayout);
assert(resultIndices.size() == resultElems);

SmallVector<Value> resultVals(resultElems);
Expand Down
141 changes: 22 additions & 119 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,12 +204,28 @@ SmallVector<unsigned> getUniqueContigPerThread(Attribute layout,
}
return ret;
}

SmallVector<unsigned> getShapePerCTATile(Attribute layout,
ArrayRef<int64_t> tensorShape) {
SmallVector<unsigned> getShapePerCTATile(Attribute layout) {
if (auto distributedLayout =
mlir::dyn_cast<DistributedEncodingTrait>(layout)) {
return distributedLayout.getShapePerCTATile(tensorShape);
auto sizePerThread = distributedLayout.getSizePerThread();
auto threadsPerWarp = distributedLayout.getThreadsPerWarp();
auto warpsPerCTA = distributedLayout.getWarpsPerCTA();
// ThreadsPerWarp and warpsPerCTA does not align with this function for
// slice layout
if (auto sliceLayout = mlir::dyn_cast<SliceEncodingAttr>(layout)) {
threadsPerWarp = getThreadsPerWarp(sliceLayout.getParent());
threadsPerWarp.erase(threadsPerWarp.begin() + sliceLayout.getDim());
warpsPerCTA = getWarpsPerCTA(sliceLayout.getParent());
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a real bug of the upstream. But the NV and AMD code doesn't uses this legacy layout function.
Need to align our lowering code to the upstream once we enabled the LL for DPAS related layout.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please create an issue to track this.

warpsPerCTA.erase(warpsPerCTA.begin() + sliceLayout.getDim());
}
assert(sizePerThread.size() == threadsPerWarp.size() &&
sizePerThread.size() == warpsPerCTA.size());
SmallVector<unsigned> shape;
for (auto [size, thread, warp] :
llvm::zip(sizePerThread, threadsPerWarp, warpsPerCTA)) {
shape.push_back(size * thread * warp);
}
return shape;
} else {
llvm::report_fatal_error("getShapePerCTATile not implemented");
return SmallVector<unsigned>();
Expand Down Expand Up @@ -691,14 +707,6 @@ SmallVector<unsigned> BlockedEncodingAttr::getThreadOrder() const {
SmallVector<unsigned> BlockedEncodingAttr::getSizePerThread() const {
return SmallVector<unsigned>(getSizePerThread__());
}
SmallVector<unsigned>
BlockedEncodingAttr::getShapePerCTATile(ArrayRef<int64_t> tensorShape) const {
SmallVector<unsigned> shape;
for (unsigned d = 0, n = getOrder().size(); d < n; ++d)
shape.push_back(getSizePerThread()[d] * getThreadsPerWarp()[d] *
getWarpsPerCTA()[d]);
return shape;
}

template <class T>
SmallVector<T> SliceEncodingAttr::paddedShape(ArrayRef<T> shape) const {
Expand Down Expand Up @@ -800,12 +808,6 @@ SmallVector<unsigned> SliceEncodingAttr::getSizePerThread() const {
sizePerThread.erase(sizePerThread.begin() + getDim());
return sizePerThread;
}
SmallVector<unsigned>
SliceEncodingAttr::getShapePerCTATile(ArrayRef<int64_t> tensorShape) const {
SmallVector<unsigned> shape = ::getShapePerCTATile(getParent(), tensorShape);
shape.erase(shape.begin() + getDim());
return shape;
}

//

Expand Down Expand Up @@ -999,9 +1001,9 @@ unsigned DotOperandEncodingAttr::getTotalElemsPerThread(ArrayRef<int64_t> shape,
}
if (auto blockedLayout = mlir::dyn_cast<BlockedEncodingAttr>(getParent())) {
auto shapePerCTA = getShapePerCTA(*this, shape);
auto shapePerCTATile = ::getShapePerCTATile(blockedLayout);
auto shapePerCTATile = getShapePerCTATile(blockedLayout);
auto order = blockedLayout.getOrder();
auto sizePerThread = ::getSizePerThread(blockedLayout);
auto sizePerThread = blockedLayout.getSizePerThread();

int K = getOpIdx() == 0 ? shapePerCTA[1] : shapePerCTA[0];
int otherDim = getOpIdx() == 1 ? shapePerCTA[1] : shapePerCTA[0];
Expand Down Expand Up @@ -1072,19 +1074,6 @@ SmallVector<unsigned> DotOperandEncodingAttr::getThreadOrder() const {
/*kMajor*/ true);
}
}
SmallVector<unsigned> DotOperandEncodingAttr::getShapePerCTATile(
ArrayRef<int64_t> tensorShape) const {
auto parentLayout = getParent();
assert(parentLayout && "DotOperandEncodingAttr must have a parent");
if (auto parentMmaLayout = mlir::dyn_cast<MmaEncodingTrait>(parentLayout)) {
return parentMmaLayout.getShapePerCTATileForOperand(
tensorShape, getKWidth(), getOpIdx());
} else {
llvm::report_fatal_error(
"DotOperandEncodingAttr non-NvidiaMmaEncodingAttr parent not "
"supported yet");
}
}

LogicalResult DotOperandEncodingAttr::verify(
::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
Expand Down Expand Up @@ -1606,16 +1595,6 @@ void SharedEncodingAttr::print(AsmPrinter &printer) const {
//===----------------------------------------------------------------------===//
// TODO: there is a lot of common code with MmaEncoding here

SmallVector<unsigned>
AMDMfmaEncodingAttr::getShapePerCTATile(ArrayRef<int64_t> tensorShape) const {
auto warpsPerCTA = getWarpsPerCTA();
auto rank = warpsPerCTA.size();
SmallVector<unsigned> shapePerCTATile(warpsPerCTA.begin(), warpsPerCTA.end());
shapePerCTATile[rank - 1] *= getMDim();
shapePerCTATile[rank - 2] *= getNDim();
return shapePerCTATile;
}

SmallVector<unsigned> AMDMfmaEncodingAttr::getCTAsPerCGA() const {
return SmallVector<unsigned>(getCTALayout().getCTAsPerCGA());
}
Expand Down Expand Up @@ -1759,43 +1738,10 @@ AMDMfmaEncodingAttr::getSizePerThreadForOperand(int kWidth, int opIdx) const {
return sizePerThread;
}

SmallVector<unsigned>
AMDMfmaEncodingAttr::getShapePerCTATileForOperand(ArrayRef<int64_t> shape,
int kWidth, int opIdx) const {
assert(getMDim() == 32 || getMDim() == 16);
auto parentShapePerCTATile = getShapePerCTATile(shape);
auto rank = parentShapePerCTATile.size();
if (opIdx == 0) {
if (rank == 2)
return {parentShapePerCTATile[rank - 2], 32};
else
return {parentShapePerCTATile[0], parentShapePerCTATile[rank - 2], 32};
} else if (opIdx == 1) {
if (rank == 2)
return {32, parentShapePerCTATile[rank - 1]};
else
return {parentShapePerCTATile[0], 32, parentShapePerCTATile[rank - 1]};
} else {
llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1");
}
llvm_unreachable("DotOperandEncodingAttr opIdx must be 0 or 1");
}

//===----------------------------------------------------------------------===//
// Wmma encoding
//===----------------------------------------------------------------------===//

SmallVector<unsigned>
AMDWmmaEncodingAttr::getShapePerCTATile(ArrayRef<int64_t> tensorShape) const {
auto warpsPerCTA = getWarpsPerCTA();
auto rank = warpsPerCTA.size();
SmallVector<unsigned> shapePerCTATile(warpsPerCTA.begin(), warpsPerCTA.end());

auto mnkDim = getMNKDimPerInstr();
shapePerCTATile[rank - 2] *= mnkDim[0];
shapePerCTATile[rank - 1] *= mnkDim[1];
return shapePerCTATile;
}
SmallVector<unsigned> AMDWmmaEncodingAttr::getRepOrder() const {
auto rank = getWarpsPerCTA().size();
return getMatrixOrder(rank, /*rowMajor*/ true);
Expand Down Expand Up @@ -1860,21 +1806,6 @@ AMDWmmaEncodingAttr::getSizePerThreadForOperand(int kWidth, int opIdx) const {
return sizePerThread;
}

SmallVector<unsigned>
AMDWmmaEncodingAttr::getShapePerCTATileForOperand(ArrayRef<int64_t> shape,
int kWidth, int opIdx) const {
auto parentShapePerCTA = getShapePerCTATile(shape);
auto rank = shape.size();
assert(rank == 2);
if (opIdx == 0) {
return {parentShapePerCTA[0], static_cast<unsigned>(shape[1])};
} else if (opIdx == 1) {
return {static_cast<unsigned>(shape[0]), parentShapePerCTA[1]};
} else {
llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1");
}
}

unsigned AMDWmmaEncodingAttr::getTotalElemsPerThreadForOperand(
ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const {
auto rep = getRepForOperand(shape, eltTy, kWidth, opIdx);
Expand Down Expand Up @@ -1993,24 +1924,6 @@ SmallVector<unsigned> NvidiaMmaEncodingAttr::getSizePerThread() const {
llvm_unreachable("Unexpected mma version");
}

SmallVector<unsigned>
NvidiaMmaEncodingAttr::getShapePerCTATile(ArrayRef<int64_t> tensorShape) const {
if (isAmpere()) {
auto warpsPerCTA = getWarpsPerCTA();
auto rank = warpsPerCTA.size();
SmallVector<unsigned> shapePerCTATile(warpsPerCTA.begin(),
warpsPerCTA.end());
shapePerCTATile[rank - 1] *= 8;
shapePerCTATile[rank - 2] *= 16;
return shapePerCTATile;
}
if (isHopper()) {
auto instrShape = getInstrShape();
return {16 * getWarpsPerCTA()[0], instrShape[1] * getWarpsPerCTA()[1]};
}
llvm::report_fatal_error("Unexpected MMA layout version found");
}

SmallVector<unsigned>
NvidiaMmaEncodingAttr::getRepOrderForOperand(int opIdx) const {
auto rank = getWarpsPerCTA().size();
Expand Down Expand Up @@ -2051,16 +1964,6 @@ NvidiaMmaEncodingAttr::getRepForOperand(ArrayRef<int64_t> shape, int bitwidth,
}
}

SmallVector<unsigned> NvidiaMmaEncodingAttr::getShapePerCTATileForOperand(
ArrayRef<int64_t> shape, int kWidth, int opIdx) const {
assert(isAmpere() && "mmaLayout Hopper is not implemented yet");
auto shapePerCTATile = getShapePerCTATile(shape);
auto rank = shapePerCTATile.size();
auto kDim = opIdx == 0 ? rank - 1 : rank - 2;
// 4 threads * 2 subtiles
shapePerCTATile[kDim] = kWidth * 2 * 4;
return shapePerCTATile;
}
SmallVector<unsigned>
NvidiaMmaEncodingAttr::getSizePerThreadForOperand(int kWidth, int opIdx) const {
auto rank = getWarpsPerCTA().size();
Expand Down
3 changes: 1 addition & 2 deletions third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,7 @@ LogicalResult ExtractSliceOp::verify() {
}

auto srcShape = srcTy.getShape();
auto shapePerCTATile =
mlir::triton::gpu::getShapePerCTATile(srcLayout, srcShape);
auto shapePerCTATile = mlir::triton::gpu::getShapePerCTATile(srcLayout);
shapePerCTATile[0] =
std::min(static_cast<unsigned>(srcShape[0]), shapePerCTATile[0]);
shapePerCTATile[1] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ struct ExtractSliceOpConversion
auto order = triton::gpu::getOrder(srcLayout);

// Calculate valid total number of workers in each dimension
auto shapePerCTATile = triton::gpu::getShapePerCTATile(srcLayout, srcShape);
auto shapePerCTATile = triton::gpu::getShapePerCTATile(srcLayout);
shapePerCTATile[0] =
std::min(static_cast<unsigned>(srcShape[0]), shapePerCTATile[0]);
shapePerCTATile[1] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ Value redundantDataMask(Type valueTy, ConversionPatternRewriter &rewriter,
} else {
warpOrder = triton::gpu::getWarpOrder(layout);
}
auto shapePerCTATile = triton::gpu::getShapePerCTATile(layout, shape);
auto shapePerCTATile = triton::gpu::getShapePerCTATile(layout);
Value warpSize = i32_val(triton::gpu::getWarpSize(layout));
Value laneId = urem(tid, warpSize);
Value warpId = udiv(tid, warpSize);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ along the row (resp. col) dimension.
SmallVector<int64_t> getDPASRepetitions(ArrayRef<int64_t> shape, int opIdx) const;
SmallVector<unsigned> getSizePerThreadForOperand(int kWidth,unsigned opIdx) const;
SmallVector<unsigned> getElemsPerThreadForOperands(ArrayRef<int64_t> shape, Type eltTy, unsigned opIdx) const;
SmallVector<unsigned> getShapePerCTATileForOperand(ArrayRef<int64_t> shape, int kWidth, int opIdx) const;
SmallVector<unsigned> getRepOrderForOperand(int opIdx) const;
unsigned getTotalElemsPerThreadForOperand(ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;

Expand Down
Loading