Skip to content

Commit 593a1b5

Browse files
authored
[LAYOUTS] Kill getWarpsPerCTA(Attribute) and prefer LinearLayout-based impl (#6252)
We remove the manual implementations in favour of the generic LL implementation
1 parent e196446 commit 593a1b5

File tree

11 files changed

+36
-89
lines changed

11 files changed

+36
-89
lines changed

include/triton/Dialect/TritonGPU/IR/Dialect.h

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -99,9 +99,16 @@ unsigned getTotalElemsPerThread(Attribute layout, ArrayRef<int64_t> shape);
9999

100100
SmallVector<unsigned> getElemsPerThread(Type type);
101101

102-
// Returns the number of warps per CTA that may have access to replicated
103-
// elements. If you want non-replicated warps, use getWarpsPerCTAWithUniqueData.
104-
SmallVector<unsigned> getWarpsPerCTA(Attribute layout);
102+
// Returns the number of warps per CTA that have access to non-replicated
103+
// elements of the tensor. E.g. for a blocked layout with sizePerThread = [1,
104+
// 1], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4] and tensor shape = [2, 2],
105+
// returns [1, 1], since the first warp has access to the full tensor, whereas
106+
// the other warps have access to replicated elements.
107+
SmallVector<unsigned> getWarpsPerCTA(Attribute layout,
108+
ArrayRef<int64_t> tensorShape);
109+
inline SmallVector<unsigned> getWarpsPerCTA(RankedTensorType type) {
110+
return getWarpsPerCTA(type.getEncoding(), type.getShape());
111+
}
105112

106113
// Returns the number of contiguous elements of the logical tensor that each
107114
// thread has access to, on each dimension of the tensor. For a blocked layout
@@ -122,14 +129,6 @@ inline SmallVector<unsigned> getThreadsPerWarp(RankedTensorType type) {
122129
return getThreadsPerWarp(type.getEncoding(), type.getShape());
123130
}
124131

125-
// Returns the number of warps per CTA that have access to non-replicated
126-
// elements of the tensor. E.g. for a blocked layout with sizePerThread = [1,
127-
// 1], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4] and tensor shape = [2, 2],
128-
// returns [1, 1], since the first warp has access to the full tensor, whereas
129-
// the other warps have access to replicated elements.
130-
SmallVector<unsigned>
131-
getWarpsPerCTAWithUniqueData(Attribute layout, ArrayRef<int64_t> tensorShape);
132-
133132
// Returns the dimensions of the tensor from minor (fast-varying) to
134133
// major (slow-varying). For distributed layouts, this represents
135134
// the order of the elements within a thread.

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -597,11 +597,6 @@ We call each individual tile "rep".
597597
/*defaultImplementation=*/[{
598598
return toLinearEncoding($_self, shape).getElemsPerThread(shape);
599599
}]>,
600-
// Interface for the meta information about the multiple thread hierarchy.
601-
InterfaceMethod<"Get the shape of the warps per CTA.",
602-
"SmallVector<unsigned>",
603-
"getWarpsPerCTA">,
604-
605600
InterfaceMethod<"Convert to LinearLayout.",
606601
"LinearLayout",
607602
"toLinearLayout",
@@ -653,7 +648,6 @@ L(T) = [ {0,8} , {1,9} , {2,10}, {3,11}, {0,8} , {1, 9} , {2, 10}, {3, 11},
653648
SmallVector<unsigned> getCTAsPerCGA() const;
654649
SmallVector<unsigned> getCTAOrder() const;
655650
SmallVector<unsigned> getCTASplitNum() const;
656-
SmallVector<unsigned> getWarpsPerCTA() const;
657651

658652
LinearLayout toLinearLayout(ArrayRef<int64_t> shape) const;
659653
}];
@@ -703,6 +697,7 @@ def LinearEncodingAttr : DistributedEncoding<"LinearEncoding", "linear_encoding"
703697
SmallVector<unsigned> basesPerDim(StringAttr dimName,
704698
bool skipBroadcast = true) const;
705699
SmallVector<unsigned> getThreadsPerWarp() const;
700+
SmallVector<unsigned> getWarpsPerCTA() const;
706701

