diff --git a/externals/llvm-project b/externals/llvm-project index d21b5c874658..a0fc10d350b9 160000 --- a/externals/llvm-project +++ b/externals/llvm-project @@ -1 +1 @@ -Subproject commit d21b5c87465837690b6bd5f2eafac20808d3da39 +Subproject commit a0fc10d350b9b1b29767796f3faeb07132cd08fd diff --git a/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h b/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h index 8beabb969ed8..fbef9085470d 100644 --- a/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h +++ b/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h @@ -49,6 +49,11 @@ bool isScale32(mlir::quant::UniformQuantizedType output_element_type); Value getTosaConstTensorSingleF32(PatternRewriter &rewriter, Operation *op, float val); +// Create an int8_t const tosa.mul shift tensor from an int when required for +// the given result type. Returns a null Value when no shift operand is needed. +Value getTosaMulShiftConstTensor(PatternRewriter &rewriter, Operation *op, + Type resultType, int32_t shift); + // Create a zero constant tensor of the desired type and shape. std::optional getZerosLikeTensor(PatternRewriter &rewriter, Operation *op, Type type); diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 5bf8a3387d88..219775b92438 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -672,14 +672,15 @@ std::optional floorIntDiv(PatternRewriter &rewriter, Operation *op, auto boolType = RankedTensorType::get(outType.getShape(), rewriter.getIntegerType(1)); - auto lhsMulRhs = rewriter.create(op->getLoc(), i32Type, lhs, rhs, - /*shift=*/0); + auto mulShift = tosa::getTosaMulShiftConstTensor(rewriter, op, i32Type, 0); + auto lhsMulRhs = + rewriter.create(op->getLoc(), i32Type, lhs, rhs, mulShift); auto lhsRhsDifferentSign = rewriter.create(op->getLoc(), boolType, zero, lhsMulRhs); auto truncMulRhs = rewriter.create(op->getLoc(), i32Type, - intDivOp, rhs, /*shift=*/0); + intDivOp, rhs, mulShift); auto truncMulRhsEqualLhs = rewriter.create(op->getLoc(), boolType, truncMulRhs, lhs); @@ -918,12 +919,13 @@ LogicalResult ConvertAtenOp::matchAndRewrite( op->getLoc(), RankedTensorType::get(selfTy.getShape(), rewriter.getIntegerType(1)), self, zero); - auto mulTensor = rewriter.create( - op->getLoc(), getTypeConverter()->convertType(op.getType()), self, - alphaTensor, /*shift=*/0); + auto resultType = getTypeConverter()->convertType(op.getType()); + auto mulShift = tosa::getTosaMulShiftConstTensor(rewriter, op, resultType, 0); + auto mulTensor = rewriter.create(op->getLoc(), resultType, self, + alphaTensor, mulShift); - rewriter.replaceOpWithNewOp( - op, getTypeConverter()->convertType(op.getType()), cond, self, mulTensor); + rewriter.replaceOpWithNewOp(op, resultType, cond, self, + mulTensor); return success(); } @@ -2348,8 +2350,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "Failed to equalize ranks among operands and result"); + auto mulShift = tosa::getTosaMulShiftConstTensor(rewriter, op, resultTy, 0); auto multTensor = rewriter.create(op->getLoc(), resultTy, self, - alphaTensor, /*shift=*/0); + alphaTensor, mulShift); rewriter.replaceOpWithNewOp(op, resultTy, otherTensor, multTensor); @@ -2761,12 +2764,13 @@ std::optional computeBatchNorm(Operation *op, auto op3RsqrtOp2 = rewriter.create( op->getLoc(), variance.getType(), op2AddVarEpsilon.getResult()); - auto op4MulOp1Op3 = rewriter.create(op->getLoc(), outType, - op1SubInputMean.getResult(), - op3RsqrtOp2.getResult(), 0); + auto mulShift = tosa::getTosaMulShiftConstTensor(rewriter, op, outType, 0); + auto op4MulOp1Op3 = rewriter.create( + op->getLoc(), outType, op1SubInputMean.getResult(), + op3RsqrtOp2.getResult(), mulShift); auto op5MulOp4Scale = rewriter.create( - op->getLoc(), outType, op4MulOp1Op3.getResult(), weight, 0); + op->getLoc(), outType, op4MulOp1Op3.getResult(), weight, mulShift); return rewriter .create(op->getLoc(), outType, op5MulOp4Scale.getResult(), @@ -2989,19 +2993,25 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Compute mean. Value sum = computeSumAndReshape(input, inputType, bcastOutType, bcastOutShape); - Value meanVal = rewriter.create(op.getLoc(), bcastOutType, sum, - elemCntRcp, /*shift=*/0); + Value meanVal = rewriter.create( + op.getLoc(), bcastOutType, sum, elemCntRcp, + tosa::getTosaMulShiftConstTensor(rewriter, op.getOperation(), + bcastOutType, 0)); // Compute variance. Value squareSumSub = rewriter.create(op.getLoc(), inputType, input, meanVal); - Value squareSum = rewriter.create(op.getLoc(), inputType, - squareSumSub, squareSumSub, 0); + Value squareSum = rewriter.create( + op.getLoc(), inputType, squareSumSub, squareSumSub, + tosa::getTosaMulShiftConstTensor(rewriter, op.getOperation(), inputType, + 0)); Value squareSumReduced = computeSumAndReshape(squareSum, inputType, bcastOutType, bcastOutShape); Value varianceVal = rewriter.create( - op.getLoc(), bcastOutType, squareSumReduced, elemCntRcp, /*shift=*/0); + op.getLoc(), bcastOutType, squareSumReduced, elemCntRcp, + tosa::getTosaMulShiftConstTensor(rewriter, op.getOperation(), + bcastOutType, 0)); // Reshape weight and bias. SmallVector weightAndBiasBcastShape; @@ -3259,8 +3269,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( rewriter.create(op.getLoc(), ln2Op.getType(), ln2Op); auto logOp = rewriter.create(op.getLoc(), outType, self); - rewriter.replaceOpWithNewOp(op, outType, logOp, rcpOp, - /*shift=*/0); + auto mulShift = tosa::getTosaMulShiftConstTensor(rewriter, op, outType, 0); + rewriter.replaceOpWithNewOp(op, outType, logOp, rcpOp, mulShift); return success(); } @@ -3497,26 +3507,27 @@ approximateErfOp(ConversionPatternRewriter &rewriter, Operation *op, Value x, mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), x, a4).failed()) return std::nullopt; - auto a1X = rewriter.create(loc, outType, a1, absX, /*shift=*/0); + auto mulShift = tosa::getTosaMulShiftConstTensor(rewriter, op, outType, 0); + auto a1X = rewriter.create(loc, outType, a1, absX, mulShift); auto sum = rewriter.create(loc, outType, a1X, one); - auto x2 = rewriter.create(loc, outType, absX, absX, /*shift=*/0); - auto a2X = rewriter.create(loc, outType, a2, x2, /*shift=*/0); + auto x2 = rewriter.create(loc, outType, absX, absX, mulShift); + auto a2X = rewriter.create(loc, outType, a2, x2, mulShift); sum = rewriter.create(loc, outType, sum, a2X); - auto x3 = rewriter.create(loc, outType, x2, absX, /*shift=*/0); - auto a3X = rewriter.create(loc, outType, a3, x3, /*shift=*/0); + auto x3 = rewriter.create(loc, outType, x2, absX, mulShift); + auto a3X = rewriter.create(loc, outType, a3, x3, mulShift); sum = rewriter.create(loc, outType, sum, a3X); - auto x4 = rewriter.create(loc, outType, x3, absX, /*shift=*/0); - auto a4X = rewriter.create(loc, outType, a4, x4, /*shift=*/0); + auto x4 = rewriter.create(loc, outType, x3, absX, mulShift); + auto a4X = rewriter.create(loc, outType, a4, x4, mulShift); sum = rewriter.create(loc, outType, sum, a4X); auto rcprl = rewriter.create(loc, outType, sum); auto rcprl2 = - rewriter.create(loc, outType, rcprl, rcprl, /*shift=*/0); + rewriter.create(loc, outType, rcprl, rcprl, mulShift); auto rcprl4 = - rewriter.create(loc, outType, rcprl2, rcprl2, /*shift=*/0); + rewriter.create(loc, outType, rcprl2, rcprl2, mulShift); auto erf = rewriter.create(loc, outType, one, rcprl4); // Deal with negative x. @@ -3553,13 +3564,14 @@ buildUnitNormalCdf(ConversionPatternRewriter &rewriter, Operation *op, Value x, auto mean = zero; Value xMinusMean = rewriter.create(loc, outType, x, mean); - Value erfArg = rewriter.create(loc, outType, xMinusMean, rsqrt2, - /*shift=*/0); + auto mulShift = tosa::getTosaMulShiftConstTensor(rewriter, op, outType, 0); + Value erfArg = + rewriter.create(loc, outType, xMinusMean, rsqrt2, mulShift); Value erf = approximateErfOp(rewriter, op, erfArg, dtype).value(); Value erfPlus1 = rewriter.create(loc, outType, one, erf); - Value normalCdf = rewriter.create(loc, outType, oneHalf, - erfPlus1, /*shift=*/0); + Value normalCdf = + rewriter.create(loc, outType, oneHalf, erfPlus1, mulShift); return normalCdf; } @@ -3599,8 +3611,10 @@ LogicalResult ConvertAtenOp::matchAndRewrite( op->getLoc(), cast(cdf.getType()).cloneWith({}, selfElemTy), cdf); + auto mulShift = + tosa::getTosaMulShiftConstTensor(rewriter, op, resultType, 0); rewriter.replaceOpWithNewOp(op, resultType, self, cdf, - /*shift=*/0); + mulShift); } else if (approximate.compare("tanh") == 0) { // "tanh" approximate // GELU(x) = 0.5 * x * (1 + Tanh(sqrt(2/pi) * (x + 0.044715 * x^3)) @@ -3644,8 +3658,10 @@ LogicalResult ConvertAtenOp::matchAndRewrite( .value(); // 0.5 * x + auto mulShift = + tosa::getTosaMulShiftConstTensor(rewriter, op, resultType, 0); auto halfInput = rewriter.create(op->getLoc(), resultType, - half, self, /*shift=*/0); + half, self, mulShift); // sqrt(2/pi) auto sqrtTwoOverPi = @@ -3658,7 +3674,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // 0.044715 * x^3 auto inputPowThreeMul = rewriter.create(op->getLoc(), resultType, magicNumber, - inputPowThree.getResult(), /*shift=*/0); + inputPowThree.getResult(), mulShift); // x + 0.044715 * x^3 auto inputPowThreeMulAdd = rewriter.create( @@ -3667,7 +3683,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // sqrt(2/pi) * (x + 0.044715 * x^3) auto sqrtTwoOverPiMul = rewriter.create( op->getLoc(), resultType, sqrtTwoOverPi.getResult(), - inputPowThreeMulAdd.getResult(), /*shift=*/0); + inputPowThreeMulAdd.getResult(), mulShift); // tanh(sqrt(2/pi) * (x + 0.044715 * x^3)) auto tanh = rewriter.create(op->getLoc(), resultType, @@ -3678,8 +3694,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( tanh.getResult()); rewriter.replaceOpWithNewOp( - op, resultType, halfInput.getResult(), tanhAdd.getResult(), - /*shift=*/0); + op, resultType, halfInput.getResult(), tanhAdd.getResult(), mulShift); } else { return rewriter.notifyMatchFailure(op, "Unsupported approximation algorithm"); @@ -3732,23 +3747,23 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "Failed to equalize ranks among operands and result"); + auto mulShift = tosa::getTosaMulShiftConstTensor(rewriter, op, selfType, 0); Value inputSquared = - rewriter.create(loc, selfType, self, self, /*shift=*/0); + rewriter.create(loc, selfType, self, self, mulShift); Value negHalfInputSquared = rewriter.create( - loc, selfType, inputSquared, negOneHalf, /*shift=*/0); + loc, selfType, inputSquared, negOneHalf, mulShift); Value dinput = rewriter.create(loc, selfType, negHalfInputSquared); Value cdf = buildUnitNormalCdf(rewriter, op, self, selfElemTy).value(); Value dinputInput = - rewriter.create(loc, selfType, dinput, self, /*shift=*/0); + rewriter.create(loc, selfType, dinput, self, mulShift); Value dinputInputAlpha = rewriter.create( - loc, selfType, dinputInput, kAlphaHalf, /*shift=*/0); + loc, selfType, dinputInput, kAlphaHalf, mulShift); Value cdfExt = rewriter.create(loc, selfType, dinputInputAlpha, cdf); rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), - adaptor.getGradOutput(), cdfExt, - /*shift=*/0); + adaptor.getGradOutput(), cdfExt, mulShift); return success(); } @@ -5232,8 +5247,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( rewriter.create(op->getLoc(), selfType, rhsSubOp); auto lhsAbsOp = rewriter.create(op->getLoc(), otherType, other); + auto mulShift = tosa::getTosaMulShiftConstTensor(rewriter, op, otherType, 0); auto mulOp = rewriter.create(op->getLoc(), otherType, - rtolConstOp, lhsAbsOp, /*shift=*/0); + rtolConstOp, lhsAbsOp, mulShift); auto addOp = rewriter.create(op->getLoc(), otherType, atolConstOp, mulOp); @@ -5778,8 +5794,10 @@ class ConvertAtenRemainderFmodOp : public OpConversionPattern { if (isa(outElemTy)) { auto otherTensorReciprocal = rewriter.create( op.getLoc(), otherTensor.getType(), otherTensor); + auto mulShift = + tosa::getTosaMulShiftConstTensor(rewriter, op, outType, 0); divTensor = rewriter.create( - op.getLoc(), outType, self, otherTensorReciprocal, /*shift=*/0); + op.getLoc(), outType, self, otherTensorReciprocal, mulShift); divTensor = rewriter.create(op.getLoc(), outType, divTensor); } else { @@ -5804,9 +5822,9 @@ class ConvertAtenRemainderFmodOp : public OpConversionPattern { } } - auto mulTensor = rewriter.create(op.getLoc(), outType, - otherTensor, divTensor, - /*shift=*/0); + auto mulShift = tosa::getTosaMulShiftConstTensor(rewriter, op, outType, 0); + auto mulTensor = rewriter.create( + op.getLoc(), outType, otherTensor, divTensor, mulShift); rewriter.replaceOpWithNewOp(op, outType, self, mulTensor); return success(); @@ -7010,8 +7028,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "Failed to equalize ranks among operands and result"); + auto mulShift = tosa::getTosaMulShiftConstTensor(rewriter, op, resultType, 0); rewriter.replaceOpWithNewOp(op, resultType, self, trilMask, - /*shift=*/0); + mulShift); return success(); } @@ -7106,15 +7125,16 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto ceilInput = rewriter.create(op->getLoc(), resultTy, self); + auto mulShift = tosa::getTosaMulShiftConstTensor(rewriter, op, resultTy, 0); auto floorInputDivByTwo = rewriter.create( - op->getLoc(), resultTy, floorInput.getResult(), oneHalf, /*shift=*/0); + op->getLoc(), resultTy, floorInput.getResult(), oneHalf, mulShift); auto floorDivResult = rewriter.create( op->getLoc(), resultTy, floorInputDivByTwo.getResult()); // (floor(input) // 2) * 2 auto evenComparison = rewriter.create( - op->getLoc(), resultTy, floorDivResult.getResult(), two, /*shift=*/0); + op->getLoc(), resultTy, floorDivResult.getResult(), two, mulShift); // floor(input) // 2) * 2 == input <=> floor(input) % 2 == 0 auto floorInputEven = rewriter.create( @@ -7296,9 +7316,11 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "Failed to equalize ranks among operands and result"); - Value diagonalTensor = rewriter.create( - op->getLoc(), transposedInputType, selfTransposed, diagonalMask, - /*shift=*/0); + auto mulShift = + tosa::getTosaMulShiftConstTensor(rewriter, op, transposedInputType, 0); + Value diagonalTensor = + rewriter.create(op->getLoc(), transposedInputType, + selfTransposed, diagonalMask, mulShift); auto resultShape = makeShapeTorchCompatible(resultType.getShape()); auto targetReduceDim = resultShape[resultType.getRank() - 1]; @@ -8587,9 +8609,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto oneMinusZiReciprocal = rewriter.create( op->getLoc(), resultType, oneMinusZi.getResult()); - auto mulOp = rewriter.create(op->getLoc(), resultType, zi, - oneMinusZiReciprocal.getResult(), - /*shift=*/0); + auto mulShift = tosa::getTosaMulShiftConstTensor(rewriter, op, resultType, 0); + auto mulOp = rewriter.create( + op->getLoc(), resultType, zi, oneMinusZiReciprocal.getResult(), mulShift); auto result = rewriter.create(op->getLoc(), resultType, mulOp.getResult()); @@ -8687,9 +8709,10 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto reciprocalOp = rewriter.create( op->getLoc(), constTenType, logOfTen.getResult()); + auto mulShift = tosa::getTosaMulShiftConstTensor(rewriter, op, resultType, 0); auto result = rewriter.create( op->getLoc(), resultType, logOfSelf.getResult(), reciprocalOp.getResult(), - /*shift=*/0); + mulShift); rewriter.replaceOp(op, {result.getResult()}); @@ -8772,9 +8795,10 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto reciprocalOp = rewriter.create(op->getLoc(), resultType, cosOp); - auto result = rewriter.create( - op->getLoc(), resultType, sinOp.getResult(), reciprocalOp.getResult(), - /*shift=*/0); + auto mulShift = tosa::getTosaMulShiftConstTensor(rewriter, op, resultType, 0); + auto result = + rewriter.create(op->getLoc(), resultType, sinOp.getResult(), + reciprocalOp.getResult(), mulShift); rewriter.replaceOp(op, {result.getResult()}); diff --git a/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp b/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp index 1f18cabd8cb0..9e769bbc17c6 100644 --- a/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp +++ b/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp @@ -119,8 +119,10 @@ tosa::MulOp createMulOpAndCast(PatternRewriter &rewriter, Operation *op, int32_t shift) { lhs = promoteType(rewriter, lhs, outType); rhs = promoteType(rewriter, rhs, outType); + auto constShift = + tosa::getTosaMulShiftConstTensor(rewriter, op, outType, shift); return tosa::CreateOpAndInfer(rewriter, op->getLoc(), outType, - lhs, rhs, shift); + lhs, rhs, constShift); } template <> @@ -386,10 +388,13 @@ std::optional convertGatherNdOp(PatternRewriter &rewriter, Operation *op, // Multiply the coefficients by the coordinates // %5 = "tosa.mul"(%3, %4) {shift = 0 : i32} : (tensor<8x3xi32>, // tensor<3xi32>) -> tensor<8x3xi32> + auto flattenedMulType = + GetTypeFromTensorShape(indicesMatrixShape, indicesType.getElementType()); + auto mulShift = + tosa::getTosaMulShiftConstTensor(rewriter, op, flattenedMulType, 0); auto flattenedIndicesMulOp = tosa::CreateOpAndInfer( - rewriter, op->getLoc(), - GetTypeFromTensorShape(indicesMatrixShape, indicesType.getElementType()), - indicesMatrixReshapeOp, flattenedCoeffValue.value(), 0); + rewriter, op->getLoc(), flattenedMulType, indicesMatrixReshapeOp, + flattenedCoeffValue.value(), mulShift); // Sum up the products of the coefficients and coordinates // %6 = "tosa.reduce_sum"(%5) {axis = 1 : i64} : (tensor<8x3xi32>) -> @@ -657,10 +662,13 @@ std::optional convertScatterNdOp(PatternRewriter &rewriter, // [[0, 1], [0, 2], [0, 3]] X [4, 1] -> [[4*0, 1*1], [4*0, 1*2], [4*0, 1*3]] // %13 = "tosa.mul"(%11, %12) {shift = 0 : i32} : (tensor<3x2xi32>, // tensor<2xi32>) -> tensor<3x2xi32> + auto flattenedMulType = + GetTypeFromTensorShape(indicesMatrixShape, indicesType.getElementType()); + auto mulShift = + tosa::getTosaMulShiftConstTensor(rewriter, op, flattenedMulType, 0); auto flattenedIndicesMulOp = tosa::CreateOpAndInfer( - rewriter, op->getLoc(), - GetTypeFromTensorShape(indicesMatrixShape, indicesType.getElementType()), - indicesMatrixReshapeOp, flattenedCoeffValue.value(), 0); + rewriter, op->getLoc(), flattenedMulType, indicesMatrixReshapeOp, + flattenedCoeffValue.value(), mulShift); // Sum up the products of the coefficients and coordinates // [[4*0 + 1*1], [4*0 + 1*2], [4*0 + 1*3]] = [[1],[2],[3]] @@ -1006,8 +1014,10 @@ convertReduceMeanOp(PatternRewriter &rewriter, Operation *op, .failed()) return std::nullopt; + auto mulShift = + tosa::getTosaMulShiftConstTensor(rewriter, op, output_type, 0); return CreateOpAndInfer(rewriter, op->getLoc(), output_type, - val.value(), div_const, 0) + val.value(), div_const, mulShift) .getResult(); } diff --git a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp index 2243c8dcfd83..38de55492e71 100644 --- a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp +++ b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp @@ -163,6 +163,18 @@ Value getTosaConstTensorSingleF32(PatternRewriter &rewriter, Operation *op, return const_op.getResult(); } +Value getTosaMulShiftConstTensor(PatternRewriter &rewriter, Operation *op, + Type /*resultType*/, int32_t shift) { + auto shiftType = RankedTensorType::get({1}, rewriter.getI8Type()); + auto shiftAttr = DenseElementsAttr::get( + shiftType, llvm::ArrayRef{static_cast(shift)}); + + auto constShift = + rewriter.create(op->getLoc(), shiftType, shiftAttr); + + return constShift.getResult(); +} + // Create a zero constant tensor of the desired type and shape. std::optional getZerosLikeTensor(PatternRewriter &rewriter, Operation *op, Type type) { diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index 941d710e4f2e..4b8bbb609520 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -262,7 +262,7 @@ func.func @torch.aten.relu$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vte // CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor}> : () -> tensor // CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_5]] {new_shape = array} : (tensor) -> tensor<1x1xf32> // CHECK: %[[VAL_7:.*]] = tosa.greater_equal %[[VAL_1]], %[[VAL_6]] : (tensor, tensor<1x1xf32>) -> tensor -// CHECK: %[[VAL_8:.*]] = tosa.mul %[[VAL_1]], %[[VAL_4]] {shift = 0 : i8} : (tensor, tensor<1x1xf32>) -> tensor +// CHECK: %[[VAL_8:.*]] = tosa.mul %[[VAL_1]], {{.*}}, {{.*}}: (tensor, tensor<1x1xf32>, tensor<1xi8>) -> tensor // CHECK: %[[VAL_9:.*]] = tosa.select %[[VAL_7]], %[[VAL_1]], %[[VAL_8]] : (tensor, tensor, tensor) -> tensor // CHECK: %[[VAL_10:.*]] = torch_c.from_builtin_tensor %[[VAL_9]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[VAL_10]] : !torch.vtensor<[?,?],f32> @@ -377,7 +377,7 @@ func.func @torch.aten.reciprocal$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !tor // CHECK: %[[VAL_4:.*]] = torch.constant.int 1 // CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor}> : () -> tensor // CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_5]] {new_shape = array} : (tensor) -> tensor<1x1xf32> -// CHECK: %[[VAL_7:.*]] = tosa.mul %[[VAL_2]], %[[VAL_6]] {shift = 0 : i8} : (tensor, tensor<1x1xf32>) -> tensor +// CHECK: %[[VAL_7:.*]] = tosa.mul %[[VAL_2]], {{.*}}, {{.*}}: (tensor, tensor<1x1xf32>, tensor<1xi8>) -> tensor // CHECK: %[[VAL_8:.*]] = tosa.add %[[VAL_3]], %[[VAL_7]] : (tensor, tensor) -> tensor // CHECK: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_8]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[VAL_9]] : !torch.vtensor<[?,?],f32> @@ -398,7 +398,7 @@ func.func @torch.aten.add$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch // CHECK: %[[VAL_4:.*]] = torch.constant.int 1 // CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor}> : () -> tensor // CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_5]] {new_shape = array} : (tensor) -> tensor<1x1xf32> -// CHECK: %[[VAL_7:.*]] = tosa.mul %[[VAL_2]], %[[VAL_6]] {shift = 0 : i8} : (tensor, tensor<1x1xf32>) -> tensor +// CHECK: %[[VAL_7:.*]] = tosa.mul %[[VAL_2]], {{.*}}, {{.*}}: (tensor, tensor<1x1xf32>, tensor<1xi8>) -> tensor // CHECK: %[[VAL_8:.*]] = tosa.sub %[[VAL_3]], %[[VAL_7]] : (tensor, tensor) -> tensor // CHECK: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_8]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[VAL_9]] : !torch.vtensor<[?,?],f32> @@ -416,7 +416,7 @@ func.func @torch.aten.sub$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch // CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK-DAG: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK-DAG: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[VAL_4:.*]] = tosa.mul %[[VAL_2]], %[[VAL_3]] {shift = 0 : i8} : (tensor, tensor) -> tensor +// CHECK: %[[VAL_4:.*]] = tosa.mul %[[VAL_2]], {{.*}}, {{.*}}: (tensor, tensor, tensor<1xi8>) -> tensor // CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],f32> // CHECK: } @@ -433,7 +433,7 @@ func.func @torch.aten.mul$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch // CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[VAL_4:.*]] = tosa.reciprocal %[[VAL_2]] : (tensor) -> tensor -// CHECK: %[[VAL_5:.*]] = tosa.mul %[[VAL_3]], %[[VAL_4]] {shift = 0 : i8} : (tensor, tensor) -> tensor +// CHECK: %[[VAL_5:.*]] = tosa.mul %[[VAL_3]], {{.*}}, {{.*}}: (tensor, tensor, tensor<1xi8>) -> tensor // CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[VAL_6]] : !torch.vtensor<[?,?],f32> // CHECK: } @@ -469,7 +469,7 @@ func.func @torch.aten.rsqrt$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vt // CHECK: %[[VAL_7:.*]] = tosa.reshape %[[VAL_6]] {new_shape = array} : (tensor<1x?x?x?xf32>) -> tensor // CHECK: %[[VAL_8:.*]] = "tosa.const"() <{value = dense<-1.08420217E-19> : tensor}> : () -> tensor // CHECK: %[[VAL_9:.*]] = tosa.reshape %[[VAL_8]] {new_shape = array} : (tensor) -> tensor<1x1x1xf32> -// CHECK: %[[VAL_10:.*]] = tosa.mul %[[VAL_7]], %[[VAL_9]] {shift = 0 : i8} : (tensor, tensor<1x1x1xf32>) -> tensor +// CHECK: %[[VAL_10:.*]] = tosa.mul %[[VAL_7]], {{.*}}, {{.*}}: (tensor, tensor<1x1x1xf32>, tensor<1xi8>) -> tensor // CHECK: %[[VAL_11:.*]] = torch_c.from_builtin_tensor %[[VAL_10]] : tensor -> !torch.vtensor<[?,?,?],f32> // CHECK: return %[[VAL_11]] : !torch.vtensor<[?,?,?],f32> // CHECK: } @@ -682,7 +682,7 @@ func.func @torch.aten.pow.Tensor_Scalar$basic(%arg0: !torch.vtensor<[?,?],f32>) // CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<6.432100e+00> : tensor}> : () -> tensor // CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_4]] {new_shape = array} : (tensor) -> tensor<1x1xf32> // CHECK: %[[VAL_7:.*]] = tosa.reshape %[[VAL_5]] {new_shape = array} : (tensor) -> tensor<1x1xf32> -// CHECK: %[[VAL_8:.*]] = tosa.mul %[[VAL_1]], %[[VAL_7]] {shift = 0 : i8} : (tensor, tensor<1x1xf32>) -> tensor +// CHECK: %[[VAL_8:.*]] = tosa.mul %[[VAL_1]], {{.*}}, {{.*}}: (tensor, tensor<1x1xf32>, tensor<1xi8>) -> tensor // CHECK: %[[VAL_9:.*]] = tosa.sub %[[VAL_6]], %[[VAL_8]] : (tensor<1x1xf32>, tensor) -> tensor // CHECK: %[[VAL_10:.*]] = torch_c.from_builtin_tensor %[[VAL_9]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[VAL_10]] : !torch.vtensor<[?,?],f32> @@ -705,7 +705,7 @@ func.func @torch.aten.rsub.Scalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !to // CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor}> : () -> tensor // CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_4]] {new_shape = array} : (tensor) -> tensor<1x1xf32> // CHECK: %[[VAL_7:.*]] = tosa.reshape %[[VAL_5]] {new_shape = array} : (tensor) -> tensor<1x1xf32> -// CHECK: %[[VAL_8:.*]] = tosa.mul %[[VAL_1]], %[[VAL_7]] {shift = 0 : i8} : (tensor, tensor<1x1xf32>) -> tensor +// CHECK: %[[VAL_8:.*]] = tosa.mul %[[VAL_1]], {{.*}}, {{.*}}: (tensor, tensor<1x1xf32>, tensor<1xi8>) -> tensor // CHECK: %[[VAL_9:.*]] = tosa.sub %[[VAL_6]], %[[VAL_8]] : (tensor<1x1xf32>, tensor) -> tensor // CHECK: %[[VAL_10:.*]] = torch_c.from_builtin_tensor %[[VAL_9]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[VAL_10]] : !torch.vtensor<[?,?],f32> @@ -807,8 +807,8 @@ func.func @torch.aten.reshape$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !to // CHECK: %[[VAL_18:.*]] = tosa.sub %[[VAL_1]], %[[VAL_13]] : (tensor<10x4x3xf32>, tensor<1x4x1xf32>) -> tensor<10x4x3xf32> // CHECK: %[[VAL_19:.*]] = tosa.add %[[VAL_14]], %[[VAL_15]] : (tensor<1x4x1xf32>, tensor<1x1x1xf32>) -> tensor<1x4x1xf32> // CHECK: %[[VAL_20:.*]] = tosa.rsqrt %[[VAL_19]] : (tensor<1x4x1xf32>) -> tensor<1x4x1xf32> -// CHECK: %[[VAL_21:.*]] = tosa.mul %[[VAL_18]], %[[VAL_20]] {shift = 0 : i8} : (tensor<10x4x3xf32>, tensor<1x4x1xf32>) -> tensor<10x4x3xf32> -// CHECK: %[[VAL_22:.*]] = tosa.mul %[[VAL_21]], %[[VAL_16]] {shift = 0 : i8} : (tensor<10x4x3xf32>, tensor<1x4x1xf32>) -> tensor<10x4x3xf32> +// CHECK: %[[VAL_21:.*]] = tosa.mul %[[VAL_18]], {{.*}}, {{.*}}: (tensor<10x4x3xf32>, tensor<1x4x1xf32>, tensor<1xi8>) -> tensor<10x4x3xf32> +// CHECK: %[[VAL_22:.*]] = tosa.mul %[[VAL_21]], {{.*}}, {{.*}}: (tensor<10x4x3xf32>, tensor<1x4x1xf32>, tensor<1xi8>) -> tensor<10x4x3xf32> // CHECK: %[[VAL_23:.*]] = tosa.add %[[VAL_22]], %[[VAL_17]] : (tensor<10x4x3xf32>, tensor<1x4x1xf32>) -> tensor<10x4x3xf32> // CHECK: %[[VAL_24:.*]] = torch_c.from_builtin_tensor %[[VAL_23]] : tensor<10x4x3xf32> -> !torch.vtensor<[10,4,3],f32> // CHECK: return %[[VAL_24]] : !torch.vtensor<[10,4,3],f32> @@ -885,14 +885,14 @@ func.func @forward(%arg0: !torch.vtensor<[1,6,4],f32> ) -> !torch.vtensor<[1,2,3 // CHECK: %[[VAL_14:.*]] = tosa.reduce_sum %[[VAL_13]] {axis = 2 : i32} : (tensor<5x2x2x1xf32>) -> tensor<5x2x1x1xf32> // CHECK: %[[VAL_15:.*]] = tosa.reduce_sum %[[VAL_14]] {axis = 1 : i32} : (tensor<5x2x1x1xf32>) -> tensor<5x1x1x1xf32> // CHECK: %[[VAL_16:.*]] = tosa.reshape %[[VAL_15]] {new_shape = array} : (tensor<5x1x1x1xf32>) -> tensor<5x1x1x1xf32> -// CHECK: %[[VAL_17:.*]] = tosa.mul %[[VAL_16]], %[[VAL_12]] {shift = 0 : i8} : (tensor<5x1x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<5x1x1x1xf32> +// CHECK: %[[VAL_17:.*]] = tosa.mul %[[VAL_16]], {{.*}}, {{.*}}: (tensor<5x1x1x1xf32>, tensor<1x1x1x1xf32>, tensor<1xi8>) -> tensor<5x1x1x1xf32> // CHECK: %[[VAL_18:.*]] = tosa.sub %[[VAL_5]], %[[VAL_17]] : (tensor<5x2x2x3xf32>, tensor<5x1x1x1xf32>) -> tensor<5x2x2x3xf32> -// CHECK: %[[VAL_19:.*]] = tosa.mul %[[VAL_18]], %[[VAL_18]] {shift = 0 : i8} : (tensor<5x2x2x3xf32>, tensor<5x2x2x3xf32>) -> tensor<5x2x2x3xf32> +// CHECK: %[[VAL_19:.*]] = tosa.mul %[[VAL_18]], {{.*}}, {{.*}}: (tensor<5x2x2x3xf32>, tensor<5x2x2x3xf32>, tensor<1xi8>) -> tensor<5x2x2x3xf32> // CHECK: %[[VAL_20:.*]] = tosa.reduce_sum %[[VAL_19]] {axis = 3 : i32} : (tensor<5x2x2x3xf32>) -> tensor<5x2x2x1xf32> // CHECK: %[[VAL_21:.*]] = tosa.reduce_sum %[[VAL_20]] {axis = 2 : i32} : (tensor<5x2x2x1xf32>) -> tensor<5x2x1x1xf32> // CHECK: %[[VAL_22:.*]] = tosa.reduce_sum %[[VAL_21]] {axis = 1 : i32} : (tensor<5x2x1x1xf32>) -> tensor<5x1x1x1xf32> // CHECK: %[[VAL_23:.*]] = tosa.reshape %[[VAL_22]] {new_shape = array} : (tensor<5x1x1x1xf32>) -> tensor<5x1x1x1xf32> -// CHECK: %[[VAL_24:.*]] = tosa.mul %[[VAL_23]], %[[VAL_12]] {shift = 0 : i8} : (tensor<5x1x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<5x1x1x1xf32> +// CHECK: %[[VAL_24:.*]] = tosa.mul %[[VAL_23]], {{.*}}, {{.*}}: (tensor<5x1x1x1xf32>, tensor<1x1x1x1xf32>, tensor<1xi8>) -> tensor<5x1x1x1xf32> // CHECK: %[[VAL_25:.*]] = tosa.reshape %[[VAL_4]] {new_shape = array} : (tensor<2x2x3xf32>) -> tensor<1x2x2x3xf32> // CHECK: %[[VAL_26:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor<2x2x3xf32>) -> tensor<1x2x2x3xf32> // CHECK: %[[VAL_27:.*]] = "tosa.const"() <{value = dense<5.000000e-01> : tensor}> : () -> tensor @@ -900,8 +900,8 @@ func.func @forward(%arg0: !torch.vtensor<[1,6,4],f32> ) -> !torch.vtensor<[1,2,3 // CHECK: %[[VAL_29:.*]] = tosa.sub %[[VAL_5]], %[[VAL_17]] : (tensor<5x2x2x3xf32>, tensor<5x1x1x1xf32>) -> tensor<5x2x2x3xf32> // CHECK: %[[VAL_30:.*]] = tosa.add %[[VAL_24]], %[[VAL_28]] : (tensor<5x1x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<5x1x1x1xf32> // CHECK: %[[VAL_31:.*]] = tosa.rsqrt %[[VAL_30]] : (tensor<5x1x1x1xf32>) -> tensor<5x1x1x1xf32> -// CHECK: %[[VAL_32:.*]] = tosa.mul %[[VAL_29]], %[[VAL_31]] {shift = 0 : i8} : (tensor<5x2x2x3xf32>, tensor<5x1x1x1xf32>) -> tensor<5x2x2x3xf32> -// CHECK: %[[VAL_33:.*]] = tosa.mul %[[VAL_32]], %[[VAL_25]] {shift = 0 : i8} : (tensor<5x2x2x3xf32>, tensor<1x2x2x3xf32>) -> tensor<5x2x2x3xf32> +// CHECK: %[[VAL_32:.*]] = tosa.mul %[[VAL_29]], {{.*}}, {{.*}}: (tensor<5x2x2x3xf32>, tensor<5x1x1x1xf32>, tensor<1xi8>) -> tensor<5x2x2x3xf32> +// CHECK: %[[VAL_33:.*]] = tosa.mul %[[VAL_32]], {{.*}}, {{.*}}: (tensor<5x2x2x3xf32>, tensor<1x2x2x3xf32>, tensor<1xi8>) -> tensor<5x2x2x3xf32> // CHECK: %[[VAL_34:.*]] = tosa.add %[[VAL_33]], %[[VAL_26]] : (tensor<5x2x2x3xf32>, tensor<1x2x2x3xf32>) -> tensor<5x2x2x3xf32> // CHECK: %[[VAL_35:.*]] = torch_c.from_builtin_tensor %[[VAL_34]] : tensor<5x2x2x3xf32> -> !torch.vtensor<[5,2,2,3],f32> // CHECK: return %[[VAL_35]] : !torch.vtensor<[5,2,2,3],f32> @@ -995,7 +995,7 @@ func.func @torch.aten.bitwise_and.Tensor$basic(%arg0: !torch.vtensor<[?,?],si32> // CHECK: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<0.693147182> : tensor<1x1xf32>}> : () -> tensor<1x1xf32> // CHECK: %[[VAL_3:.*]] = tosa.reciprocal %[[VAL_2]] : (tensor<1x1xf32>) -> tensor<1x1xf32> // CHECK: %[[VAL_4:.*]] = tosa.log %[[VAL_1]] : (tensor) -> tensor -// CHECK: %[[VAL_5:.*]] = tosa.mul %[[VAL_4]], %[[VAL_3]] {shift = 0 : i8} : (tensor, tensor<1x1xf32>) -> tensor +// CHECK: %[[VAL_5:.*]] = tosa.mul %[[VAL_4]], {{.*}}, {{.*}}: (tensor, tensor<1x1xf32>, tensor<1xi8>) -> tensor // CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[VAL_6]] : !torch.vtensor<[?,?],f32> // CHECK: } @@ -1325,7 +1325,7 @@ func.func @torch.aten.to.dtype$floatToInt(%arg0: !torch.vtensor<[3,5],f32>) -> ! // CHECK: %[[VAL_12:.*]] = tosa.reshape %[[VAL_10]] {new_shape = array} : (tensor<1x4x2x3xi32>) -> tensor<8x3xi32> // CHECK: %[[VAL_13:.*]] = "tosa.const"() <{value = dense<[12, 3, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> // CHECK: %[[VAL_14:.*]] = tosa.reshape %[[VAL_13]] {new_shape = array} : (tensor<3xi32>) -> tensor<1x3xi32> -// CHECK: %[[VAL_15:.*]] = tosa.mul %[[VAL_12]], %[[VAL_14]] {shift = 0 : i8} : (tensor<8x3xi32>, tensor<1x3xi32>) -> tensor<8x3xi32> +// CHECK: %[[VAL_15:.*]] = tosa.mul %[[VAL_12]], %[[VAL_14]], %{{.*}} : (tensor<8x3xi32>, tensor<1x3xi32>, tensor<1xi8>) -> tensor<8x3xi32> // CHECK: %[[VAL_16:.*]] = tosa.reduce_sum %[[VAL_15]] {axis = 1 : i32} : (tensor<8x3xi32>) -> tensor<8x1xi32> // CHECK: %[[VAL_17:.*]] = tosa.reshape %[[VAL_16]] {new_shape = array} : (tensor<8x1xi32>) -> tensor<1x8xi32> // CHECK: %[[VAL_18:.*]] = tosa.gather %[[VAL_11]], %[[VAL_17]] : (tensor<1x12x1xf32>, tensor<1x8xi32>) -> tensor<1x8x1xf32> @@ -1349,7 +1349,7 @@ func.func @torch.aten.gather(%arg0: !torch.vtensor<[1,4,3],f32>, %arg1: !torch.v // CHECK: %[[VAL_4:.*]] = torch.constant.int 1 // CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<1> : tensor}> : () -> tensor // CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_5]] {new_shape = array} : (tensor) -> tensor<1x1xi32> -// CHECK: %[[VAL_7:.*]] = tosa.mul %[[VAL_2]], %[[VAL_6]] {shift = 0 : i8} : (tensor<2x2xi32>, tensor<1x1xi32>) -> tensor<2x2xi32> +// CHECK: %[[VAL_7:.*]] = tosa.mul %[[VAL_2]], %[[VAL_6]], %{{.*}} : (tensor<2x2xi32>, tensor<1x1xi32>, tensor<1xi8>) -> tensor<2x2xi32> // CHECK: %[[VAL_8:.*]] = tosa.add %[[VAL_3]], %[[VAL_7]] : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> // CHECK: %[[VAL_9:.*]] = tosa.cast %[[VAL_8]] : (tensor<2x2xi32>) -> tensor<2x2xi64> // CHECK: %[[VAL_10:.*]] = torch_c.from_builtin_tensor %[[VAL_9]] : tensor<2x2xi64> -> !torch.vtensor<[2,2],si64> @@ -1372,7 +1372,7 @@ func.func @torch.aten.add$basic(%arg0: !torch.vtensor<[2, 2],si32>, %arg1: !torc // CHECK: %[[VAL_5_cast:.*]] = tosa.cast %[[VAL_5]] : (tensor<1x1x1x1xi64>) -> tensor<1x1x1x1xi32> // CHECK: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<1> : tensor}> : () -> tensor // CHECK: %[[VAL_7:.*]] = tosa.reshape %[[VAL_6]] {new_shape = array} : (tensor) -> tensor<1x1x1x1xi32> -// CHECK: %[[VAL_8:.*]] = tosa.mul %[[VAL_5_cast]], %[[VAL_7]] {shift = 0 : i8} : (tensor<1x1x1x1xi32>, tensor<1x1x1x1xi32>) -> tensor<1x1x1x1xi32> +// CHECK: %[[VAL_8:.*]] = tosa.mul %[[VAL_5_cast]], %[[VAL_7]], %{{.*}} : (tensor<1x1x1x1xi32>, tensor<1x1x1x1xi32>, tensor<1xi8>) -> tensor<1x1x1x1xi32> // CHECK: %[[VAL_9:.*]] = tosa.cast %[[VAL_1]] : (tensor<1x1x128x128xi64>) -> tensor<1x1x128x128xi32> // CHECK: %[[VAL_10:.*]] = tosa.add %[[VAL_9]], %[[VAL_8]] : (tensor<1x1x128x128xi32>, tensor<1x1x1x1xi32>) -> tensor<1x1x128x128xi32> // CHECK: %[[VAL_11:.*]] = tosa.cast %[[VAL_10]] : (tensor<1x1x128x128xi32>) -> tensor<1x1x128x128xi64> @@ -1554,9 +1554,9 @@ func.func @torch.aten.where.self(%arg0: !torch.vtensor<[1,1,5,5],i1>, %arg1: !to // CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<2.000000e+00> : tensor}> : () -> tensor // CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor) -> tensor<1x1xf32> // CHECK: %[[VAL_5:.*]] = tosa.reciprocal %[[VAL_4]] : (tensor<1x1xf32>) -> tensor<1x1xf32> -// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_1]], %[[VAL_5]] {shift = 0 : i8} : (tensor<2x4xf32>, tensor<1x1xf32>) -> tensor<2x4xf32> +// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_1]], {{.*}}, {{.*}}: (tensor<2x4xf32>, tensor<1x1xf32>, tensor<1xi8>) -> tensor<2x4xf32> // CHECK: %[[VAL_7:.*]] = tosa.floor %[[VAL_6]] : (tensor<2x4xf32>) -> tensor<2x4xf32> -// CHECK: %[[VAL_8:.*]] = tosa.mul %[[VAL_4]], %[[VAL_7]] {shift = 0 : i8} : (tensor<1x1xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> +// CHECK: %[[VAL_8:.*]] = tosa.mul %[[VAL_4]], {{.*}}, {{.*}}: (tensor<1x1xf32>, tensor<2x4xf32>, tensor<1xi8>) -> tensor<2x4xf32> // CHECK: %[[VAL_9:.*]] = tosa.sub %[[VAL_1]], %[[VAL_8]] : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> // CHECK: %[[VAL_10:.*]] = torch_c.from_builtin_tensor %[[VAL_9]] : tensor<2x4xf32> -> !torch.vtensor<[2,4],f32> // CHECK: return %[[VAL_10]] : !torch.vtensor<[2,4],f32> @@ -1584,7 +1584,7 @@ func.func @torch.aten.remainder.Scalar(%arg0: !torch.vtensor<[2, 4],f32>) -> !to // CHECK: %[[VAL_11:.*]] = tosa.sub %[[VAL_3]], %[[VAL_2]] : (tensor<5x5xf32>, tensor<5x5xf32>) -> tensor<5x5xf32> // CHECK: %[[VAL_12:.*]] = tosa.abs %[[VAL_11]] : (tensor<5x5xf32>) -> tensor<5x5xf32> // CHECK: %[[VAL_13:.*]] = tosa.abs %[[VAL_2]] : (tensor<5x5xf32>) -> tensor<5x5xf32> -// CHECK: %[[VAL_14:.*]] = tosa.mul %[[VAL_9]], %[[VAL_13]] {shift = 0 : i8} : (tensor<1x1xf32>, tensor<5x5xf32>) -> tensor<5x5xf32> +// CHECK: %[[VAL_14:.*]] = tosa.mul %[[VAL_9]], {{.*}}, {{.*}}: (tensor<1x1xf32>, tensor<5x5xf32>, tensor<1xi8>) -> tensor<5x5xf32> // CHECK: %[[VAL_15:.*]] = tosa.add %[[VAL_10]], %[[VAL_14]] : (tensor<1x1xf32>, tensor<5x5xf32>) -> tensor<5x5xf32> // CHECK: %[[VAL_16:.*]] = tosa.greater_equal %[[VAL_15]], %[[VAL_12]] : (tensor<5x5xf32>, tensor<5x5xf32>) -> tensor<5x5xi1> // CHECK: %[[VAL_17:.*]] = torch_c.from_builtin_tensor %[[VAL_16]] : tensor<5x5xi1> -> !torch.vtensor<[5,5],i1> @@ -1687,7 +1687,7 @@ func.func @torch.aten.__interpolate.size_list_scale_list.nearest(%arg0: !torch.v // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,4],si32> -> tensor<2x4xi32> // CHECK: %[[VAL_2:.*]] = torch.constant.int 1 // CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<{{\[\[}}1, 1, 0, 0], [1, 1, 1, 0]]> : tensor<2x4xi32>}> : () -> tensor<2x4xi32> -// CHECK: %[[VAL_4:.*]] = tosa.mul %[[VAL_1]], %[[VAL_3]] {shift = 0 : i8} : (tensor<2x4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32> +// CHECK: %[[VAL_4:.*]] = tosa.mul %[[VAL_1]], %[[VAL_3]], %{{.*}} : (tensor<2x4xi32>, tensor<2x4xi32>, tensor<1xi8>) -> tensor<2x4xi32> // CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<2x4xi32> -> !torch.vtensor<[2,4],si32> // CHECK: return %[[VAL_5]] : !torch.vtensor<[2,4],si32> // CHECK: } @@ -1803,7 +1803,7 @@ func.func @torch.aten.all.dim$basic(%arg0: !torch.vtensor<[3,2,3],i1>) -> !torch // CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[VAL_4:.*]] = torch.constant.str "trunc" // CHECK: %[[VAL_5:.*]] = tosa.reciprocal %[[VAL_2]] : (tensor) -> tensor -// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_3]], %[[VAL_5]] {shift = 0 : i8} : (tensor, tensor) -> tensor +// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_3]], {{.*}}, {{.*}}: (tensor, tensor, tensor<1xi8>) -> tensor // CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor}> : () -> tensor // CHECK: %[[VAL_8:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor}> : () -> tensor // CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<-1.000000e+00> : tensor}> : () -> tensor @@ -1814,7 +1814,7 @@ func.func @torch.aten.all.dim$basic(%arg0: !torch.vtensor<[3,2,3],i1>) -> !torch // CHECK: %[[VAL_14:.*]] = tosa.select %[[VAL_13]], %[[VAL_10]], %[[VAL_12]] : (tensor, tensor<1x1xf32>, tensor<1x1xf32>) -> tensor // CHECK: %[[VAL_15:.*]] = tosa.abs %[[VAL_6]] : (tensor) -> tensor // CHECK: %[[VAL_16:.*]] = tosa.floor %[[VAL_15]] : (tensor) -> tensor -// CHECK: %[[VAL_17:.*]] = tosa.mul %[[VAL_16]], %[[VAL_14]] {shift = 0 : i8} : (tensor, tensor) -> tensor +// CHECK: %[[VAL_17:.*]] = tosa.mul %[[VAL_16]], {{.*}}, {{.*}}: (tensor, tensor, tensor<1xi8>) -> tensor // CHECK: %[[VAL_18:.*]] = torch_c.from_builtin_tensor %[[VAL_17]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[VAL_18]] : !torch.vtensor<[?,?],f32> // CHECK: } @@ -1854,7 +1854,7 @@ func.func @torch.aten.div.Tensor_mode$int_trunc(%arg0: !torch.vtensor<[?, ?],si6 // CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[VAL_4:.*]] = torch.constant.str "floor" // CHECK: %[[VAL_5:.*]] = tosa.reciprocal %[[VAL_2]] : (tensor) -> tensor -// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_3]], %[[VAL_5]] {shift = 0 : i8} : (tensor, tensor) -> tensor +// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_3]], {{.*}}, {{.*}}: (tensor, tensor, tensor<1xi8>) -> tensor // CHECK: %[[VAL_7:.*]] = tosa.floor %[[VAL_6]] : (tensor) -> tensor // CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[VAL_8]] : !torch.vtensor<[?,?],f32> @@ -1880,9 +1880,9 @@ func.func @torch.aten.div.Tensor_mode$float_floor(%arg0: !torch.vtensor<[?, ?],f // CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<1> : tensor}> : () -> tensor // CHECK: %[[VAL_10:.*]] = tosa.reshape %[[VAL_9]] {new_shape = array} : (tensor) -> tensor<1x1xi32> // CHECK: %[[VAL_11:.*]] = tosa.reshape %[[VAL_8]] {new_shape = array} : (tensor) -> tensor<1x1xi32> -// CHECK: %[[VAL_12:.*]] = tosa.mul %[[VAL_5]], %[[VAL_6]] {shift = 0 : i8} : (tensor, tensor) -> tensor +// CHECK: %[[VAL_12:.*]] = tosa.mul %[[VAL_5]], %[[VAL_6]], %{{.*}} : (tensor, tensor, tensor<1xi8>) -> tensor // CHECK: %[[VAL_13:.*]] = tosa.greater %[[VAL_11]], %[[VAL_12]] : (tensor<1x1xi32>, tensor) -> tensor -// CHECK: %[[VAL_14:.*]] = tosa.mul %[[VAL_7]], %[[VAL_6]] {shift = 0 : i8} : (tensor, tensor) -> tensor +// CHECK: %[[VAL_14:.*]] = tosa.mul %[[VAL_7]], %[[VAL_6]], %{{.*}} : (tensor, tensor, tensor<1xi8>) -> tensor // CHECK: %[[VAL_15:.*]] = tosa.equal %[[VAL_14]], %[[VAL_5]] : (tensor, tensor) -> tensor // CHECK: %[[VAL_16:.*]] = tosa.logical_not %[[VAL_15]] : (tensor) -> tensor // CHECK: %[[VAL_17:.*]] = tosa.sub %[[VAL_7]], %[[VAL_10]] : (tensor, tensor<1x1xi32>) -> tensor @@ -1907,7 +1907,7 @@ func.func @torch.aten.div.Tensor_mode$int_floor(%arg0: !torch.vtensor<[?, ?],si6 // CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[VAL_4:.*]] = torch.constant.str "" // CHECK: %[[VAL_5:.*]] = tosa.reciprocal %[[VAL_2]] : (tensor) -> tensor -// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_3]], %[[VAL_5]] {shift = 0 : i8} : (tensor, tensor) -> tensor +// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_3]], {{.*}}, {{.*}}: (tensor, tensor, tensor<1xi8>) -> tensor // CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[VAL_7]] : !torch.vtensor<[?,?],f32> // CHECK: } @@ -1962,9 +1962,9 @@ func.func @torch.aten.ge.Tensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: ! // CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[2,4],f32> -> tensor<2x4xf32> // CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,4],f32> -> tensor<2x4xf32> // CHECK: %[[VAL_4:.*]] = tosa.reciprocal %[[VAL_2]] : (tensor<2x4xf32>) -> tensor<2x4xf32> -// CHECK: %[[VAL_5:.*]] = tosa.mul %[[VAL_3]], %[[VAL_4]] {shift = 0 : i8} : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> +// CHECK: %[[VAL_5:.*]] = tosa.mul %[[VAL_3]], {{.*}}, {{.*}}: (tensor<2x4xf32>, tensor<2x4xf32>, tensor<1xi8>) -> tensor<2x4xf32> // CHECK: %[[VAL_6:.*]] = tosa.floor %[[VAL_5]] : (tensor<2x4xf32>) -> tensor<2x4xf32> -// CHECK: %[[VAL_7:.*]] = tosa.mul %[[VAL_2]], %[[VAL_6]] {shift = 0 : i8} : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> +// CHECK: %[[VAL_7:.*]] = tosa.mul %[[VAL_2]], {{.*}}, {{.*}}: (tensor<2x4xf32>, tensor<2x4xf32>, tensor<1xi8>) -> tensor<2x4xf32> // CHECK: %[[VAL_8:.*]] = tosa.sub %[[VAL_3]], %[[VAL_7]] : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> // CHECK: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_8]] : tensor<2x4xf32> -> !torch.vtensor<[2,4],f32> // CHECK: return %[[VAL_9]] : !torch.vtensor<[2,4],f32> @@ -1982,7 +1982,7 @@ func.func @torch.aten.remainder.Tensor(%arg0: !torch.vtensor<[2, 4],f32>, %arg1: // CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[2,4],f32> -> tensor<2x4xf32> // CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,4],f32> -> tensor<2x4xf32> // CHECK: %[[VAL_4:.*]] = tosa.reciprocal %[[VAL_2]] : (tensor<2x4xf32>) -> tensor<2x4xf32> -// CHECK: %[[VAL_5:.*]] = tosa.mul %[[VAL_3]], %[[VAL_4]] {shift = 0 : i8} : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> +// CHECK: %[[VAL_5:.*]] = tosa.mul %[[VAL_3]], {{.*}}, {{.*}}: (tensor<2x4xf32>, tensor<2x4xf32>, tensor<1xi8>) -> tensor<2x4xf32> // CHECK: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor}> : () -> tensor // CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor}> : () -> tensor // CHECK: %[[VAL_8:.*]] = "tosa.const"() <{value = dense<-1.000000e+00> : tensor}> : () -> tensor @@ -1993,8 +1993,8 @@ func.func @torch.aten.remainder.Tensor(%arg0: !torch.vtensor<[2, 4],f32>, %arg1: // CHECK: %[[VAL_13:.*]] = tosa.select %[[VAL_12]], %[[VAL_9]], %[[VAL_11]] : (tensor<2x4xi1>, tensor<1x1xf32>, tensor<1x1xf32>) -> tensor<2x4xf32> // CHECK: %[[VAL_14:.*]] = tosa.abs %[[VAL_5]] : (tensor<2x4xf32>) -> tensor<2x4xf32> // CHECK: %[[VAL_15:.*]] = tosa.floor %[[VAL_14]] : (tensor<2x4xf32>) -> tensor<2x4xf32> -// CHECK: %[[VAL_16:.*]] = tosa.mul %[[VAL_15]], %[[VAL_13]] {shift = 0 : i8} : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> -// CHECK: %[[VAL_17:.*]] = tosa.mul %[[VAL_2]], %[[VAL_16]] {shift = 0 : i8} : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> +// CHECK: %[[VAL_16:.*]] = tosa.mul %[[VAL_15]], {{.*}}, {{.*}}: (tensor<2x4xf32>, tensor<2x4xf32>, tensor<1xi8>) -> tensor<2x4xf32> +// CHECK: %[[VAL_17:.*]] = tosa.mul %[[VAL_2]], {{.*}}, {{.*}}: (tensor<2x4xf32>, tensor<2x4xf32>, tensor<1xi8>) -> tensor<2x4xf32> // CHECK: %[[VAL_18:.*]] = tosa.sub %[[VAL_3]], %[[VAL_17]] : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> // CHECK: %[[VAL_19:.*]] = torch_c.from_builtin_tensor %[[VAL_18]] : tensor<2x4xf32> -> !torch.vtensor<[2,4],f32> // CHECK: return %[[VAL_19]] : !torch.vtensor<[2,4],f32> @@ -2207,7 +2207,7 @@ func.func @torch.aten.bitwise_right_shift.Tensor$basic(%arg0: !torch.vtensor<[?, // CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<[2, 3, 1, 0]> : tensor<4xi32>}> : () -> tensor<4xi32> // CHECK: %[[VAL_6:.*]] = tosa.transpose %[[VAL_1]], %[[VAL_5]] : (tensor<3x4x5x6xi32>, tensor<4xi32>) -> tensor<5x6x4x3xi32> // CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}0, 0, 0], [0, 0, 0], [1, 0, 0], [0, 1, 0]]]]> : tensor<1x1x4x3xi32>}> : () -> tensor<1x1x4x3xi32> -// CHECK: %[[VAL_8:.*]] = tosa.mul %[[VAL_6]], %[[VAL_7]] {shift = 0 : i8} : (tensor<5x6x4x3xi32>, tensor<1x1x4x3xi32>) -> tensor<5x6x4x3xi32> +// CHECK: %[[VAL_8:.*]] = tosa.mul %[[VAL_6]], %[[VAL_7]], %{{.*}} : (tensor<5x6x4x3xi32>, tensor<1x1x4x3xi32>, tensor<1xi8>) -> tensor<5x6x4x3xi32> // CHECK: %[[VAL_9:.*]] = tosa.slice %[[VAL_8]] {size = array, start = array} : (tensor<5x6x4x3xi32>) -> tensor<5x6x2x3xi32> // CHECK: %[[VAL_10:.*]] = tosa.reduce_sum %[[VAL_9]] {axis = 3 : i32} : (tensor<5x6x2x3xi32>) -> tensor<5x6x2x1xi32> // CHECK: %[[VAL_11:.*]] = tosa.reshape %[[VAL_10]] {new_shape = array} : (tensor<5x6x2x1xi32>) -> tensor<5x6x2xi32> @@ -2242,7 +2242,7 @@ func.func @torch.aten.diagonal$basic(%arg0: !torch.vtensor<[3,4,5,6], si32>) -> // CHECK: %[[VAL_14:.*]] = tosa.reshape %[[VAL_12]] {new_shape = array} : (tensor<4x5x2x3xi32>) -> tensor<40x3xi32> // CHECK: %[[VAL_15:.*]] = "tosa.const"() <{value = dense<[30, 6, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> // CHECK: %[[VAL_16:.*]] = tosa.reshape %[[VAL_15]] {new_shape = array} : (tensor<3xi32>) -> tensor<1x3xi32> -// CHECK: %[[VAL_17:.*]] = tosa.mul %[[VAL_14]], %[[VAL_16]] {shift = 0 : i8} : (tensor<40x3xi32>, tensor<1x3xi32>) -> tensor<40x3xi32> +// CHECK: %[[VAL_17:.*]] = tosa.mul %[[VAL_14]], %[[VAL_16]], %{{.*}} : (tensor<40x3xi32>, tensor<1x3xi32>, tensor<1xi8>) -> tensor<40x3xi32> // CHECK: %[[VAL_18:.*]] = tosa.reduce_sum %[[VAL_17]] {axis = 1 : i32} : (tensor<40x3xi32>) -> tensor<40x1xi32> // CHECK: %[[VAL_19:.*]] = tosa.reshape %[[VAL_18]] {new_shape = array} : (tensor<40x1xi32>) -> tensor<1x40xi32> // CHECK: %[[VAL_20:.*]] = tosa.gather %[[VAL_13]], %[[VAL_19]] : (tensor<1x120x1xf32>, tensor<1x40xi32>) -> tensor<1x40x1xf32> @@ -2323,9 +2323,9 @@ func.func @torch.aten.flip(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor // CHECK: %[[VAL_6:.*]] = tosa.floor %[[VAL_1]] : (tensor<3x4x5xf32>) -> tensor<3x4x5xf32> // CHECK: %[[VAL_7:.*]] = tosa.sub %[[VAL_1]], %[[VAL_6]] : (tensor<3x4x5xf32>, tensor<3x4x5xf32>) -> tensor<3x4x5xf32> // CHECK: %[[VAL_8:.*]] = tosa.ceil %[[VAL_1]] : (tensor<3x4x5xf32>) -> tensor<3x4x5xf32> -// CHECK: %[[VAL_9:.*]] = tosa.mul %[[VAL_6]], %[[VAL_4]] {shift = 0 : i8} : (tensor<3x4x5xf32>, tensor<1x1x1xf32>) -> tensor<3x4x5xf32> +// CHECK: %[[VAL_9:.*]] = tosa.mul %[[VAL_6]], {{.*}}, {{.*}}: (tensor<3x4x5xf32>, tensor<1x1x1xf32>, tensor<1xi8>) -> tensor<3x4x5xf32> // CHECK: %[[VAL_10:.*]] = tosa.floor %[[VAL_9]] : (tensor<3x4x5xf32>) -> tensor<3x4x5xf32> -// CHECK: %[[VAL_11:.*]] = tosa.mul %[[VAL_10]], %[[VAL_5]] {shift = 0 : i8} : (tensor<3x4x5xf32>, tensor<1x1x1xf32>) -> tensor<3x4x5xf32> +// CHECK: %[[VAL_11:.*]] = tosa.mul %[[VAL_10]], {{.*}}, {{.*}}: (tensor<3x4x5xf32>, tensor<1x1x1xf32>, tensor<1xi8>) -> tensor<3x4x5xf32> // CHECK: %[[VAL_12:.*]] = tosa.equal %[[VAL_6]], %[[VAL_11]] : (tensor<3x4x5xf32>, tensor<3x4x5xf32>) -> tensor<3x4x5xi1> // CHECK: %[[VAL_13:.*]] = tosa.equal %[[VAL_7]], %[[VAL_4]] : (tensor<3x4x5xf32>, tensor<1x1x1xf32>) -> tensor<3x4x5xi1> // CHECK: %[[VAL_14:.*]] = tosa.greater %[[VAL_4]], %[[VAL_7]] : (tensor<1x1x1xf32>, tensor<3x4x5xf32>) -> tensor<3x4x5xi1> @@ -2426,7 +2426,7 @@ func.func @torch.aten.empty.memory_format$basic() -> !torch.vtensor<[3,4],si64> // CHECK: %[[VAL_14:.*]] = tosa.reshape %[[VAL_11]] {new_shape = array} : (tensor<2x4x3x3xi32>) -> tensor<24x3xi32> // CHECK: %[[VAL_15:.*]] = "tosa.const"() <{value = dense<[48, 6, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> // CHECK: %[[VAL_16:.*]] = tosa.reshape %[[VAL_15]] {new_shape = array} : (tensor<3xi32>) -> tensor<1x3xi32> -// CHECK: %[[VAL_17:.*]] = tosa.mul %[[VAL_14]], %[[VAL_16]] {shift = 0 : i8} : (tensor<24x3xi32>, tensor<1x3xi32>) -> tensor<24x3xi32> +// CHECK: %[[VAL_17:.*]] = tosa.mul %[[VAL_14]], %[[VAL_16]], %{{.*}} : (tensor<24x3xi32>, tensor<1x3xi32>, tensor<1xi8>) -> tensor<24x3xi32> // CHECK: %[[VAL_18:.*]] = tosa.reduce_sum %[[VAL_17]] {axis = 1 : i32} : (tensor<24x3xi32>) -> tensor<24x1xi32> // CHECK: %[[VAL_19:.*]] = tosa.reshape %[[VAL_18]] {new_shape = array} : (tensor<24x1xi32>) -> tensor<1x24xi32> // CHECK: %[[VAL_20:.*]] = tosa.scatter %[[VAL_13]], %[[VAL_19]], %[[VAL_12]] : (tensor<1x480x1xf32>, tensor<1x24xi32>, tensor<1x36x1xf32>) -> tensor<1x480x1xf32> @@ -2458,7 +2458,7 @@ func.func @torch.aten.scatter.src$basic(%arg0: !torch.vtensor<[10,8,6],f32>, %ar // CHECK: %[[VAL_12:.*]] = tosa.reshape %[[VAL_9]] {new_shape = array} : (tensor<6x1x2xi32>) -> tensor<6x2xi32> // CHECK: %[[VAL_13:.*]] = "tosa.const"() <{value = dense<[8, 1]> : tensor<2xi32>}> : () -> tensor<2xi32> // CHECK: %[[VAL_14:.*]] = tosa.reshape %[[VAL_13]] {new_shape = array} : (tensor<2xi32>) -> tensor<1x2xi32> -// CHECK: %[[VAL_15:.*]] = tosa.mul %[[VAL_12]], %[[VAL_14]] {shift = 0 : i8} : (tensor<6x2xi32>, tensor<1x2xi32>) -> tensor<6x2xi32> +// CHECK: %[[VAL_15:.*]] = tosa.mul %[[VAL_12]], %[[VAL_14]], %{{.*}} : (tensor<6x2xi32>, tensor<1x2xi32>, tensor<1xi8>) -> tensor<6x2xi32> // CHECK: %[[VAL_16:.*]] = tosa.reduce_sum %[[VAL_15]] {axis = 1 : i32} : (tensor<6x2xi32>) -> tensor<6x1xi32> // CHECK: %[[VAL_17:.*]] = tosa.reshape %[[VAL_16]] {new_shape = array} : (tensor<6x1xi32>) -> tensor<1x6xi32> // CHECK: %[[VAL_18:.*]] = tosa.scatter %[[VAL_11]], %[[VAL_17]], %[[VAL_10]] : (tensor<1x48x1xf32>, tensor<1x6xi32>, tensor<1x6x1xf32>) -> tensor<1x48x1xf32> @@ -2494,7 +2494,7 @@ func.func @torch.aten.slice_scatter$basic(%arg0: !torch.vtensor<[6,8],f32>, %arg // CHECK: %[[VAL_15:.*]] = tosa.reshape %[[VAL_12]] {new_shape = array} : (tensor<2x3x4x1x4xi32>) -> tensor<24x4xi32> // CHECK: %[[VAL_16:.*]] = "tosa.const"() <{value = dense<[48, 16, 4, 1]> : tensor<4xi32>}> : () -> tensor<4xi32> // CHECK: %[[VAL_17:.*]] = tosa.reshape %[[VAL_16]] {new_shape = array} : (tensor<4xi32>) -> tensor<1x4xi32> -// CHECK: %[[VAL_18:.*]] = tosa.mul %[[VAL_15]], %[[VAL_17]] {shift = 0 : i8} : (tensor<24x4xi32>, tensor<1x4xi32>) -> tensor<24x4xi32> +// CHECK: %[[VAL_18:.*]] = tosa.mul %[[VAL_15]], %[[VAL_17]], %{{.*}} : (tensor<24x4xi32>, tensor<1x4xi32>, tensor<1xi8>) -> tensor<24x4xi32> // CHECK: %[[VAL_19:.*]] = tosa.reduce_sum %[[VAL_18]] {axis = 1 : i32} : (tensor<24x4xi32>) -> tensor<24x1xi32> // CHECK: %[[VAL_20:.*]] = tosa.reshape %[[VAL_19]] {new_shape = array} : (tensor<24x1xi32>) -> tensor<1x24xi32> // CHECK: %[[VAL_21:.*]] = tosa.scatter %[[VAL_14]], %[[VAL_20]], %[[VAL_13]] : (tensor<1x96x1xf32>, tensor<1x24xi32>, tensor<1x24x1xf32>) -> tensor<1x96x1xf32> @@ -2531,7 +2531,7 @@ func.func @torch.aten.diag_embed$basic(%arg0: !torch.vtensor<[2,3,4],f32>) -> !t // CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_11]] {new_shape = array} : (tensor<1xi32>) -> tensor<1x1xi32> // CHECK: %[[VAL_14:.*]] = "tosa.const"() <{value = dense<1> : tensor<1xi32>}> : () -> tensor<1xi32> // CHECK: %[[VAL_15:.*]] = tosa.reshape %[[VAL_14]] {new_shape = array} : (tensor<1xi32>) -> tensor<1x1xi32> -// CHECK: %[[VAL_16:.*]] = tosa.mul %[[VAL_13]], %[[VAL_15]] {shift = 0 : i8} : (tensor<1x1xi32>, tensor<1x1xi32>) -> tensor<1x1xi32> +// CHECK: %[[VAL_16:.*]] = tosa.mul %[[VAL_13]], %[[VAL_15]], %{{.*}} : (tensor<1x1xi32>, tensor<1x1xi32>, tensor<1xi8>) -> tensor<1x1xi32> // CHECK: %[[VAL_17:.*]] = tosa.reduce_sum %[[VAL_16]] {axis = 1 : i32} : (tensor<1x1xi32>) -> tensor<1x1xi32> // CHECK: %[[VAL_18:.*]] = tosa.reshape %[[VAL_17]] {new_shape = array} : (tensor<1x1xi32>) -> tensor<1x1xi32> // CHECK: %[[VAL_19:.*]] = tosa.gather %[[VAL_12]], %[[VAL_18]] : (tensor<1x2x8xi64>, tensor<1x1xi32>) -> tensor<1x1x8xi64> @@ -2635,7 +2635,7 @@ func.func @torch.aten.uniform$basic(%arg0: !torch.vtensor<[3,4],f64>) -> (!torch // CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_11]] {new_shape = array} : (tensor<9x1xi32>) -> tensor<9x1xi32> // CHECK: %[[VAL_14:.*]] = "tosa.const"() <{value = dense<1> : tensor<1xi32>}> : () -> tensor<1xi32> // CHECK: %[[VAL_15:.*]] = tosa.reshape %[[VAL_14]] {new_shape = array} : (tensor<1xi32>) -> tensor<1x1xi32> -// CHECK: %[[VAL_16:.*]] = tosa.mul %[[VAL_13]], %[[VAL_15]] {shift = 0 : i8} : (tensor<9x1xi32>, tensor<1x1xi32>) -> tensor<9x1xi32> +// CHECK: %[[VAL_16:.*]] = tosa.mul %[[VAL_13]], %[[VAL_15]], %{{.*}} : (tensor<9x1xi32>, tensor<1x1xi32>, tensor<1xi8>) -> tensor<9x1xi32> // CHECK: %[[VAL_17:.*]] = tosa.reduce_sum %[[VAL_16]] {axis = 1 : i32} : (tensor<9x1xi32>) -> tensor<9x1xi32> // CHECK: %[[VAL_18:.*]] = tosa.reshape %[[VAL_17]] {new_shape = array} : (tensor<9x1xi32>) -> tensor<1x9xi32> // CHECK: %[[VAL_19:.*]] = tosa.gather %[[VAL_12]], %[[VAL_18]] : (tensor<1x25x1xf32>, tensor<1x9xi32>) -> tensor<1x9x1xf32> @@ -2919,7 +2919,7 @@ func.func @torch.aten.replication_pad2d$basic(%arg0: !torch.vtensor<[1,1,3,3],f3 // CHECK: %[[VAL_7:.*]] = tosa.reshape %[[VAL_2]] {new_shape = array} : (tensor<4xf32>) -> tensor<1x4xf32> // CHECK: %[[VAL_8:.*]] = tosa.const_shape {value = dense<[3, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> // CHECK: %[[VAL_9:.*]] = tosa.tile %[[VAL_7]], %[[VAL_8]] : (tensor<1x4xf32>, !tosa.shape<2>) -> tensor<3x4xf32> -// CHECK: %[[VAL_10:.*]] = tosa.mul %[[VAL_6]], %[[VAL_9]] {shift = 0 : i8} : (tensor<3x4xf32>, tensor<3x4xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_10:.*]] = tosa.mul %[[VAL_6]], {{.*}}, {{.*}}: (tensor<3x4xf32>, tensor<3x4xf32>, tensor<1xi8>) -> tensor<3x4xf32> // CHECK: %[[VAL_11:.*]] = torch_c.from_builtin_tensor %[[VAL_10]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> // CHECK: return %[[VAL_11]] : !torch.vtensor<[3,4],f32> // CHECK: } @@ -2968,7 +2968,7 @@ func.func @torch.prims.split_dim$basic(%arg0: !torch.vtensor<[1,8,3,3],si64>) -> // CHECK: %[[VAL_14:.*]] = tosa.reshape %[[VAL_12]] {new_shape = array} : (tensor<1x1x72x3xi32>) -> tensor<72x3xi32> // CHECK: %[[VAL_15:.*]] = "tosa.const"() <{value = dense<[6, 6, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> // CHECK: %[[VAL_16:.*]] = tosa.reshape %[[VAL_15]] {new_shape = array} : (tensor<3xi32>) -> tensor<1x3xi32> -// CHECK: %[[VAL_17:.*]] = tosa.mul %[[VAL_14]], %[[VAL_16]] {shift = 0 : i8} : (tensor<72x3xi32>, tensor<1x3xi32>) -> tensor<72x3xi32> +// CHECK: %[[VAL_17:.*]] = tosa.mul %[[VAL_14]], %[[VAL_16]], %{{.*}} : (tensor<72x3xi32>, tensor<1x3xi32>, tensor<1xi8>) -> tensor<72x3xi32> // CHECK: %[[VAL_18:.*]] = tosa.reduce_sum %[[VAL_17]] {axis = 1 : i32} : (tensor<72x3xi32>) -> tensor<72x1xi32> // CHECK: %[[VAL_19:.*]] = tosa.reshape %[[VAL_18]] {new_shape = array} : (tensor<72x1xi32>) -> tensor<1x72xi32> // CHECK: %[[VAL_20:.*]] = tosa.gather %[[VAL_13]], %[[VAL_19]] : (tensor<1x6x1xf64>, tensor<1x72xi32>) -> tensor<1x72x1xf64> @@ -3006,7 +3006,7 @@ func.func @torch.aten.upsample_nearest2d$basic(%arg0: !torch.vtensor<[1,1,2,3],f // CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_11]] {new_shape = array} : (tensor<1x1x14x3xi32>) -> tensor<14x3xi32> // CHECK: %[[VAL_14:.*]] = "tosa.const"() <{value = dense<[20, 20, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> // CHECK: %[[VAL_15:.*]] = tosa.reshape %[[VAL_14]] {new_shape = array} : (tensor<3xi32>) -> tensor<1x3xi32> -// CHECK: %[[VAL_16:.*]] = tosa.mul %[[VAL_13]], %[[VAL_15]] {shift = 0 : i8} : (tensor<14x3xi32>, tensor<1x3xi32>) -> tensor<14x3xi32> +// CHECK: %[[VAL_16:.*]] = tosa.mul %[[VAL_13]], %[[VAL_15]], %{{.*}} : (tensor<14x3xi32>, tensor<1x3xi32>, tensor<1xi8>) -> tensor<14x3xi32> // CHECK: %[[VAL_17:.*]] = tosa.reduce_sum %[[VAL_16]] {axis = 1 : i32} : (tensor<14x3xi32>) -> tensor<14x1xi32> // CHECK: %[[VAL_18:.*]] = tosa.reshape %[[VAL_17]] {new_shape = array} : (tensor<14x1xi32>) -> tensor<1x14xi32> // CHECK: %[[VAL_19:.*]] = tosa.gather %[[VAL_12]], %[[VAL_18]] : (tensor<1x20x1xf32>, tensor<1x14xi32>) -> tensor<1x14x1xf32> @@ -3035,15 +3035,15 @@ func.func @torch.aten.upsample_nearest2d.vec$basic(%arg0: !torch.vtensor<[1,1,4, // CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<3.000000e+00> : tensor<5x3xf32>}> : () -> tensor<5x3xf32> // CHECK: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<4.471500e-02> : tensor<5x3xf32>}> : () -> tensor<5x3xf32> // CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<0.636619746> : tensor<5x3xf32>}> : () -> tensor<5x3xf32> -// CHECK: %[[VAL_8:.*]] = tosa.mul %[[VAL_3]], %[[VAL_1]] {shift = 0 : i8} : (tensor<5x3xf32>, tensor<5x3xf32>) -> tensor<5x3xf32> +// CHECK: %[[VAL_8:.*]] = tosa.mul %[[VAL_3]], {{.*}}, {{.*}}: (tensor<5x3xf32>, tensor<5x3xf32>, tensor<1xi8>) -> tensor<5x3xf32> // CHECK: %[[VAL_9:.*]] = tosa.pow %[[VAL_7]], %[[VAL_3]] : (tensor<5x3xf32>, tensor<5x3xf32>) -> tensor<5x3xf32> // CHECK: %[[VAL_10:.*]] = tosa.pow %[[VAL_1]], %[[VAL_5]] : (tensor<5x3xf32>, tensor<5x3xf32>) -> tensor<5x3xf32> -// CHECK: %[[VAL_11:.*]] = tosa.mul %[[VAL_6]], %[[VAL_10]] {shift = 0 : i8} : (tensor<5x3xf32>, tensor<5x3xf32>) -> tensor<5x3xf32> +// CHECK: %[[VAL_11:.*]] = tosa.mul %[[VAL_6]], {{.*}}, {{.*}}: (tensor<5x3xf32>, tensor<5x3xf32>, tensor<1xi8>) -> tensor<5x3xf32> // CHECK: %[[VAL_12:.*]] = tosa.add %[[VAL_1]], %[[VAL_11]] : (tensor<5x3xf32>, tensor<5x3xf32>) -> tensor<5x3xf32> -// CHECK: %[[VAL_13:.*]] = tosa.mul %[[VAL_9]], %[[VAL_12]] {shift = 0 : i8} : (tensor<5x3xf32>, tensor<5x3xf32>) -> tensor<5x3xf32> +// CHECK: %[[VAL_13:.*]] = tosa.mul %[[VAL_9]], {{.*}}, {{.*}}: (tensor<5x3xf32>, tensor<5x3xf32>, tensor<1xi8>) -> tensor<5x3xf32> // CHECK: %[[VAL_14:.*]] = tosa.tanh %[[VAL_13]] : (tensor<5x3xf32>) -> tensor<5x3xf32> // CHECK: %[[VAL_15:.*]] = tosa.add %[[VAL_4]], %[[VAL_14]] : (tensor<5x3xf32>, tensor<5x3xf32>) -> tensor<5x3xf32> -// CHECK: %[[VAL_16:.*]] = tosa.mul %[[VAL_8]], %[[VAL_15]] {shift = 0 : i8} : (tensor<5x3xf32>, tensor<5x3xf32>) -> tensor<5x3xf32> +// CHECK: %[[VAL_16:.*]] = tosa.mul %[[VAL_8]], {{.*}}, {{.*}}: (tensor<5x3xf32>, tensor<5x3xf32>, tensor<1xi8>) -> tensor<5x3xf32> // CHECK: %[[VAL_17:.*]] = torch_c.from_builtin_tensor %[[VAL_16]] : tensor<5x3xf32> -> !torch.vtensor<[5,3],f32> // CHECK: return %[[VAL_17]] : !torch.vtensor<[5,3],f32> // CHECK: } @@ -3078,7 +3078,7 @@ func.func @torch.aten.exp$int(%arg0: !torch.vtensor<[3,4],si32>) -> !torch.vtens // CHECK: %[[VAL_4:.*]] = tosa.log %[[VAL_1]] : (tensor<3x4xf32>) -> tensor<3x4xf32> // CHECK: %[[VAL_5:.*]] = tosa.log %[[VAL_3]] : (tensor<1x1xf32>) -> tensor<1x1xf32> // CHECK: %[[VAL_6:.*]] = tosa.reciprocal %[[VAL_5]] : (tensor<1x1xf32>) -> tensor<1x1xf32> -// CHECK: %[[VAL_7:.*]] = tosa.mul %[[VAL_4]], %[[VAL_6]] {shift = 0 : i8} : (tensor<3x4xf32>, tensor<1x1xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_7:.*]] = tosa.mul %[[VAL_4]], {{.*}}, {{.*}}: (tensor<3x4xf32>, tensor<1x1xf32>, tensor<1xi8>) -> tensor<3x4xf32> // CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> // CHECK: return %[[VAL_8]] : !torch.vtensor<[3,4],f32> // CHECK: } @@ -3098,7 +3098,7 @@ func.func @torch.aten.log10$basic(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vt // CHECK: %[[VAL_5:.*]] = tosa.log %[[VAL_2]] : (tensor<3x4xf32>) -> tensor<3x4xf32> // CHECK: %[[VAL_6:.*]] = tosa.log %[[VAL_4]] : (tensor<1x1xf32>) -> tensor<1x1xf32> // CHECK: %[[VAL_7:.*]] = tosa.reciprocal %[[VAL_6]] : (tensor<1x1xf32>) -> tensor<1x1xf32> -// CHECK: %[[VAL_8:.*]] = tosa.mul %[[VAL_5]], %[[VAL_7]] {shift = 0 : i8} : (tensor<3x4xf32>, tensor<1x1xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_8:.*]] = tosa.mul %[[VAL_5]], {{.*}}, {{.*}}: (tensor<3x4xf32>, tensor<1x1xf32>, tensor<1xi8>) -> tensor<3x4xf32> // CHECK: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_8]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> // CHECK: return %[[VAL_9]] : !torch.vtensor<[3,4],f32> // CHECK: } @@ -3153,7 +3153,7 @@ func.func @torch.aten.log1p$int(%arg0: !torch.vtensor<[3,4],si32>) -> !torch.vte // CHECK: %[[VAL_5:.*]] = tosa.reshape %[[VAL_4]] {new_shape = array} : (tensor) -> tensor<1x1xf32> // CHECK: %[[VAL_6:.*]] = tosa.sub %[[VAL_5]], %[[VAL_3]] : (tensor<1x1xf32>, tensor<3x4xf32>) -> tensor<3x4xf32> // CHECK: %[[VAL_7:.*]] = tosa.reciprocal %[[VAL_6]] : (tensor<3x4xf32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_8:.*]] = tosa.mul %[[VAL_3]], %[[VAL_7]] {shift = 0 : i8} : (tensor<3x4xf32>, tensor<3x4xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_8:.*]] = tosa.mul %[[VAL_3]], {{.*}}, {{.*}}: (tensor<3x4xf32>, tensor<3x4xf32>, tensor<1xi8>) -> tensor<3x4xf32> // CHECK: %[[VAL_9:.*]] = tosa.log %[[VAL_8]] : (tensor<3x4xf32>) -> tensor<3x4xf32> // CHECK: %[[VAL_10:.*]] = torch_c.from_builtin_tensor %[[VAL_9]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> // CHECK: return %[[VAL_10]] : !torch.vtensor<[3,4],f32> @@ -3176,7 +3176,7 @@ func.func @torch.aten.logit$basic(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vt // CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_5]] {new_shape = array} : (tensor) -> tensor<1x1xf32> // CHECK: %[[VAL_7:.*]] = tosa.sub %[[VAL_6]], %[[VAL_4]] : (tensor<1x1xf32>, tensor<3x4xf32>) -> tensor<3x4xf32> // CHECK: %[[VAL_8:.*]] = tosa.reciprocal %[[VAL_7]] : (tensor<3x4xf32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_9:.*]] = tosa.mul %[[VAL_4]], %[[VAL_8]] {shift = 0 : i8} : (tensor<3x4xf32>, tensor<3x4xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_9:.*]] = tosa.mul %[[VAL_4]], {{.*}}, {{.*}}: (tensor<3x4xf32>, tensor<3x4xf32>, tensor<1xi8>) -> tensor<3x4xf32> // CHECK: %[[VAL_10:.*]] = tosa.log %[[VAL_9]] : (tensor<3x4xf32>) -> tensor<3x4xf32> // CHECK: %[[VAL_11:.*]] = torch_c.from_builtin_tensor %[[VAL_10]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> // CHECK: return %[[VAL_11]] : !torch.vtensor<[3,4],f32> @@ -3211,7 +3211,7 @@ func.func @torch.aten.log$int(%arg0: !torch.vtensor<[3,4],si32>) -> !torch.vtens // CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<0.693147182> : tensor<1x1xf32>}> : () -> tensor<1x1xf32> // CHECK: %[[VAL_4:.*]] = tosa.reciprocal %[[VAL_3]] : (tensor<1x1xf32>) -> tensor<1x1xf32> // CHECK: %[[VAL_5:.*]] = tosa.log %[[VAL_2]] : (tensor<3x4xf32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_5]], %[[VAL_4]] {shift = 0 : i8} : (tensor<3x4xf32>, tensor<1x1xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_5]], {{.*}}, {{.*}}: (tensor<3x4xf32>, tensor<1x1xf32>, tensor<1xi8>) -> tensor<3x4xf32> // CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> // CHECK: return %[[VAL_7]] : !torch.vtensor<[3,4],f32> // CHECK: } @@ -3278,7 +3278,7 @@ func.func @torch.aten.sigmoid$int(%arg0: !torch.vtensor<[3,5],si32>) -> !torch.v // CHECK: %[[VAL_2:.*]] = tosa.sin %[[VAL_1]] : (tensor<3x4xf32>) -> tensor<3x4xf32> // CHECK: %[[VAL_3:.*]] = tosa.cos %[[VAL_1]] : (tensor<3x4xf32>) -> tensor<3x4xf32> // CHECK: %[[VAL_4:.*]] = tosa.reciprocal %[[VAL_3]] : (tensor<3x4xf32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_5:.*]] = tosa.mul %[[VAL_2]], %[[VAL_4]] {shift = 0 : i8} : (tensor<3x4xf32>, tensor<3x4xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_5:.*]] = tosa.mul %[[VAL_2]], {{.*}}, {{.*}}: (tensor<3x4xf32>, tensor<3x4xf32>, tensor<1xi8>) -> tensor<3x4xf32> // CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> // CHECK: return %[[VAL_6]] : !torch.vtensor<[3,4],f32> // CHECK: } @@ -3296,7 +3296,7 @@ func.func @torch.aten.tan$basic(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vten // CHECK: %[[VAL_3:.*]] = tosa.sin %[[VAL_2]] : (tensor<3x4xf32>) -> tensor<3x4xf32> // CHECK: %[[VAL_4:.*]] = tosa.cos %[[VAL_2]] : (tensor<3x4xf32>) -> tensor<3x4xf32> // CHECK: %[[VAL_5:.*]] = tosa.reciprocal %[[VAL_4]] : (tensor<3x4xf32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_3]], %[[VAL_5]] {shift = 0 : i8} : (tensor<3x4xf32>, tensor<3x4xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_3]], {{.*}}, {{.*}}: (tensor<3x4xf32>, tensor<3x4xf32>, tensor<1xi8>) -> tensor<3x4xf32> // CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> // CHECK: return %[[VAL_7]] : !torch.vtensor<[3,4],f32> // CHECK: } @@ -3352,7 +3352,7 @@ func.func @torch.aten.pow.Tensor_Tensor$intfloat(%arg0: !torch.vtensor<[3,4,5],s // CHECK: %[[VAL_9:.*]] = tosa.reshape %[[VAL_7]] {new_shape = array} : (tensor<6x4x2xi32>) -> tensor<24x2xi32> // CHECK: %[[VAL_10:.*]] = "tosa.const"() <{value = dense<[4, 1]> : tensor<2xi32>}> : () -> tensor<2xi32> // CHECK: %[[VAL_11:.*]] = tosa.reshape %[[VAL_10]] {new_shape = array} : (tensor<2xi32>) -> tensor<1x2xi32> -// CHECK: %[[VAL_12:.*]] = tosa.mul %[[VAL_9]], %[[VAL_11]] {shift = 0 : i8} : (tensor<24x2xi32>, tensor<1x2xi32>) -> tensor<24x2xi32> +// CHECK: %[[VAL_12:.*]] = tosa.mul %[[VAL_9]], %[[VAL_11]], %{{.*}} : (tensor<24x2xi32>, tensor<1x2xi32>, tensor<1xi8>) -> tensor<24x2xi32> // CHECK: %[[VAL_13:.*]] = tosa.reduce_sum %[[VAL_12]] {axis = 1 : i32} : (tensor<24x2xi32>) -> tensor<24x1xi32> // CHECK: %[[VAL_14:.*]] = tosa.reshape %[[VAL_13]] {new_shape = array} : (tensor<24x1xi32>) -> tensor<1x24xi32> // CHECK: %[[VAL_15:.*]] = tosa.gather %[[VAL_8]], %[[VAL_14]] : (tensor<1x24x1xf32>, tensor<1x24xi32>) -> tensor<1x24x1xf32> diff --git a/test/Conversion/TorchToTosa/torch-backend-to-tosa-backend-pipeline.mlir b/test/Conversion/TorchToTosa/torch-backend-to-tosa-backend-pipeline.mlir index 68ca4f0e08ac..abf153fcb0c9 100644 --- a/test/Conversion/TorchToTosa/torch-backend-to-tosa-backend-pipeline.mlir +++ b/test/Conversion/TorchToTosa/torch-backend-to-tosa-backend-pipeline.mlir @@ -3,7 +3,7 @@ // CHECK-LABEL: func.func @torch.aten.mul.Scalar$mixed_type( // CHECK-SAME: %[[VAL_0:.*]]: tensor<5xbf16>) -> tensor<5xbf16> { // CHECK: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<2.000000e+00> : tensor<1xbf16>}> : () -> tensor<1xbf16> -// CHECK: %[[VAL_2:.*]] = tosa.mul %[[VAL_0]], %[[VAL_1]] {shift = 0 : i8} : (tensor<5xbf16>, tensor<1xbf16>) -> tensor<5xbf16> +// CHECK: %[[VAL_2:.*]] = tosa.mul %[[VAL_0]], {{.*}}, {{.*}}: (tensor<5xbf16>, tensor<1xbf16>, tensor<1xi8>) -> tensor<5xbf16> // CHECK: return %[[VAL_2]] : tensor<5xbf16> // CHECK: } func.func @torch.aten.mul.Scalar$mixed_type(%arg0: !torch.vtensor<[5],bf16>) -> !torch.vtensor<[5],bf16> { @@ -95,7 +95,7 @@ func.func @torch.aten.bitwise_and.Tensor$mixed_type(%arg0: !torch.vtensor<[?,?], // CHECK-SAME: %[[VAL_1:.*]]: tensor) -> tensor { // CHECK: %[[VAL_2:.*]] = tosa.cast %[[VAL_1]] : (tensor) -> tensor // CHECK: %[[VAL_3:.*]] = tosa.reciprocal %[[VAL_2]] : (tensor) -> tensor -// CHECK: %[[VAL_4:.*]] = tosa.mul %[[VAL_0]], %[[VAL_3]] {shift = 0 : i8} : (tensor, tensor) -> tensor +// CHECK: %[[VAL_4:.*]] = tosa.mul %[[VAL_0]], {{.*}}, {{.*}}: (tensor, tensor, tensor<1xi8>) -> tensor // CHECK: return %[[VAL_4]] : tensor // CHECK: } func.func @torch.aten.div.Tensor$mixed_type_fp(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vtensor<[?, ?],si32>) -> !torch.vtensor<[?, ?],f32> { @@ -121,7 +121,7 @@ func.func @torch.aten.div.Tensor$mixed_type_int(%arg0: !torch.vtensor<[?, ?],si1 // CHECK-SAME: %[[VAL_0:.*]]: tensor // CHECK: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<7.812500e-03> : tensor<1x1xf32>}> : () -> tensor<1x1xf32> // CHECK: %[[VAL_3:.*]] = tosa.cast %[[VAL_0]] : (tensor) -> tensor -// CHECK: %[[VAL_5:.*]] = tosa.mul %[[VAL_3]], %[[VAL_1]] {shift = 0 : i8} : (tensor, tensor<1x1xf32>) -> tensor +// CHECK: %[[VAL_5:.*]] = tosa.mul %[[VAL_3]], {{.*}}, {{.*}}: (tensor, tensor<1x1xf32>, tensor<1xi8>) -> tensor func.func @torch.aten.div.Scalar$int_input_fp_output(%arg0: !torch.vtensor<[?, ?],si64>) -> !torch.vtensor<[?, ?],f32> { %int128 = torch.constant.int 128 %0 = torch.aten.div.Scalar %arg0, %int128 : !torch.vtensor<[?, ?],si64>, !torch.int -> !torch.vtensor<[?, ?],f32>