Skip to content

Commit 1e98c47

Browse files
Merge commit '593a1b5b7b55a4b679d0955d6a2b5441e930595b'
2 parents 77b8626 + 593a1b5 commit 1e98c47

File tree

11 files changed

+43
-264
lines changed

11 files changed

+43
-264
lines changed

include/triton/Dialect/TritonGPU/IR/Dialect.h

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -99,14 +99,16 @@ unsigned getTotalElemsPerThread(Attribute layout, ArrayRef<int64_t> shape);
9999

100100
SmallVector<unsigned> getElemsPerThread(Type type);
101101

102-
// Returns the number of threads per warp that may have access to replicated
103-
// elements. If you want non-replicated threads, use
104-
// getThreadsPerWarpWithUniqueData.
105-
SmallVector<unsigned> getThreadsPerWarp(Attribute layout);
106-
107-
// Returns the number of warps per CTA that may have access to replicated
108-
// elements. If you want non-replicated warps, use getWarpsPerCTAWithUniqueData.
109-
SmallVector<unsigned> getWarpsPerCTA(Attribute layout);
102+
// Returns the number of warps per CTA that have access to non-replicated
103+
// elements of the tensor. E.g. for a blocked layout with sizePerThread = [1,
104+
// 1], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4] and tensor shape = [2, 2],
105+
// returns [1, 1], since the first warp has access to the full tensor, whereas
106+
// the other warps have access to replicated elements.
107+
SmallVector<unsigned> getWarpsPerCTA(Attribute layout,
108+
ArrayRef<int64_t> tensorShape);
109+
inline SmallVector<unsigned> getWarpsPerCTA(RankedTensorType type) {
110+
return getWarpsPerCTA(type.getEncoding(), type.getShape());
111+
}
110112

111113
// Returns the number of contiguous elements of the logical tensor that each
112114
// thread has access to, on each dimension of the tensor. For a blocked layout
@@ -125,14 +127,6 @@ SmallVector<unsigned>
125127
getThreadsPerWarpWithUniqueData(Attribute layout,
126128
ArrayRef<int64_t> tensorShape);
127129