707702
// [FIXME LL] Supports legacy behaviour. We should remove these functions
708703
SmallVector<unsigned> getShapePerCTATile() const;
@@ -813,7 +808,7 @@ for
813808
ins
814809
ArrayRefParameter<"unsigned">:$sizePerThread,
815810
ArrayRefParameter<"unsigned">:$threadsPerWarp,
816-
ArrayRefParameter<"unsigned">:$warpsPerCTA__,
811+
ArrayRefParameter<"unsigned">:$warpsPerCTA,
817812
ArrayRefParameter<"unsigned">:$order, // the fastest-changing axis first
818813

819814
// CTALayout is optional in the textual IR. If omitted, we infer it to be a
@@ -1012,7 +1007,7 @@ V [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129,
10121007
ins
10131008
"unsigned": $versionMajor,
10141009
"unsigned": $versionMinor,
1015-
ArrayRefParameter<"unsigned">:$warpsPerCTA__,
1010+
ArrayRefParameter<"unsigned">:$warpsPerCTA,
10161011
"unsigned":$MDim,
10171012
"unsigned":$NDim,
10181013
"bool":$isTransposed,
@@ -1132,7 +1127,7 @@ Row |
11321127
ins
11331128
"unsigned": $version,
11341129
"bool":$isTransposed,
1135-
ArrayRefParameter<"unsigned">:$warpsPerCTA__,
1130+
ArrayRefParameter<"unsigned">:$warpsPerCTA,
11361131
"CTALayoutAttr":$CTALayout
11371132
);
11381133

@@ -1237,7 +1232,7 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
12371232
ins
12381233
"unsigned":$versionMajor,
12391234
"unsigned":$versionMinor,
1240-
ArrayRefParameter<"unsigned">:$warpsPerCTA__,
1235+
ArrayRefParameter<"unsigned">:$warpsPerCTA,
12411236
"CTALayoutAttr":$CTALayout,
12421237
ArrayRefParameter<"unsigned">:$instrShape
12431238
);

lib/Analysis/Utility.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -99,15 +99,15 @@ bool shouldUseDistSmem(Attribute srcLayout, Attribute dstLayout) {
9999
}
100100

101101
unsigned ReduceOpHelper::getInterWarpSizeWithUniqueData() {
102-
return getWarpsPerCTAWithUniqueData(srcEncoding, srcShape)[axis];
102+
return getWarpsPerCTA(srcEncoding, srcShape)[axis];
103103
}
104104

105105
unsigned ReduceOpHelper::getIntraWarpSizeWithUniqueData() {
106106
return getThreadsPerWarp(srcEncoding, srcShape)[axis];
107107
}
108108

109109
bool ReduceOpHelper::isWarpSynchronous() {
110-
return getWarpsPerCTAWithUniqueData(srcEncoding, srcShape)[axis] == 1;
110+
return getWarpsPerCTA(srcEncoding, srcShape)[axis] == 1;
111111
}
112112

