Skip to content

Commit 2daef6c

Browse files
lezcanoloislo
authored andcommitted
[LAYOUTS] [NFC] Just accept DistributedEncodings in SliceLayout (triton-lang#6004)
1 parent e5782c4 commit 2daef6c

File tree

11 files changed

+64
-37
lines changed

11 files changed

+64
-37
lines changed

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1267,8 +1267,7 @@ def SliceEncodingAttr : DistributedEncoding<"SliceEncoding", "slice_encoding"> {
12671267
let parameters = (
12681268
ins
12691269
"unsigned":$dim,
1270-
// TODO: constraint here to only take distributed encodings
1271-
"Attribute":$parent
1270+
"DistributedEncodingTrait":$parent
12721271
);
12731272

12741273
let extraClassDeclaration = extraDistributedDeclaration # [{

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1592,7 +1592,12 @@ Attribute SliceEncodingAttr::parse(AsmParser &parser, Type type) {
15921592
if (parser.parseGreater().failed())
15931593
return {};
15941594
unsigned dim = mlir::cast<IntegerAttr>(attrs.get("dim")).getInt();
1595-
Attribute parent = attrs.get("parent");
1595+
auto parent = mlir::dyn_cast<DistributedEncodingTrait>(attrs.get("parent"));
1596+
if (!parent) {
1597+
parser.emitError(parser.getNameLoc(),
1598+
"expected a distributed encoding trait");
1599+
return {};
1600+
}
15961601
return parser.getChecked<SliceEncodingAttr>(parser.getContext(), dim, parent);
15971602
}
15981603

@@ -2285,8 +2290,9 @@ struct TritonGPUInferLayoutInterface
22852290
LogicalResult
22862291
inferReduceOpEncoding(Attribute operandEncoding, unsigned axis,
22872292
Attribute &resultEncoding) const override {
2288-
resultEncoding = SliceEncodingAttr::get(getDialect()->getContext(), axis,
2289-
operandEncoding);
2293+
resultEncoding =
2294+
SliceEncodingAttr::get(getDialect()->getContext(), axis,
2295+
cast<DistributedEncodingTrait>(operandEncoding));
22902296
return success();
22912297
}
22922298

lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -428,8 +428,8 @@ static bool canUseTwoCTAs(triton::DotOp dotOp) {
428428
return true;
429429
}
430430

431-
static Attribute
432-
replaceCTALayout(Attribute layout,
431+
static DistributedEncodingTrait
432+
replaceCTALayout(DistributedEncodingTrait layout,
433433
const triton::gpu::CTALayoutAttr &newCTALayout) {
434434
if (auto blockedLayout = mlir::dyn_cast<BlockedEncodingAttr>(layout)) {
435435
return BlockedEncodingAttr::get(
@@ -454,7 +454,7 @@ static Value splitBOperand(Value b, mlir::PatternRewriter &rewriter) {
454454
auto loadOp = b.getDefiningOp<triton::LoadOp>();
455455
assert(loadOp && "expected LoadOp");
456456
RankedTensorType bType = cast<RankedTensorType>(b.getType());
457-
Attribute currentLayout = bType.getEncoding();
457+
auto currentLayout = cast<DistributedEncodingTrait>(bType.getEncoding());
458458
auto newCTALayout =
459459
CTALayoutAttr::get(ctx, {1, 2}, {1, 2}, getCTAOrder(currentLayout));
460460
Attribute newLayout = replaceCTALayout(currentLayout, newCTALayout);

lib/Dialect/TritonGPU/Transforms/OptimizeThreadLocality.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -544,7 +544,8 @@ class TritonGPUOptimizeThreadLocalityPass
544544
return viewOpTensorShape;
545545
}
546546

547-
Attribute getThreadLocalityOptimizedEncoding(triton::ReduceOp reduce) const {
547+
BlockedEncodingAttr
548+
getThreadLocalityOptimizedEncoding(triton::ReduceOp reduce) const {
548549
auto srcType = cast<RankedTensorType>(reduce.getOperands()[0].getType());
549550
auto rank = srcType.getShape().size();
550551
auto srcEncoding = srcType.getEncoding();

lib/Dialect/TritonGPU/Transforms/Utility.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -302,8 +302,9 @@ std::string GraphLayoutMarker::getColor(const Type &type) const {
302302
// -------------------------------------------------------------------------- //
303303

304304
static Attribute inferDstEncoding(triton::ReduceOp op, Attribute encoding) {
305-
return triton::gpu::SliceEncodingAttr::get(op->getContext(), op.getAxis(),
306-
encoding);
305+
return triton::gpu::SliceEncodingAttr::get(
306+
op->getContext(), op.getAxis(),
307+
cast<ttg::DistributedEncodingTrait>(encoding));
307308
}
308309

309310
static Attribute inferDstEncoding(triton::ExpandDimsOp op, Attribute encoding) {
@@ -351,8 +352,9 @@ static Attribute inferSrcEncoding(triton::ReduceOp op, Attribute encoding) {
351352
}
352353

353354
static Attribute inferSrcEncoding(triton::ExpandDimsOp op, Attribute encoding) {
354-
return triton::gpu::SliceEncodingAttr::get(op->getContext(), op.getAxis(),
355-
encoding);
355+
return triton::gpu::SliceEncodingAttr::get(
356+
op->getContext(), op.getAxis(),
357+
cast<ttg::DistributedEncodingTrait>(encoding));
356358
}
357359

358360
static Attribute inferSrcEncoding(JoinOp op, Attribute dstEnc) {

lib/Dialect/TritonNvidiaGPU/Transforms/PlanCTA.cpp

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,10 @@ Type replaceLayout(const Type &type, const Attribute &newLayout) {
5757
return curType;
5858
}
5959

60-
Attribute replaceCTALayout(Attribute layout, llvm::ArrayRef<int64_t> shape,
61-
const ttg::CTALayoutAttr &newCTALayout) {
60+
ttg::DistributedEncodingTrait
61+
replaceCTALayout(ttg::DistributedEncodingTrait layout,
62+
llvm::ArrayRef<int64_t> shape,
63+
const ttg::CTALayoutAttr &newCTALayout) {
6264
if (auto blockedLayout = mlir::dyn_cast<ttg::BlockedEncodingAttr>(layout)) {
6365
return ttg::BlockedEncodingAttr::get(
6466
layout.getContext(), shape, blockedLayout.getSizePerThread(),
@@ -120,9 +122,9 @@ class CTAPlanner {
120122

121123
bool processBroadcast(triton::BroadcastOp broadcast, Attribute layout);
122124
bool processExpandDimsBackward(triton::ExpandDimsOp expandDims,
123-
Attribute newResultLayout);
125+
ttg::DistributedEncodingTrait newResultLayout);
124126
bool processExpandDimsForward(triton::ExpandDimsOp expandDims,
125-
Attribute newSrcLayout);
127+
ttg::DistributedEncodingTrait newSrcLayout);
126128

127129
bool processConvertLayoutBackward(ttg::ConvertLayoutOp convertLayout,
128130
CastOp cast);
@@ -361,7 +363,8 @@ bool CTAPlanner::processReduce(triton::FuncOp &funcOp) {
361363
ttg::CTALayoutAttr::get(context, CTAsPerCGA, CTASplitNum, CTAOrder);
362364
if (!tiled)
363365
setTiling(CTALayout.getCTAsPerCGA());
364-
auto newSrcLayout = replaceCTALayout(srcLayout, srcShape, CTALayout);
366+
auto newSrcLayout = replaceCTALayout(
367+
cast<ttg::DistributedEncodingTrait>(srcLayout), srcShape, CTALayout);
365368
auto newResultLayout =
366369
ttg::SliceEncodingAttr::get(context, axis, newSrcLayout);
367370
unsigned numOperands = reduce.getNumOperands();
@@ -393,8 +396,9 @@ void CTAPlanner::processStoreLikeOps(triton::FuncOp &funcOp) {
393396
CTALayout = ttg::getCTALayout(tensorTy.getEncoding());
394397
setTiling(CTALayout.getCTAsPerCGA());
395398
}
396-
auto newLayout = replaceCTALayout(tensorTy.getEncoding(),
397-
tensorTy.getShape(), CTALayout);
399+
auto newLayout = replaceCTALayout(
400+
cast<ttg::DistributedEncodingTrait>(tensorTy.getEncoding()),
401+
tensorTy.getShape(), CTALayout);
398402
processElementwise(store, newLayout);
399403
}
400404
}
@@ -421,7 +425,8 @@ bool CTAPlanner::propagateBackward(CastOp cast) {
421425
Type outTy = output.getType();
422426
if (auto ptrTy = dyn_cast<triton::PointerType>(outTy))
423427
outTy = ptrTy.getPointeeType();
424-
Attribute layout = mlir::cast<RankedTensorType>(outTy).getEncoding();
428+
auto layout = mlir::cast<ttg::DistributedEncodingTrait>(
429+
mlir::cast<RankedTensorType>(outTy).getEncoding());
425430
Operation *op = input.getDefiningOp();
426431
if (op == nullptr) {
427432
assert(isa<BlockArgument>(input) &&
@@ -626,8 +631,10 @@ bool CTAPlanner::processLoadStore(Operation *op, Attribute layout) {
626631
if (auto ptrTy = dyn_cast<triton::PointerType>(type))
627632
type = ptrTy.getPointeeType();
628633
auto tensorTy = cast<RankedTensorType>(type);
629-
auto newLayout = replaceCTALayout(tensorTy.getEncoding(),
630-
tensorTy.getShape(), CTALayout);
634+
auto oldLayout =
635+
cast<ttg::DistributedEncodingTrait>(tensorTy.getEncoding());
636+
auto newLayout =
637+
replaceCTALayout(oldLayout, tensorTy.getShape(), CTALayout);
631638
newOperandLayouts.push_back(newLayout);
632639
}
633640

@@ -637,8 +644,10 @@ bool CTAPlanner::processLoadStore(Operation *op, Attribute layout) {
637644
if (auto ptrTy = dyn_cast<triton::PointerType>(type))
638645
type = ptrTy.getPointeeType();
639646
auto tensorTy = cast<RankedTensorType>(type);
640-
auto newLayout = replaceCTALayout(tensorTy.getEncoding(),
641-
tensorTy.getShape(), CTALayout);
647+
auto oldLayout =
648+
cast<ttg::DistributedEncodingTrait>(tensorTy.getEncoding());
649+
auto newLayout =
650+
replaceCTALayout(oldLayout, tensorTy.getShape(), CTALayout);
642651
newResultLayouts.push_back(newLayout);
643652
}
644653

@@ -725,16 +734,18 @@ bool CTAPlanner::processBroadcast(triton::BroadcastOp broadcast,
725734
return true;
726735
}
727736

728-
bool CTAPlanner::processExpandDimsBackward(triton::ExpandDimsOp expandDims,
729-
Attribute newResultLayout) {
737+
bool CTAPlanner::processExpandDimsBackward(
738+
triton::ExpandDimsOp expandDims,
739+
ttg::DistributedEncodingTrait newResultLayout) {
730740
auto newSrcLayout = ttg::SliceEncodingAttr::get(
731741
newResultLayout.getContext(), expandDims.getAxis(), newResultLayout);
732742
insertCasts(expandDims.getOperation(), {newSrcLayout}, {newResultLayout});
733743
return true;
734744
}
735745

736-
bool CTAPlanner::processExpandDimsForward(triton::ExpandDimsOp expandDims,
737-
Attribute newSrcLayout) {
746+
bool CTAPlanner::processExpandDimsForward(
747+
triton::ExpandDimsOp expandDims,
748+
ttg::DistributedEncodingTrait newSrcLayout) {
738749
llvm::report_fatal_error("processExpandDimsForward not implemented yet");
739750
return true;
740751
}

third_party/amd/lib/TritonAMDGPUToLLVM/DecomposeUnsupportedConversions.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,8 @@ struct DecomposeUnsupportedAMDConversions
6767
auto srcType = cvtOp.getSrc().getType();
6868
auto dstType = cvtOp.getType();
6969

70-
auto srcEnc = srcType.getEncoding();
70+
auto srcEnc =
71+
cast<triton::gpu::DistributedEncodingTrait>(srcType.getEncoding());
7172
auto dstBlocked =
7273
dyn_cast<triton::gpu::BlockedEncodingAttr>(dstType.getEncoding());
7374

third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUsage.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,10 @@ class OptimizeAMDLDSUsage
9292
auto srcType = cvtOp.getSrc().getType();
9393
auto dstType = cvtOp.getType();
9494

95-
auto srcEnc = srcType.getEncoding();
96-
auto dstEnc = dstType.getEncoding();
95+
auto srcEnc =
96+
cast<triton::gpu::DistributedEncodingTrait>(srcType.getEncoding());
97+
auto dstEnc =
98+
cast<triton::gpu::DistributedEncodingTrait>(dstType.getEncoding());
9799

98100
auto ctx = srcEnc.getContext();
99101
auto rank = srcType.getRank();

third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUtility.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,9 @@ std::vector<SmallVector<unsigned>> factorizePowerOf2(int n, int rank) {
4949
return factors;
5050
}
5151

52-
Attribute createTmpLayout(Attribute layout, ArrayRef<unsigned> warpsPerCTA) {
52+
triton::gpu::DistributedEncodingTrait
53+
createTmpLayout(triton::gpu::DistributedEncodingTrait layout,
54+
ArrayRef<unsigned> warpsPerCTA) {
5355
auto ctx = layout.getContext();
5456
if (auto src = dyn_cast<triton::gpu::AMDMfmaEncodingAttr>(layout))
5557
return triton::gpu::AMDMfmaEncodingAttr::get(
@@ -65,8 +67,9 @@ Attribute createTmpLayout(Attribute layout, ArrayRef<unsigned> warpsPerCTA) {
6567
ctx, src.getSizePerThread(), src.getThreadsPerWarp(), warpsPerCTA,
6668
src.getOrder(), src.getCTALayout());
6769
if (auto src = dyn_cast<triton::gpu::DotOperandEncodingAttr>(layout)) {
70+
auto parent = cast<triton::gpu::DistributedEncodingTrait>(src.getParent());
6871
return triton::gpu::DotOperandEncodingAttr::get(
69-
ctx, src.getOpIdx(), createTmpLayout(src.getParent(), warpsPerCTA),
72+
ctx, src.getOpIdx(), createTmpLayout(parent, warpsPerCTA),
7073
src.getKWidth());
7174
}
7275
if (auto src = dyn_cast<triton::gpu::SliceEncodingAttr>(layout)) {
@@ -77,7 +80,7 @@ Attribute createTmpLayout(Attribute layout, ArrayRef<unsigned> warpsPerCTA) {
7780
ctx, src.getDim(), createTmpLayout(src.getParent(), parentWarpsPerCTA));
7881
}
7982
assert("Encountered unsupported layout");
80-
return Attribute();
83+
return {};
8184
}
8285

8386
std::pair<triton::gpu::ConvertLayoutOp, triton::gpu::ConvertLayoutOp>

third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUtility.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@ std::vector<SmallVector<unsigned>> factorizePowerOf2(int n, int rank);
1616
/// \param layout original layout
1717
/// \param warpsPerCTA new warpsPerCTA
1818
/// \returns create layout
19-
Attribute createTmpLayout(Attribute layout, ArrayRef<unsigned> warpsPerCTA);
19+
triton::gpu::DistributedEncodingTrait
20+
createTmpLayout(triton::gpu::DistributedEncodingTrait layout,
21+
ArrayRef<unsigned> warpsPerCTA);
2022

2123
/// Creates two chained convert layout operations
2224
///

0 commit comments

Comments
 (0)