128-
// Returns the number of warps per CTA that have access to non-replicated
129-
// elements of the tensor. E.g. for a blocked layout with sizePerThread = [1,
130-
// 1], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4] and tensor shape = [2, 2],
131-
// returns [1, 1], since the first warp has access to the full tensor, whereas
132-
// the other warps have access to replicated elements.
133-
SmallVector<unsigned>
134-
getWarpsPerCTAWithUniqueData(Attribute layout, ArrayRef<int64_t> tensorShape);
135-
136130
// Returns the dimensions of the tensor from minor (fast-varying) to
137131
// major (slow-varying). For distributed layouts, this represents
138132
// the order of the elements within a thread.

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -597,15 +597,6 @@ We call each individual tile "rep".
597597
/*defaultImplementation=*/[{
598598
return toLinearEncoding($_self, shape).getElemsPerThread(shape);
599599
}]>,
600-
// Interface for the meta information about the multiple thread hierarchy.
601-
InterfaceMethod<"Get the shape of the warps per CTA.",
602-
"SmallVector<unsigned>",
603-
"getWarpsPerCTA">,
604-
605-
606-
InterfaceMethod<"Get the shape of the threads per warp",
607-
"SmallVector<unsigned>",
608-
"getThreadsPerWarp">,
609600
InterfaceMethod<"Convert to LinearLayout.",
610601
"LinearLayout",
611602
"toLinearLayout",
@@ -671,8 +662,6 @@ L(T) = [ {0,8} , {1,9} , {2,10}, {3,11}, {0,8} , {1, 9} , {2, 10}, {3, 11},
671662
SmallVector<unsigned> getCTAsPerCGA() const;
672663
SmallVector<unsigned> getCTAOrder() const;
673664
SmallVector<unsigned> getCTASplitNum() const;
674-
SmallVector<unsigned> getWarpsPerCTA() const;
675-
SmallVector<unsigned> getThreadsPerWarp() const;
676665

677666
LinearLayout toLinearLayout(ArrayRef<int64_t> shape) const;
678667

@@ -725,6 +714,8 @@ def LinearEncodingAttr : DistributedEncoding<"LinearEncoding", "linear_encoding"
725714
// If skipBroadcast is false, we count a base zero
726715
SmallVector<unsigned> basesPerDim(StringAttr dimName,
727716
bool skipBroadcast = true) const;
717+
SmallVector<unsigned> getThreadsPerWarp() const;
718+
SmallVector<unsigned> getWarpsPerCTA() const;
728719

729720
// [FIXME LL] Supports legacy behaviour. We should remove these functions
730721
SmallVector<unsigned> getShapePerCTATile() const;
@@ -834,8 +825,8 @@ for
834825
let parameters = (
835826
ins
836827
ArrayRefParameter<"unsigned">:$sizePerThread,
837-
ArrayRefParameter<"unsigned">:$threadsPerWarp__,
838-
ArrayRefParameter<"unsigned">:$warpsPerCTA__,
828+
ArrayRefParameter<"unsigned">:$threadsPerWarp,
829+
ArrayRefParameter<"unsigned">:$warpsPerCTA,
839830
ArrayRefParameter<"unsigned">:$order, // the fastest-changing axis first
840831

841832
// CTALayout is optional in the textual IR. If omitted, we infer it to be a
@@ -1039,7 +1030,7 @@ V [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129,
10391030
ins
10401031
"unsigned": $versionMajor,
10411032
"unsigned": $versionMinor,
1042-
ArrayRefParameter<"unsigned">:$warpsPerCTA__,
1033+
ArrayRefParameter<"unsigned">:$warpsPerCTA,
10431034
"unsigned":$MDim,
10441035
"unsigned":$NDim,
10451036
"bool":$isTransposed,
@@ -1160,7 +1151,7 @@ Row |
11601151
ins
11611152
"unsigned": $version,
11621153
"bool":$isTransposed,
1163-
ArrayRefParameter<"unsigned">:$warpsPerCTA__,
1154+
ArrayRefParameter<"unsigned">:$warpsPerCTA,
11641155
"CTALayoutAttr":$CTALayout
11651156
);
11661157

@@ -1266,7 +1257,7 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
12661257
ins
12671258
"unsigned":$versionMajor,
12681259
"unsigned":$versionMinor,
1269-
ArrayRefParameter<"unsigned">:$warpsPerCTA__,
1260+
ArrayRefParameter<"unsigned">:$warpsPerCTA,
12701261
"CTALayoutAttr":$CTALayout,
12711262
ArrayRefParameter<"unsigned">:$instrShape
12721263
);

lib/Analysis/Utility.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ bool shouldUseDistSmem(Attribute srcLayout, Attribute dstLayout) {
100100
}
101101

102102
unsigned ReduceOpHelper::getInterWarpSizeWithUniqueData() {
103-
return getWarpsPerCTAWithUniqueData(srcEncoding, srcShape)[axis];
103+
return getWarpsPerCTA(srcEncoding, srcShape)[axis];
104104
}
105105

106106
unsigned ReduceOpHelper::getIntraWarpSizeWithUniqueData() {
@@ -113,7 +113,7 @@ bool ReduceOpHelper::isWarpSynchronous() {
113113
// in order to remove this change.
114114
if (!srcEncoding)
115115
return true;
116-
return getWarpsPerCTAWithUniqueData(srcEncoding, srcShape)[axis] == 1;
116+
return getWarpsPerCTA(srcEncoding, srcShape)[axis] == 1;
117117
}
118118

119119
SmallVector<unsigned> ReduceOpHelper::getScratchRepShape() {
@@ -180,8 +180,8 @@ unsigned ScanLoweringHelper::getAxisNumWarpsWithUniqueData() {
180180

181181
unsigned ScanLoweringHelper::getAxisNumBlocks() {
182182
auto contigPerThread = getEncoding().getContigPerThread();
183-
auto threadsPerWarp = getThreadsPerWarp(getEncoding());
184-
auto warpsPerCTA = getWarpsPerCTA(getEncoding());
183+
auto threadsPerWarp = getEncoding().getThreadsPerWarp();
184+
auto warpsPerCTA = getEncoding().getWarpsPerCTA();
185185
unsigned axis = getAxis();
186186
return ceil<unsigned>(
187187
getShape()[axis],
@@ -190,8 +190,8 @@ unsigned ScanLoweringHelper::getAxisNumBlocks() {
190190

191191
unsigned ScanLoweringHelper::getNonAxisNumBlocks() {
192192
auto contigPerThread = getEncoding().getContigPerThread();
193-
auto threadsPerWarp = getThreadsPerWarp(getEncoding());
194-
auto warpsPerCTA = getWarpsPerCTA(getEncoding());
193+
auto threadsPerWarp = getEncoding().getThreadsPerWarp();
194+
auto warpsPerCTA = getEncoding().getWarpsPerCTA();
195195
auto rank = contigPerThread.size();
196196
unsigned axis = getAxis();
197197
unsigned numBlocks = 1;
@@ -527,8 +527,8 @@ unsigned ScanLoweringHelper::getAxisBlockStride() {
527527
auto order = getOrder();
528528
unsigned stride = 1;
529529
auto contigPerThread = getEncoding().getContigPerThread();
530-
auto threadsPerWarp = getThreadsPerWarp(getEncoding());
531-
auto warpsPerCTA = getWarpsPerCTA(getEncoding());
530+
auto threadsPerWarp = getEncoding().getThreadsPerWarp();
531+
auto warpsPerCTA = getEncoding().getWarpsPerCTA();
532532
for (unsigned dim : order) {
533533
if (dim == getAxis())
534534
return stride;

lib/Conversion/TritonGPUToLLVM/HistogramOpToLLVM.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,8 @@ static SmallVector<Value> computeCrossWarpHistogram(
8888
Value threadId, int numWarps) {
8989
auto b = TritonLLVMOpBuilder(loc, rewriter);
9090
SmallVector<Value> histogramValues;
91-
unsigned numWarpsWithUniqueData =
92-
mlir::triton::gpu::getWarpsPerCTAWithUniqueData(srcType.getEncoding(),
93-
srcType.getShape())[0];
91+
unsigned numWarpsWithUniqueData = mlir::triton::gpu::getWarpsPerCTA(
92+
srcType.getEncoding(), srcType.getShape())[0];
9493
Value laneId = b.and_(threadId, b.i32_val(numThreadPerWarp - 1));
9594
// Initialize the shared memory with zeros.
9695
int64_t numElementPerThread =

lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -392,8 +392,8 @@ ScanOpConversion::getDelinearizedIds(ConversionPatternRewriter &rewriter,
392392
unsigned axis = helper.getAxis();
393393
auto srcEncoding = helper.getEncoding();
394394

395-
auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcEncoding);
396-
auto warpsPerCTA = triton::gpu::getWarpsPerCTA(srcEncoding);
395+
auto threadsPerWarp = srcEncoding.getThreadsPerWarp();
396+
auto warpsPerCTA = srcEncoding.getWarpsPerCTA();
397397
auto [multiDimLaneId, isRepresentativeLane] =
398398
getMultiDimLaneId(rewriter, helper, laneId);
399399
auto [multiDimWarpId, isRepresentativeWarp] =

lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ struct TritonExpandDimsPattern
155155
retSizePerThread.insert(retSizePerThread.begin() + op.getAxis(), 1);
156156
auto retThreadsPerWarp = argEncoding.getThreadsPerWarp();
157157
retThreadsPerWarp.insert(retThreadsPerWarp.begin() + op.getAxis(), 1);
158-
auto retWarpsPerCTA = argEncoding.getWarpsPerCTA();
158+
auto retWarpsPerCTA = to_vector(argEncoding.getWarpsPerCTA());
159159
retWarpsPerCTA.insert(retWarpsPerCTA.begin() + op.getAxis(), 1);
160160
SmallVector<unsigned, 4> retOrder(retShape.size());
161161
std::iota(retOrder.begin(), retOrder.end(), 0);

0 commit comments

Comments
 (0)