113113
SmallVector<unsigned> ReduceOpHelper::getScratchRepShape() {
@@ -175,7 +175,7 @@ unsigned ScanLoweringHelper::getAxisNumWarpsWithUniqueData() {
175175
unsigned ScanLoweringHelper::getAxisNumBlocks() {
176176
auto contigPerThread = getEncoding().getContigPerThread();
177177
auto threadsPerWarp = getEncoding().getThreadsPerWarp();
178-
auto warpsPerCTA = getWarpsPerCTA(getEncoding());
178+
auto warpsPerCTA = getEncoding().getWarpsPerCTA();
179179
unsigned axis = getAxis();
180180
return ceil<unsigned>(
181181
getShape()[axis],
@@ -185,7 +185,7 @@ unsigned ScanLoweringHelper::getAxisNumBlocks() {
185185
unsigned ScanLoweringHelper::getNonAxisNumBlocks() {
186186
auto contigPerThread = getEncoding().getContigPerThread();
187187
auto threadsPerWarp = getEncoding().getThreadsPerWarp();
188-
auto warpsPerCTA = getWarpsPerCTA(getEncoding());
188+
auto warpsPerCTA = getEncoding().getWarpsPerCTA();
189189
auto rank = contigPerThread.size();
190190
unsigned axis = getAxis();
191191
unsigned numBlocks = 1;
@@ -522,7 +522,7 @@ unsigned ScanLoweringHelper::getAxisBlockStride() {
522522
unsigned stride = 1;
523523
auto contigPerThread = getEncoding().getContigPerThread();
524524
auto threadsPerWarp = getEncoding().getThreadsPerWarp();
525-
auto warpsPerCTA = getWarpsPerCTA(getEncoding());
525+
auto warpsPerCTA = getEncoding().getWarpsPerCTA();
526526
for (unsigned dim : order) {
527527
if (dim == getAxis())
528528
return stride;

lib/Conversion/TritonGPUToLLVM/HistogramOpToLLVM.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,8 @@ static SmallVector<Value> computeCrossWarpHistogram(
8787
Value threadId, int numWarps) {
8888
auto b = TritonLLVMOpBuilder(loc, rewriter);
8989
SmallVector<Value> histogramValues;
90-
unsigned numWarpsWithUniqueData =
91-
mlir::triton::gpu::getWarpsPerCTAWithUniqueData(srcType.getEncoding(),
92-
srcType.getShape())[0];
90+
unsigned numWarpsWithUniqueData = mlir::triton::gpu::getWarpsPerCTA(
91+
srcType.getEncoding(), srcType.getShape())[0];
9392
Value laneId = b.and_(threadId, b.i32_val(numThreadPerWarp - 1));
9493
// Initialize the shared memory with zeros.
9594
int64_t numElementPerThread =

lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -393,7 +393,7 @@ ScanOpConversion::getDelinearizedIds(ConversionPatternRewriter &rewriter,
393393
auto srcEncoding = helper.getEncoding();
394394

395395
auto threadsPerWarp = srcEncoding.getThreadsPerWarp();
396-
auto warpsPerCTA = triton::gpu::getWarpsPerCTA(srcEncoding);
396+
auto warpsPerCTA = srcEncoding.getWarpsPerCTA();
397397
auto [multiDimLaneId, isRepresentativeLane] =
398398
getMultiDimLaneId(rewriter, helper, laneId);
399399
auto [multiDimWarpId, isRepresentativeWarp] =

lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ struct TritonExpandDimsPattern
155155
retSizePerThread.insert(retSizePerThread.begin() + op.getAxis(), 1);
156156
auto retThreadsPerWarp = to_vector(argEncoding.getThreadsPerWarp());
157157
retThreadsPerWarp.insert(retThreadsPerWarp.begin() + op.getAxis(), 1);
158-
auto retWarpsPerCTA = argEncoding.getWarpsPerCTA();
158+
auto retWarpsPerCTA = to_vector(argEncoding.getWarpsPerCTA());
159159
retWarpsPerCTA.insert(retWarpsPerCTA.begin() + op.getAxis(), 1);
160160
SmallVector<unsigned, 4> retOrder(retShape.size());
161161
std::iota(retOrder.begin(), retOrder.end(), 0);

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 6 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -85,18 +85,8 @@ SmallVector<unsigned> getThreadsPerWarp(Attribute layout,
8585
return toLinearEncoding(layout, shape).getThreadsPerWarp();
8686
}
8787

88-
SmallVector<unsigned> getWarpsPerCTA(Attribute layout) {
89-
if (auto distributedLayout =
90-
mlir::dyn_cast<DistributedEncodingTrait>(layout)) {
91-
return distributedLayout.getWarpsPerCTA();
92-
}
93-
94-
llvm::report_fatal_error("getWarpsPerCTA not implemented");
95-
return SmallVector<unsigned>();
96-
}
97-
98-
SmallVector<unsigned> getWarpsPerCTAWithUniqueData(Attribute layout,
99-
ArrayRef<int64_t> shape) {
88+
SmallVector<unsigned> getWarpsPerCTA(Attribute layout,
89+
ArrayRef<int64_t> shape) {
10090
return toLinearEncoding(layout, shape).getWarpsPerCTA();
10191
}
10292

@@ -578,9 +568,6 @@ SmallVector<unsigned> BlockedEncodingAttr::getCTAOrder() const {
578568
SmallVector<unsigned> BlockedEncodingAttr::getCTASplitNum() const {
579569
return SmallVector<unsigned>(getCTALayout().getCTASplitNum());
580570
}
581-
SmallVector<unsigned> BlockedEncodingAttr::getWarpsPerCTA() const {
582-
return SmallVector<unsigned>(getWarpsPerCTA__());
583-
}
584571

585572
template <class T>
586573
SmallVector<T> SliceEncodingAttr::paddedShape(ArrayRef<T> shape) const {
@@ -637,15 +624,6 @@ SmallVector<unsigned> SliceEncodingAttr::getCTAsPerCGA() const {
637624
llvm::report_fatal_error(
638625
"getCTAsPerCGA for SliceEncodingAttr is not well-defined");
639626
}
640-
SmallVector<unsigned> SliceEncodingAttr::getWarpsPerCTA() const {
641-
auto parent = getParent();
642-
auto parentWarpsPerCTA = ::getWarpsPerCTA(parent);
643-
SmallVector<unsigned> warpsPerCTA = parentWarpsPerCTA;
644-
warpsPerCTA.erase(warpsPerCTA.begin() + getDim());
645-
int32_t nextDim = getDim() < warpsPerCTA.size() ? getDim() : getDim() - 1;
646-
warpsPerCTA[nextDim] *= parentWarpsPerCTA[getDim()];
647-
return warpsPerCTA;
648-
}
649627

650628
// Wmma encoding
651629

@@ -701,14 +679,6 @@ SmallVector<unsigned> DotOperandEncodingAttr::getCTASplitNum() const {
701679
res[kDim] = 1;
702680
return res;
703681
}
704-
SmallVector<unsigned> DotOperandEncodingAttr::getWarpsPerCTA() const {
705-
auto distributedLayout = mlir::cast<DistributedEncodingTrait>(getParent());
706-
auto warps = distributedLayout.getWarpsPerCTA();
707-
auto rank = warps.size();
708-
auto kDim = getOpIdx() == 0 ? rank - 1 : rank - 2;
709-
warps[kDim] = 1;
710-
return warps;
711-
}
712682

713683
LogicalResult DotOperandEncodingAttr::verify(
714684
::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
@@ -1306,7 +1276,7 @@ void NvidiaMmaEncodingAttr::print(AsmPrinter &printer) const {
13061276
<< ", warpsPerCTA = [" << ArrayRef(getWarpsPerCTA()) << "]";
13071277

13081278
maybePrintCTALayout(getContext(), printer, getCTALayout(),
1309-
/*rank=*/getWarpsPerCTA().size());
1279+
/*rank=*/getRank());
13101280

13111281
printer << ", instrShape = [" << getInstrShape() << "]}>";
13121282
}
@@ -1386,11 +1356,11 @@ void AMDMfmaEncodingAttr::print(AsmPrinter &printer) const {
13861356
printer << "<{"
13871357
<< "versionMajor = " << getVersionMajor() //
13881358
<< ", versionMinor = " << getVersionMinor() //
1389-
<< ", warpsPerCTA = [" << ArrayRef(getWarpsPerCTA()) << "]" //
1359+
<< ", warpsPerCTA = [" << getWarpsPerCTA() << "]" //
13901360
<< ", instrShape = [" << ArrayRef{getMDim(), getNDim()} << "]" //
13911361
<< ", isTransposed = " << getIsTransposed();
13921362
maybePrintCTALayout(getContext(), printer, getCTALayout(),
1393-
/*rank=*/getWarpsPerCTA().size());
1363+
/*rank=*/getRank());
13941364
printer << "}>";
13951365
}
13961366

@@ -1721,9 +1691,6 @@ SmallVector<unsigned> AMDMfmaEncodingAttr::getCTAOrder() const {
17211691
SmallVector<unsigned> AMDMfmaEncodingAttr::getCTASplitNum() const {
17221692
return SmallVector<unsigned>(getCTALayout().getCTASplitNum());
17231693
}
1724-
SmallVector<unsigned> AMDMfmaEncodingAttr::getWarpsPerCTA() const {
1725-
return SmallVector<unsigned>(getWarpsPerCTA__());
1726-
}
17271694

17281695
SmallVector<int64_t>
17291696
AMDMfmaEncodingAttr::getInstrShapeForOperand(int kWidth, int opIdx) const {
@@ -1842,9 +1809,6 @@ SmallVector<unsigned> AMDWmmaEncodingAttr::getCTAOrder() const {
18421809
SmallVector<unsigned> AMDWmmaEncodingAttr::getCTASplitNum() const {
18431810
return SmallVector<unsigned>(getCTALayout().getCTASplitNum());
18441811
}
1845-
SmallVector<unsigned> AMDWmmaEncodingAttr::getWarpsPerCTA() const {
1846-
return SmallVector<unsigned>(getWarpsPerCTA__());
1847-
}
18481812

18491813
SmallVector<int64_t> AMDWmmaEncodingAttr::getElemsPerInstrForOperands() const {
18501814
return {16, 16};
@@ -1916,9 +1880,6 @@ SmallVector<unsigned> NvidiaMmaEncodingAttr::getCTAOrder() const {
19161880
SmallVector<unsigned> NvidiaMmaEncodingAttr::getCTASplitNum() const {
19171881
return SmallVector<unsigned>(getCTALayout().getCTASplitNum());
19181882
}
1919-
SmallVector<unsigned> NvidiaMmaEncodingAttr::getWarpsPerCTA() const {
1920-
return SmallVector<unsigned>(getWarpsPerCTA__());
1921-
}
19221883

19231884
SmallVector<unsigned>
19241885
NvidiaMmaEncodingAttr::getRepOrderForOperand(int opIdx) const {
@@ -1933,7 +1894,7 @@ NvidiaMmaEncodingAttr::getRepForOperand(ArrayRef<int64_t> shape, int bitwidth,
19331894
"kWidth must be >= 32 / bitwidth for this function to be well-defined");
19341895
auto rank = shape.size();
19351896
// Broadcast long K
1936-
auto warpsPerCTA = getWarpsPerCTA();
1897+
auto warpsPerCTA = to_vector(getWarpsPerCTA());
19371898
auto kDim = opIdx == 0 ? rank - 1 : rank - 2;
19381899
warpsPerCTA[kDim] = 1;
19391900

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,7 @@ static LinearLayout broadcastedDotOperandLayout(MLIRContext *ctx,
364364
LinearLayout
365365
AMDMfmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
366366
int rank = shape.size();
367-
assert(rank == getWarpsPerCTA().size());
367+
assert(rank == getRank());
368368

369369
bool hasBatchDim = rank == 3;
370370
int mIndex = 0 + hasBatchDim;
@@ -712,7 +712,7 @@ LinearLayout mfmaDotToLinearLayout(DotOperandEncodingAttr dotMfmaLayout,
712712
LinearLayout
713713
AMDWmmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
714714
int rank = shape.size();
715-
assert(rank == getWarpsPerCTA().size());
715+
assert(rank == getRank());
716716

717717
bool hasBatchDim = rank == 3;
718718
int mIndex = 0 + hasBatchDim;

lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ SmallVector<unsigned> warpsPerTileV2(DotOp dotOp, const ArrayRef<int64_t> shape,
7272
}
7373
if (auto mmaEncoding =
7474
dyn_cast<NvidiaMmaEncodingAttr>(resTy.getEncoding())) {
75-
return getWarpsPerCTA(mmaEncoding);
75+
return to_vector(mmaEncoding.getWarpsPerCTA());
7676
}
7777
hasChainedDot = true;
7878
}

third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUtility.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,10 @@ createTmpLayout(triton::gpu::DistributedEncodingTrait layout,
7272
src.getKWidth());
7373
}
7474
if (auto src = dyn_cast<triton::gpu::SliceEncodingAttr>(layout)) {
75-
// TODO: think of a way to construct slice layouts based on warpsPerCTA
76-
// argument
77-
auto parentWarpsPerCTA = triton::gpu::getWarpsPerCTA(src.getParent());
75+
auto warps = to_vector(warpsPerCTA);
76+
warps.insert(warps.begin() + src.getDim(), 1);
7877
return triton::gpu::SliceEncodingAttr::get(
79-
ctx, src.getDim(), createTmpLayout(src.getParent(), parentWarpsPerCTA));
78+
ctx, src.getDim(), createTmpLayout(src.getParent(), warps));
8079
}
8180
// TODO: support linear layout if needed.
8281
if (isa<triton::gpu::LinearEncodingAttr>(layout))

0 commit comments

Comments
 (0)