From 7c5c8eecfcce07d5552efc2d5b00f0e468ff7e8b Mon Sep 17 00:00:00 2001 From: Momchil Velikov Date: Thu, 19 Jun 2025 15:03:30 +0000 Subject: [PATCH] [MLIR][AArch64] Simplify LowerContractionToSVEI8MMPattern.cpp:getExtOperand (NFC) Just recently learned about `isSignlessInteger`, use that instead of comparing to types obtained via `rewriter.getIType()`. --- .../LowerContractionToSVEI8MMPattern.cpp | 20 ++++++++----------- 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp index b1233c5c06eb4..a1209fe8230e2 100644 --- a/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp +++ b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp @@ -39,7 +39,7 @@ namespace { // // Return success only for extensions from `i8` to `i32`. template -std::optional getExtOperand(Value v, Type i8Ty, Type i32Ty) { +std::optional getExtOperand(Value v) { static_assert(llvm::is_one_of::value, "Must be instantiated with either sign- or zero- extension op"); @@ -50,7 +50,7 @@ std::optional getExtOperand(Value v, Type i8Ty, Type i32Ty) { if (!extOp) { if constexpr (std::is_same::value) { auto vTy = cast(v.getType()); - if (vTy.getElementType() != i8Ty) + if (!vTy.getElementType().isSignlessInteger(8)) return {}; return v; } @@ -61,11 +61,11 @@ std::optional getExtOperand(Value v, Type i8Ty, Type i32Ty) { // operation type, check it's extended from `i8` to `i32`. auto inOp = extOp.getIn(); auto inTy = dyn_cast(inOp.getType()); - if (!inTy || inTy.getElementType() != i8Ty) + if (!inTy || !inTy.getElementType().isSignlessInteger(8)) return {}; auto outTy = dyn_cast(extOp.getType()); - if (!outTy || outTy.getElementType() != i32Ty) + if (!outTy || !outTy.getElementType().isSignlessInteger(32)) return {}; return inOp; @@ -199,27 +199,23 @@ class LowerContractionToSVEI8MMPattern // operands are supported, but they are lowered to different operations. // Determine which is the appropriate operation to lower to. MMLA mmlaOp = MMLA::Signed; - auto maybeLhs = getExtOperand( - op.getLhs(), rewriter.getI8Type(), rewriter.getI32Type()); + auto maybeLhs = getExtOperand(op.getLhs()); if (!maybeLhs) { mmlaOp = MMLA::Unsigned; - maybeLhs = getExtOperand( - op.getLhs(), rewriter.getI8Type(), rewriter.getI32Type()); + maybeLhs = getExtOperand(op.getLhs()); } if (!maybeLhs) return rewriter.notifyMatchFailure( op, "LHS is not a sign- or zero- extended i8"); - auto maybeRhs = getExtOperand( - op.getRhs(), rewriter.getI8Type(), rewriter.getI32Type()); + auto maybeRhs = getExtOperand(op.getRhs()); if (maybeRhs) { if (mmlaOp == MMLA::Unsigned) mmlaOp = MMLA::Mixed; } else { if (mmlaOp == MMLA::Signed) mmlaOp = MMLA::MixedSwapped; - maybeRhs = getExtOperand( - op.getRhs(), rewriter.getI8Type(), rewriter.getI32Type()); + maybeRhs = getExtOperand(op.getRhs()); } if (!maybeRhs) return rewriter.notifyMatchFailure(