Skip to content

Commit 8f30fe0

Browse files
Revert "[LAYOUTS] Kill getWarpsPerCTA(Attribute) and prefer LinearLayout-based impl (#6252)"
This reverts commit 593a1b5.
1 parent 1e98c47 commit 8f30fe0

File tree

11 files changed

+264
-43
lines changed

11 files changed

+264
-43
lines changed

include/triton/Dialect/TritonGPU/IR/Dialect.h

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -99,16 +99,14 @@ unsigned getTotalElemsPerThread(Attribute layout, ArrayRef<int64_t> shape);
9999

100100
SmallVector<unsigned> getElemsPerThread(Type type);
101101

102-
// Returns the number of warps per CTA that have access to non-replicated
103-
// elements of the tensor. E.g. for a blocked layout with sizePerThread = [1,
104-
// 1], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4] and tensor shape = [2, 2],
105-
// returns [1, 1], since the first warp has access to the full tensor, whereas
106-
// the other warps have access to replicated elements.
107-
SmallVector<unsigned> getWarpsPerCTA(Attribute layout,
108-
ArrayRef<int64_t> tensorShape);
109-
inline SmallVector<unsigned> getWarpsPerCTA(RankedTensorType type) {
110-
return getWarpsPerCTA(type.getEncoding(), type.getShape());
111-
}
102+
// Returns the number of threads per warp that may have access to replicated
103+
// elements. If you want non-replicated threads, use
104+
// getThreadsPerWarpWithUniqueData.
105+
SmallVector<unsigned> getThreadsPerWarp(Attribute layout);
106+
107+
// Returns the number of warps per CTA that may have access to replicated
108+
// elements. If you want non-replicated warps, use getWarpsPerCTAWithUniqueData.
109+
SmallVector<unsigned> getWarpsPerCTA(Attribute layout);
112110

113111
// Returns the number of contiguous elements of the logical tensor that each
114112
// thread has access to, on each dimension of the tensor. For a blocked layout
@@ -127,6 +125,14 @@ SmallVector<unsigned>
127125
getThreadsPerWarpWithUniqueData(Attribute layout,
128126
ArrayRef<int64_t> tensorShape);
129127

