@@ -39,7 +39,7 @@ namespace {
3939//
4040// Return success only for extensions from `i8` to `i32`.
4141template <typename Op>
42- std::optional<Value> getExtOperand (Value v, Type i8Ty, Type i32Ty ) {
42+ std::optional<Value> getExtOperand (Value v) {
4343
4444 static_assert (llvm::is_one_of<Op, arith::ExtSIOp, arith::ExtUIOp>::value,
4545 " Must be instantiated with either sign- or zero- extension op" );
@@ -50,7 +50,7 @@ std::optional<Value> getExtOperand(Value v, Type i8Ty, Type i32Ty) {
5050 if (!extOp) {
5151 if constexpr (std::is_same<Op, arith::ExtSIOp>::value) {
5252 auto vTy = cast<VectorType>(v.getType ());
53- if (vTy.getElementType () != i8Ty )
53+ if (! vTy.getElementType (). isSignlessInteger ( 8 ) )
5454 return {};
5555 return v;
5656 }
@@ -61,11 +61,11 @@ std::optional<Value> getExtOperand(Value v, Type i8Ty, Type i32Ty) {
6161 // operation type, check it's extended from `i8` to `i32`.
6262 auto inOp = extOp.getIn ();
6363 auto inTy = dyn_cast<VectorType>(inOp.getType ());
64- if (!inTy || inTy.getElementType () != i8Ty )
64+ if (!inTy || ! inTy.getElementType (). isSignlessInteger ( 8 ) )
6565 return {};
6666
6767 auto outTy = dyn_cast<VectorType>(extOp.getType ());
68- if (!outTy || outTy.getElementType () != i32Ty )
68+ if (!outTy || ! outTy.getElementType (). isSignlessInteger ( 32 ) )
6969 return {};
7070
7171 return inOp;
@@ -199,27 +199,23 @@ class LowerContractionToSVEI8MMPattern
199199 // operands are supported, but they are lowered to different operations.
200200 // Determine which is the appropriate operation to lower to.
201201 MMLA mmlaOp = MMLA::Signed;
202- auto maybeLhs = getExtOperand<arith::ExtSIOp>(
203- op.getLhs (), rewriter.getI8Type (), rewriter.getI32Type ());
202+ auto maybeLhs = getExtOperand<arith::ExtSIOp>(op.getLhs ());
204203 if (!maybeLhs) {
205204 mmlaOp = MMLA::Unsigned;
206- maybeLhs = getExtOperand<arith::ExtUIOp>(
207- op.getLhs (), rewriter.getI8Type (), rewriter.getI32Type ());
205+ maybeLhs = getExtOperand<arith::ExtUIOp>(op.getLhs ());
208206 }
209207 if (!maybeLhs)
210208 return rewriter.notifyMatchFailure (
211209 op, " LHS is not a sign- or zero- extended i8" );
212210
213- auto maybeRhs = getExtOperand<arith::ExtSIOp>(
214- op.getRhs (), rewriter.getI8Type (), rewriter.getI32Type ());
211+ auto maybeRhs = getExtOperand<arith::ExtSIOp>(op.getRhs ());
215212 if (maybeRhs) {
216213 if (mmlaOp == MMLA::Unsigned)
217214 mmlaOp = MMLA::Mixed;
218215 } else {
219216 if (mmlaOp == MMLA::Signed)
220217 mmlaOp = MMLA::MixedSwapped;
221- maybeRhs = getExtOperand<arith::ExtUIOp>(
222- op.getRhs (), rewriter.getI8Type (), rewriter.getI32Type ());
218+ maybeRhs = getExtOperand<arith::ExtUIOp>(op.getRhs ());
223219 }
224220 if (!maybeRhs)
225221 return rewriter.notifyMatchFailure (
0 commit comments