diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp index b1233c5c06eb4..a1209fe8230e2 100644 --- a/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp +++ b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp @@ -39,7 +39,7 @@ namespace { // // Return success only for extensions from `i8` to `i32`. template -std::optional getExtOperand(Value v, Type i8Ty, Type i32Ty) { +std::optional getExtOperand(Value v) { static_assert(llvm::is_one_of::value, "Must be instantiated with either sign- or zero- extension op"); @@ -50,7 +50,7 @@ std::optional getExtOperand(Value v, Type i8Ty, Type i32Ty) { if (!extOp) { if constexpr (std::is_same::value) { auto vTy = cast(v.getType()); - if (vTy.getElementType() != i8Ty) + if (!vTy.getElementType().isSignlessInteger(8)) return {}; return v; } @@ -61,11 +61,11 @@ std::optional getExtOperand(Value v, Type i8Ty, Type i32Ty) { // operation type, check it's extended from `i8` to `i32`. auto inOp = extOp.getIn(); auto inTy = dyn_cast(inOp.getType()); - if (!inTy || inTy.getElementType() != i8Ty) + if (!inTy || !inTy.getElementType().isSignlessInteger(8)) return {}; auto outTy = dyn_cast(extOp.getType()); - if (!outTy || outTy.getElementType() != i32Ty) + if (!outTy || !outTy.getElementType().isSignlessInteger(32)) return {}; return inOp; @@ -199,27 +199,23 @@ class LowerContractionToSVEI8MMPattern // operands are supported, but they are lowered to different operations. // Determine which is the appropriate operation to lower to. MMLA mmlaOp = MMLA::Signed; - auto maybeLhs = getExtOperand( - op.getLhs(), rewriter.getI8Type(), rewriter.getI32Type()); + auto maybeLhs = getExtOperand(op.getLhs()); if (!maybeLhs) { mmlaOp = MMLA::Unsigned; - maybeLhs = getExtOperand( - op.getLhs(), rewriter.getI8Type(), rewriter.getI32Type()); + maybeLhs = getExtOperand(op.getLhs()); } if (!maybeLhs) return rewriter.notifyMatchFailure( op, "LHS is not a sign- or zero- extended i8"); - auto maybeRhs = getExtOperand( - op.getRhs(), rewriter.getI8Type(), rewriter.getI32Type()); + auto maybeRhs = getExtOperand(op.getRhs()); if (maybeRhs) { if (mmlaOp == MMLA::Unsigned) mmlaOp = MMLA::Mixed; } else { if (mmlaOp == MMLA::Signed) mmlaOp = MMLA::MixedSwapped; - maybeRhs = getExtOperand( - op.getRhs(), rewriter.getI8Type(), rewriter.getI32Type()); + maybeRhs = getExtOperand(op.getRhs()); } if (!maybeRhs) return rewriter.notifyMatchFailure(