128+
// Returns the number of warps per CTA that have access to non-replicated
129+
// elements of the tensor. E.g. for a blocked layout with sizePerThread = [1,
130+
// 1], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4] and tensor shape = [2, 2],
131+
// returns [1, 1], since the first warp has access to the full tensor, whereas
132+
// the other warps have access to replicated elements.
133+
SmallVector<unsigned>
134+
getWarpsPerCTAWithUniqueData(Attribute layout, ArrayRef<int64_t> tensorShape);
135+
130136
// Returns the dimensions of the tensor from minor (fast-varying) to
131137
// major (slow-varying). For distributed layouts, this represents
132138
// the order of the elements within a thread.

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -597,6 +597,15 @@ We call each individual tile "rep".
597597
/*defaultImplementation=*/[{
598598
return toLinearEncoding($_self, shape).getElemsPerThread(shape);
599599
}]>,
600+
// Interface for the meta information about the multiple thread hierarchy.
601+
InterfaceMethod<"Get the shape of the warps per CTA.",
602+
"SmallVector<unsigned>",
603+
"getWarpsPerCTA">,
604+
605+
606+
InterfaceMethod<"Get the shape of the threads per warp",
607+
"SmallVector<unsigned>",
608+
"getThreadsPerWarp">,
600609
InterfaceMethod<"Convert to LinearLayout.",
601610
"LinearLayout",
602611
"toLinearLayout",
@@ -662,6 +671,8 @@ L(T) = [ {0,8} , {1,9} , {2,10}, {3,11}, {0,8} , {1, 9} , {2, 10}, {3, 11},
662671
SmallVector<unsigned> getCTAsPerCGA() const;
663672
SmallVector<unsigned> getCTAOrder() const;
664673
SmallVector<unsigned> getCTASplitNum() const;
674+
SmallVector<unsigned> getWarpsPerCTA() const;
675+
SmallVector<unsigned> getThreadsPerWarp() const;
665676

666677
LinearLayout toLinearLayout(ArrayRef<int64_t> shape) const;
667678

@@ -714,8 +725,6 @@ def LinearEncodingAttr : DistributedEncoding<"LinearEncoding", "linear_encoding"
714725
// If skipBroadcast is false, we count a base zero
715726
SmallVector<unsigned> basesPerDim(StringAttr dimName,
716727
bool skipBroadcast = true) const;
717-
SmallVector<unsigned> getThreadsPerWarp() const;
718-
SmallVector<unsigned> getWarpsPerCTA() const;
719728

720729
// [FIXME LL] Supports legacy behaviour. We should remove these functions
721730
SmallVector<unsigned> getShapePerCTATile() const;
@@ -825,8 +834,8 @@ for
825834
let parameters = (
826835
ins
827836
ArrayRefParameter<"unsigned">:$sizePerThread,
828-
ArrayRefParameter<"unsigned">:$threadsPerWarp,
829-
ArrayRefParameter<"unsigned">:$warpsPerCTA,
837+
ArrayRefParameter<"unsigned">:$threadsPerWarp__,
838+
ArrayRefParameter<"unsigned">:$warpsPerCTA__,
830839
ArrayRefParameter<"unsigned">:$order, // the fastest-changing axis first
831840

832841
// CTALayout is optional in the textual IR. If omitted, we infer it to be a
@@ -1030,7 +1039,7 @@ V [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129,
10301039
ins
10311040
"unsigned": $versionMajor,
10321041
"unsigned": $versionMinor,
1033-
ArrayRefParameter<"unsigned">:$warpsPerCTA,
1042+
ArrayRefParameter<"unsigned">:$warpsPerCTA__,
10341043
"unsigned":$MDim,
10351044
"unsigned":$NDim,
10361045
"bool":$isTransposed,
@@ -1151,7 +1160,7 @@ Row |
11511160
ins
11521161
"unsigned": $version,
11531162
"bool":$isTransposed,
1154-
ArrayRefParameter<"unsigned">:$warpsPerCTA,
1163+
ArrayRefParameter<"unsigned">:$warpsPerCTA__,
11551164
"CTALayoutAttr":$CTALayout
11561165
);
11571166

@@ -1257,7 +1266,7 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
12571266
ins
12581267
"unsigned":$versionMajor,
12591268
"unsigned":$versionMinor,
1260-
ArrayRefParameter<"unsigned">:$warpsPerCTA,
1269+
ArrayRefParameter<"unsigned">:$warpsPerCTA__,
12611270
"CTALayoutAttr":$CTALayout,
12621271
ArrayRefParameter<"unsigned">:$instrShape
12631272
);

lib/Analysis/Utility.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ bool shouldUseDistSmem(Attribute srcLayout, Attribute dstLayout) {
100100
}
101101

102102
unsigned ReduceOpHelper::getInterWarpSizeWithUniqueData() {
103-
return getWarpsPerCTA(srcEncoding, srcShape)[axis];
103+
return getWarpsPerCTAWithUniqueData(srcEncoding, srcShape)[axis];
104104
}
105105

106106
unsigned ReduceOpHelper::getIntraWarpSizeWithUniqueData() {
@@ -113,7 +113,7 @@ bool ReduceOpHelper::isWarpSynchronous() {
113113
// in order to remove this change.
114114
if (!srcEncoding)
115115
return true;
116-
return getWarpsPerCTA(srcEncoding, srcShape)[axis] == 1;
116+
return getWarpsPerCTAWithUniqueData(srcEncoding, srcShape)[axis] == 1;
117117
}
118118

119119
SmallVector<unsigned> ReduceOpHelper::getScratchRepShape() {
@@ -180,8 +180,8 @@ unsigned ScanLoweringHelper::getAxisNumWarpsWithUniqueData() {
180180

181181
unsigned ScanLoweringHelper::getAxisNumBlocks() {
182182
auto contigPerThread = getEncoding().getContigPerThread();
183-
auto threadsPerWarp = getEncoding().getThreadsPerWarp();
184-
auto warpsPerCTA = getEncoding().getWarpsPerCTA();
183+
auto threadsPerWarp = getThreadsPerWarp(getEncoding());
184+
auto warpsPerCTA = getWarpsPerCTA(getEncoding());
185185
unsigned axis = getAxis();
186186
return ceil<unsigned>(
187187
getShape()[axis],
@@ -190,8 +190,8 @@ unsigned ScanLoweringHelper::getAxisNumBlocks() {
190190

191191
unsigned ScanLoweringHelper::getNonAxisNumBlocks() {
192192
auto contigPerThread = getEncoding().getContigPerThread();
193-
auto threadsPerWarp = getEncoding().getThreadsPerWarp();
194-
auto warpsPerCTA = getEncoding().getWarpsPerCTA();
193+
auto threadsPerWarp = getThreadsPerWarp(getEncoding());
194+
auto warpsPerCTA = getWarpsPerCTA(getEncoding());
195195
auto rank = contigPerThread.size();
196196
unsigned axis = getAxis();
197197
unsigned numBlocks = 1;
@@ -527,8 +527,8 @@ unsigned ScanLoweringHelper::getAxisBlockStride() {
527527
auto order = getOrder();
528528
unsigned stride = 1;
529529
auto contigPerThread = getEncoding().getContigPerThread();
530-
auto threadsPerWarp = getEncoding().getThreadsPerWarp();
531-
auto warpsPerCTA = getEncoding().getWarpsPerCTA();
530+
auto threadsPerWarp = getThreadsPerWarp(getEncoding());
531+
auto warpsPerCTA = getWarpsPerCTA(getEncoding());
532532
for (unsigned dim : order) {
533533
if (dim == getAxis())
534534
return stride;

lib/Conversion/TritonGPUToLLVM/HistogramOpToLLVM.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,9 @@ static SmallVector<Value> computeCrossWarpHistogram(
8888
Value threadId, int numWarps) {
8989
auto b = TritonLLVMOpBuilder(loc, rewriter);
9090
SmallVector<Value> histogramValues;
91-
unsigned numWarpsWithUniqueData = mlir::triton::gpu::getWarpsPerCTA(
92-
srcType.getEncoding(), srcType.getShape())[0];
91+
unsigned numWarpsWithUniqueData =
92+
mlir::triton::gpu::getWarpsPerCTAWithUniqueData(srcType.getEncoding(),
93+
srcType.getShape())[0];
9394
Value laneId = b.and_(threadId, b.i32_val(numThreadPerWarp - 1));
9495
// Initialize the shared memory with zeros.
9596
int64_t numElementPerThread =

lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -392,8 +392,8 @@ ScanOpConversion::getDelinearizedIds(ConversionPatternRewriter &rewriter,
392392
unsigned axis = helper.getAxis();
393393
auto srcEncoding = helper.getEncoding();
394394

395-
auto threadsPerWarp = srcEncoding.getThreadsPerWarp();
396-
auto warpsPerCTA = srcEncoding.getWarpsPerCTA();
395+
auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcEncoding);
396+
auto warpsPerCTA = triton::gpu::getWarpsPerCTA(srcEncoding);
397397
auto [multiDimLaneId, isRepresentativeLane] =
398398
getMultiDimLaneId(rewriter, helper, laneId);
399399
auto [multiDimWarpId, isRepresentativeWarp] =

lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ struct TritonExpandDimsPattern
155155
retSizePerThread.insert(retSizePerThread.begin() + op.getAxis(), 1);
156156
auto retThreadsPerWarp = argEncoding.getThreadsPerWarp();
157157
retThreadsPerWarp.insert(retThreadsPerWarp.begin() + op.getAxis(), 1);
158-
auto retWarpsPerCTA = to_vector(argEncoding.getWarpsPerCTA());
158+
auto retWarpsPerCTA = argEncoding.getWarpsPerCTA();
159159
retWarpsPerCTA.insert(retWarpsPerCTA.begin() + op.getAxis(), 1);
160160
SmallVector<unsigned, 4> retOrder(retShape.size());
161161
std::iota(retOrder.begin(), retOrder.end(), 0);

0 commit comments

Comments
 (0)