@@ -221,7 +221,7 @@ LogicalResult TransOp::inferReturnTypes(
221221 Attribute retEncoding;
222222 if (argEncoding) {
223223 Dialect &dialect = argEncoding.getDialect ();
224- auto inferLayoutInterface = dyn_cast <DialectInferLayoutInterface>(&dialect);
224+ auto inferLayoutInterface = cast <DialectInferLayoutInterface>(&dialect);
225225 if (inferLayoutInterface
226226 ->inferTransOpEncoding (argEncoding, order, retEncoding)
227227 .failed ()) {
@@ -250,7 +250,7 @@ DotOp::inferReturnTypes(MLIRContext *context, std::optional<Location> location,
250250 if (aEnc) {
251251 assert (bEnc && retEnc);
252252 Dialect &dialect = retEnc.getDialect ();
253- auto interface = dyn_cast <DialectInferLayoutInterface>(&dialect);
253+ auto interface = cast <DialectInferLayoutInterface>(&dialect);
254254 if (interface->inferDotOpEncoding (aEnc, 0 , retEnc, location).failed ())
255255 return failure ();
256256 if (interface->inferDotOpEncoding (bEnc, 1 , retEnc, location).failed ())
@@ -331,8 +331,7 @@ inferReduceReturnShape(RankedTensorType argTy, Type retEltTy, int axis,
331331 Attribute retEncoding;
332332 if (argEncoding) {
333333 Dialect &dialect = argEncoding.getDialect ();
334- auto inferLayoutInterface =
335- dyn_cast<DialectInferLayoutInterface>(&dialect);
334+ auto inferLayoutInterface = cast<DialectInferLayoutInterface>(&dialect);
336335 if (inferLayoutInterface
337336 ->inferReduceOpEncoding (argEncoding, axis, retEncoding)
338337 .failed ()) {
@@ -565,7 +564,7 @@ LogicalResult ExpandDimsOp::inferReturnTypes(
565564 Attribute retEncoding;
566565 if (argEncoding) {
567566 Dialect &dialect = argEncoding.getDialect ();
568- auto inferLayoutInterface = dyn_cast <DialectInferLayoutInterface>(&dialect);
567+ auto inferLayoutInterface = cast <DialectInferLayoutInterface>(&dialect);
569568 if (inferLayoutInterface
570569 ->inferExpandDimsOpEncoding (argEncoding, axis, retEncoding, loc)
571570 .failed ())
@@ -604,7 +603,7 @@ LogicalResult ExpandDimsOp::canonicalize(ExpandDimsOp op,
604603 // Infer the encoding of the new expand op, if encodings are present.
605604 Attribute newExpandEnc;
606605 if (auto srcEnc = srcTy.getEncoding ()) {
607- if (dyn_cast <DialectInferLayoutInterface>(&srcEnc.getDialect ())
606+ if (cast <DialectInferLayoutInterface>(&srcEnc.getDialect ())
608607 ->inferExpandDimsOpEncoding (srcEnc, op.getAxis (), newExpandEnc,
609608 op.getLoc ())
610609 .failed ()) {
@@ -975,7 +974,6 @@ JoinOp::inferReturnTypes(MLIRContext *context, std::optional<Location> location,
975974 assert (isa<RankedTensorType>(operands[1 ].getType ()));
976975
977976 Value lhs = operands[0 ];
978- Value rhs = operands[1 ];
979977 auto srcTy = cast<RankedTensorType>(lhs.getType ());
980978
981979 SmallVector<int64_t > retShape (srcTy.getShape ());
@@ -984,7 +982,7 @@ JoinOp::inferReturnTypes(MLIRContext *context, std::optional<Location> location,
984982 Attribute srcEnc = srcTy.getEncoding ();
985983 Attribute retEnc;
986984 if (srcEnc) {
987- if (dyn_cast <DialectInferLayoutInterface>(&srcEnc.getDialect ())
985+ if (cast <DialectInferLayoutInterface>(&srcEnc.getDialect ())
988986 ->inferJoinOpEncoding (srcEnc, retEnc, location)
989987 .failed ()) {
990988 return failure ();
@@ -1017,7 +1015,7 @@ LogicalResult SplitOp::inferReturnTypes(
10171015 Attribute srcEnc = srcTy.getEncoding ();
10181016 Attribute retEnc;
10191017 if (srcEnc) {
1020- if (dyn_cast <DialectInferLayoutInterface>(&srcEnc.getDialect ())
1018+ if (cast <DialectInferLayoutInterface>(&srcEnc.getDialect ())
10211019 ->inferSplitOpEncoding (srcEnc, retEnc, location)
10221020 .failed ()) {
10231021 return failure ();
0 commit comments