Skip to content

Commit 3563aec

Browse files
authored
[DIALECT][NFC] Move all GPU op related functions to ops.cpp (#5412)
Also does the following clean ups: 1. Remove unused functions and variables 2. Change `dyn_cast` to `cast` to get `DialectInferLayoutInterface` because `retEncoding` should always be available when `arg` has an encoding.
1 parent 82fec37 commit 3563aec

File tree

6 files changed

+411
-433
lines changed

6 files changed

+411
-433
lines changed

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -762,8 +762,6 @@ for
762762
];
763763

764764
let extraClassDeclaration = extraDistributedDeclaration # [{
765-
SliceEncodingAttr squeeze(int axis);
766-
767765
SmallVector<unsigned> getContigPerThread() {
768766
// Block encoding is dense stride layout. The elements per thread are contiguous.
769767
return getSizePerThread();

lib/Dialect/Triton/IR/Ops.cpp

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ LogicalResult TransOp::inferReturnTypes(
221221
Attribute retEncoding;
222222
if (argEncoding) {
223223
Dialect &dialect = argEncoding.getDialect();
224-
auto inferLayoutInterface = dyn_cast<DialectInferLayoutInterface>(&dialect);
224+
auto inferLayoutInterface = cast<DialectInferLayoutInterface>(&dialect);
225225
if (inferLayoutInterface
226226
->inferTransOpEncoding(argEncoding, order, retEncoding)
227227
.failed()) {
@@ -250,7 +250,7 @@ DotOp::inferReturnTypes(MLIRContext *context, std::optional<Location> location,
250250
if (aEnc) {
251251
assert(bEnc && retEnc);
252252
Dialect &dialect = retEnc.getDialect();
253-
auto interface = dyn_cast<DialectInferLayoutInterface>(&dialect);
253+
auto interface = cast<DialectInferLayoutInterface>(&dialect);
254254
if (interface->inferDotOpEncoding(aEnc, 0, retEnc, location).failed())
255255
return failure();
256256
if (interface->inferDotOpEncoding(bEnc, 1, retEnc, location).failed())
@@ -331,8 +331,7 @@ inferReduceReturnShape(RankedTensorType argTy, Type retEltTy, int axis,
331331
Attribute retEncoding;
332332
if (argEncoding) {
333333
Dialect &dialect = argEncoding.getDialect();
334-
auto inferLayoutInterface =
335-
dyn_cast<DialectInferLayoutInterface>(&dialect);
334+
auto inferLayoutInterface = cast<DialectInferLayoutInterface>(&dialect);
336335
if (inferLayoutInterface
337336
->inferReduceOpEncoding(argEncoding, axis, retEncoding)
338337
.failed()) {
@@ -565,7 +564,7 @@ LogicalResult ExpandDimsOp::inferReturnTypes(
565564
Attribute retEncoding;
566565
if (argEncoding) {
567566
Dialect &dialect = argEncoding.getDialect();
568-
auto inferLayoutInterface = dyn_cast<DialectInferLayoutInterface>(&dialect);
567+
auto inferLayoutInterface = cast<DialectInferLayoutInterface>(&dialect);
569568
if (inferLayoutInterface
570569
->inferExpandDimsOpEncoding(argEncoding, axis, retEncoding, loc)
571570
.failed())
@@ -604,7 +603,7 @@ LogicalResult ExpandDimsOp::canonicalize(ExpandDimsOp op,
604603
// Infer the encoding of the new expand op, if encodings are present.
605604
Attribute newExpandEnc;
606605
if (auto srcEnc = srcTy.getEncoding()) {
607-
if (dyn_cast<DialectInferLayoutInterface>(&srcEnc.getDialect())
606+
if (cast<DialectInferLayoutInterface>(&srcEnc.getDialect())
608607
->inferExpandDimsOpEncoding(srcEnc, op.getAxis(), newExpandEnc,
609608
op.getLoc())
610609
.failed()) {
@@ -975,7 +974,6 @@ JoinOp::inferReturnTypes(MLIRContext *context, std::optional<Location> location,
975974
assert(isa<RankedTensorType>(operands[1].getType()));
976975

977976
Value lhs = operands[0];
978-
Value rhs = operands[1];
979977
auto srcTy = cast<RankedTensorType>(lhs.getType());
980978

981979
SmallVector<int64_t> retShape(srcTy.getShape());
@@ -984,7 +982,7 @@ JoinOp::inferReturnTypes(MLIRContext *context, std::optional<Location> location,
984982
Attribute srcEnc = srcTy.getEncoding();
985983
Attribute retEnc;
986984
if (srcEnc) {
987-
if (dyn_cast<DialectInferLayoutInterface>(&srcEnc.getDialect())
985+
if (cast<DialectInferLayoutInterface>(&srcEnc.getDialect())
988986
->inferJoinOpEncoding(srcEnc, retEnc, location)
989987
.failed()) {
990988
return failure();
@@ -1017,7 +1015,7 @@ LogicalResult SplitOp::inferReturnTypes(
10171015
Attribute srcEnc = srcTy.getEncoding();
10181016
Attribute retEnc;
10191017
if (srcEnc) {
1020-
if (dyn_cast<DialectInferLayoutInterface>(&srcEnc.getDialect())
1018+
if (cast<DialectInferLayoutInterface>(&srcEnc.getDialect())
10211019
->inferSplitOpEncoding(srcEnc, retEnc, location)
10221020
.failed()) {
10231021
return failure();

0 commit comments

Comments
 (0)