Skip to content

Commit 6a6ed52

Browse files
authored
[BACKEND] Generalize getShapePerCTA (#7580)
1 parent 167ed28 commit 6a6ed52

File tree

3 files changed

+49
-23
lines changed

3 files changed

+49
-23
lines changed

include/triton/Dialect/TritonGPU/IR/Dialect.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -221,13 +221,19 @@ SmallVector<unsigned> getCTAOrder(Attribute layout);
221221
// [FIXME LL] Kill this function
222222
SmallVector<unsigned> getShapePerCTATile(RankedTensorType layout);
223223

224-
// Returns the "logical" shape per CTA
224+
// Returns the "logical" shape per CTA.
225+
// When shape and CTASplitNum have different number of dimensions, we assume
226+
// only the last N between common dimensions are split.
227+
// Example1: shape = [2, 4, 8], CTASplitNum = [2, 2], ret = [2, 2, 4].
228+
// It can be caused by pipelining.
229+
// Example2: shape = [2, 4], CTASplitNum = [2, 2, 2], ret = [1, 2].
230+
// It can be caused by memory slicing.
225231
SmallVector<int64_t> getShapePerCTA(ArrayRef<unsigned> CTASplitNum,
226232
ArrayRef<int64_t> shape);
227233
SmallVector<int64_t> getShapePerCTA(Attribute layout, ArrayRef<int64_t> shape);
228234
SmallVector<int64_t> getShapePerCTA(Type type);
229235

230-
// Returns the shape per CTA, which is "physically" allocated
236+
// Returns the shape per CTA, which is "physically" allocated.
231237
// Such shapes may be bigger than the logical one due to, for example, padding
232238
// in shared memory.
233239
SmallVector<int64_t> getAllocationShapePerCTA(Attribute layout,

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 9 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -292,34 +292,22 @@ SmallVector<unsigned> getCTAOrder(Attribute layout) {
292292
SmallVector<int64_t> getShapePerCTA(ArrayRef<unsigned> CTASplitNum,
293293
ArrayRef<int64_t> shape) {
294294
unsigned rank = shape.size();
295+
auto splitNum = llvm::to_vector(CTASplitNum);
296+
if (splitNum.size() <= rank) { // pipelining
297+
splitNum.insert(splitNum.begin(), rank - splitNum.size(), 1);
298+
} else { // memory slicing
299+
splitNum =
300+
llvm::to_vector(llvm::drop_begin(splitNum, splitNum.size() - rank));
301+
}
295302
SmallVector<int64_t> shapePerCTA(rank);
296303
for (unsigned i = 0; i < rank; ++i) {
297-
unsigned splitNum = std::min<unsigned>(shape[i], CTASplitNum[i]);
298-
shapePerCTA[i] = shape[i] / splitNum;
304+
shapePerCTA[i] = shape[i] / std::min<unsigned>(shape[i], splitNum[i]);
299305
}
300306
return shapePerCTA;
301307
}
302308

303309
SmallVector<int64_t> getShapePerCTA(Attribute layout, ArrayRef<int64_t> shape) {
304-
if (mlir::isa<SharedEncodingTrait>(layout)) {
305-
// Special logic for pipeline pass, where shape is 3D and CTALayout is 2D.
306-
// The first dim of shape is numStages. This is a work around, otherwise
307-
// too many places would have to be modified in pipeline pass. Maybe we
308-
// need to refactor this logic in the future.
309-
auto CTASplitNum = cast<LayoutEncodingTrait>(layout).getCTASplitNum();
310-
if (shape.size() == CTASplitNum.size() + 1) {
311-
auto res = getShapePerCTA(CTASplitNum, shape.drop_front());
312-
res.insert(res.begin(), shape.front());
313-
return res;
314-
}
315-
}
316-
SmallVector<unsigned> splitNum = getCTASplitNum(layout);
317-
if (auto tmem = dyn_cast<nvidia_gpu::TensorMemoryEncodingAttr>(layout)) {
318-
if (shape.size() > splitNum.size()) {
319-
splitNum.insert(splitNum.begin(), shape.size() - splitNum.size(), 1);
320-
}
321-
}
322-
return getShapePerCTA(splitNum, shape);
310+
return getShapePerCTA(getCTASplitNum(layout), shape);
323311
}
324312

325313
SmallVector<int64_t> getAllocationShapePerCTA(Attribute layout,

unittest/Dialect/TritonGPU/DialectTest.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,38 @@ TEST_F(Fp4ToFpOpTest, Fp4ToFpOpLayoutPropagation) {
367367
}
368368
}
369369

370+
class ShapePerCTATest : public ::testing::Test {
371+
public:
372+
ShapePerCTATest() { ctx.getOrLoadDialect<TritonGPUDialect>(); }
373+
374+
protected:
375+
MLIRContext ctx;
376+
};
377+
378+
TEST_F(ShapePerCTATest, ShapePerCTA) {
379+
// Equal length
380+
SmallVector<unsigned> CTASplitNum = {2, 4};
381+
SmallVector<int64_t> shape = {64, 128};
382+
auto shapePerCTA = getShapePerCTA(CTASplitNum, shape);
383+
auto expectedShapePerCTA = SmallVector<int64_t>{32, 32};
384+
EXPECT_EQ(shapePerCTA.size(), shape.size());
385+
EXPECT_EQ(shapePerCTA, expectedShapePerCTA);
386+
387+
// rank(shape) < rank(CTASplitNum)
388+
CTASplitNum = {2, 4, 8};
389+
shapePerCTA = getShapePerCTA(CTASplitNum, shape);
390+
expectedShapePerCTA = SmallVector<int64_t>{16, 16};
391+
EXPECT_EQ(shapePerCTA.size(), shape.size());
392+
EXPECT_EQ(shapePerCTA, expectedShapePerCTA);
393+
394+
// rank(shape) > rank(CTASplitNum)
395+
CTASplitNum = {2};
396+
shapePerCTA = getShapePerCTA(CTASplitNum, shape);
397+
expectedShapePerCTA = SmallVector<int64_t>{64, 64};
398+
EXPECT_EQ(shapePerCTA.size(), shape.size());
399+
EXPECT_EQ(shapePerCTA, expectedShapePerCTA);
400+
}
401+
370402
class JoinOpTest : public ::testing::Test {
371403
public:
372404
JoinOpTest() { ctx.getOrLoadDialect<TritonGPUDialect>(); }

0 commit comments

Comments
 (0)