diff --git a/externals/llvm-project b/externals/llvm-project index 4ed634719ca5..c27444ab4976 160000 --- a/externals/llvm-project +++ b/externals/llvm-project @@ -1 +1 @@ -Subproject commit 4ed634719ca5ba458c2723796a2eef180aaa6df6 +Subproject commit c27444ab4976dd9ff131212f87463f9945ab28d7 diff --git a/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h b/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h index 398926e8168a..4041e522fca1 100644 --- a/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h +++ b/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h @@ -24,6 +24,7 @@ createOneDimTfIndices(PatternRewriter &rewriter, Operation *op, SmallVector indiceOneDimShape, int32_t dim, ArrayRef indexShape); +// Default function to create TOSA op with shift value mlir::tosa::MulOp createMulOpAndCast(PatternRewriter &rewriter, Operation *op, TensorType outType, Value lhs, Value rhs, int32_t shift); @@ -32,8 +33,8 @@ mlir::tosa::MulOp createMulOpAndCast(PatternRewriter &rewriter, Operation *op, template TosaOpT createBinaryOpAndCast(PatternRewriter &rewriter, Operation *op, TensorType outType, Value lhs, Value rhs) { - lhs = promoteType(rewriter, lhs, outType); - rhs = promoteType(rewriter, rhs, outType); + lhs = tosa::tosaCastTensorToType(rewriter, lhs, outType).value(); + rhs = tosa::tosaCastTensorToType(rewriter, rhs, outType).value(); return CreateOpAndInfer(rewriter, op->getLoc(), outType, lhs, rhs); } diff --git a/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h b/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h index 8beabb969ed8..eacb1b64ef73 100644 --- a/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h +++ b/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h @@ -10,13 +10,14 @@ #ifndef TORCHMLIR_CONVERSION_TORCHTOTOSA_TOSALEGALIZEUTILS_H #define TORCHMLIR_CONVERSION_TORCHTOTOSA_TOSALEGALIZEUTILS_H -#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project -#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" // from @llvm-project -#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project -#include "mlir/IR/BuiltinTypes.h" // from @llvm-project -#include "mlir/IR/PatternMatch.h" // from @llvm-project -#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project -#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" // from @llvm-project +#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project namespace mlir { namespace tosa { @@ -49,6 +50,10 @@ bool isScale32(mlir::quant::UniformQuantizedType output_element_type); Value getTosaConstTensorSingleF32(PatternRewriter &rewriter, Operation *op, float val); +// Create an int8_t const tosa.mul shift tensor from an int +Value getTosaMulShiftConstTensor(PatternRewriter &rewriter, Operation *op, + int32_t shift); + // Create a zero constant tensor of the desired type and shape. std::optional getZerosLikeTensor(PatternRewriter &rewriter, Operation *op, Type type); @@ -62,55 +67,24 @@ std::optional getConstTensor(PatternRewriter &rewriter, Operation *op, ArrayRef vec, ArrayRef shape, std::optional dtype = {}); -LogicalResult tosaCastTensorToType(PatternRewriter &rewriter, Operation *op, - Value src, Type destType, Value &result); - -Value promoteType(PatternRewriter &rewriter, Value input, TensorType outType); +// Default function to create tosa.cast op. This should be called instead of +// directly calling rewriter.create. +std::optional tosaCastTensorToType(PatternRewriter &rewriter, Value src, + TensorType destType); // Creates a TOSA operation and performs shape inference on the individual // op. This allows shape inference during the framework to TOSA lowering. +template +TosaOp CreateOpAndInfer(ImplicitLocOpBuilder &builder, Type result_ty, + Args &&...args) { + return CreateOpAndInferShape(builder, result_ty, args...); +} + template TosaOp CreateOpAndInfer(PatternRewriter &rewriter, Location loc, Type result_ty, Args &&...args) { - auto op = rewriter.create(loc, result_ty, args...); - - InferShapedTypeOpInterface shapeInterface = - dyn_cast(op.getOperation()); - if (!shapeInterface) - return op; - - SmallVector returnedShapes; - if (shapeInterface - .inferReturnTypeComponents(op.getContext(), op.getLoc(), - op->getOperands(), op->getAttrDictionary(), - op->getPropertiesStorage(), - op->getRegions(), returnedShapes) - .failed()) - return op; - - // We need to use the element type of the existing result type to generate - // the new result shaped type. This is because rescale can include a cast to - // different bit-width types and does not have a TypeAttr to define the - // target type. - auto result = op->getResult(0); - auto predictedShape = returnedShapes[0]; - auto currentKnowledge = ValueKnowledge::getKnowledgeFromType(result_ty); - - // Compute the knowledge based on the inferred type. - auto inferredKnowledge = ValueKnowledge::getPessimisticValueState(); - inferredKnowledge.dtype = cast(result_ty).getElementType(); - inferredKnowledge.hasRank = predictedShape.hasRank(); - if (predictedShape.hasRank()) { - for (auto dim : predictedShape.getDims()) { - inferredKnowledge.sizes.push_back(dim); - } - } - - // Compute the new type based on the joined version. - auto newKnowledge = ValueKnowledge::join(currentKnowledge, inferredKnowledge); - auto new_ty = newKnowledge.getType(); - result.setType(new_ty); - return op; + ImplicitLocOpBuilder builder(loc, rewriter); + return CreateOpAndInfer(builder, result_ty, args...); } template diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 5bf8a3387d88..7b1b5aec911e 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -63,7 +63,7 @@ class ConvertAtenUnaryPromoteToFPOp : public OpConversionPattern { // Non floating point inputs are not supported in TOSA so we cast the input // to result type if (!isa(selfTy.getElementType())) - self = tosa::promoteType(rewriter, self, resultTy); + self = tosa::tosaCastTensorToType(rewriter, self, resultTy).value(); rewriter.replaceOpWithNewOp(op, resultTy, self); @@ -87,7 +87,7 @@ class ConvertAtenUnaryOp : public OpConversionPattern { OpConversionPattern::getTypeConverter()->convertType( op.getType())); - self = tosa::promoteType(rewriter, self, outType); + self = tosa::tosaCastTensorToType(rewriter, self, outType).value(); rewriter.replaceOpWithNewOp(op, outType, self); @@ -130,8 +130,8 @@ class ConvertAtenBinaryOp : public OpConversionPattern { /*round=*/false); } else if constexpr (std::is_same() || std::is_same()) { - lhs = tosa::promoteType(rewriter, lhs, outTy); - rhs = tosa::promoteType(rewriter, rhs, outTy); + lhs = tosa::tosaCastTensorToType(rewriter, lhs, outTy).value(); + rhs = tosa::tosaCastTensorToType(rewriter, rhs, outTy).value(); // Use default NaN Propagation mode "PROPAGATE" for tosa.maximum and // tosa.minimum binaryOp = rewriter.create( @@ -348,7 +348,7 @@ class ConvertAtenAddSubOp : public OpConversionPattern { // right is tensor, rhsType == tensor // right must be cast to same type as the alpha, so MulOp success rhsType = RankedTensorType::get(rhsType.getShape(), rhsAlphaMulElemType); - rhs = rewriter.create(op->getLoc(), rhsType, rhs); + rhs = tosa::tosaCastTensorToType(rewriter, rhs, rhsType).value(); } // Handle scalar value alpha. @@ -381,7 +381,10 @@ class ConvertAtenAddSubOp : public OpConversionPattern { mulAlphaOp); // cast tensor back to tensor - rewriter.replaceOpWithNewOp(op, outType, addOrSubi64Op); + auto result = + tosa::tosaCastTensorToType(rewriter, addOrSubi64Op, outType).value(); + rewriter.replaceOp(op, result); + return success(); } @@ -456,8 +459,9 @@ class ConvertAtenCompareOp : public OpConversionPattern { OpConversionPattern::getTypeConverter()->convertType( op.getType())); if (isBitwiseOp) { - lhs = tosa::promoteType(rewriter, lhs, resultTy); - rhsTensor = tosa::promoteType(rewriter, rhsTensor, resultTy); + lhs = tosa::tosaCastTensorToType(rewriter, lhs, resultTy).value(); + rhsTensor = + tosa::tosaCastTensorToType(rewriter, rhsTensor, resultTy).value(); } // Support different types comparisons @@ -466,24 +470,27 @@ class ConvertAtenCompareOp : public OpConversionPattern { if (lhsElemTy != rhsElemTy && !isBitwiseOp) { if (isLhsElemFloat && !isRhsElemFloat) { - rhsTensor = tosa::promoteType(rewriter, rhsTensor, lhsTy); + rhsTensor = + tosa::tosaCastTensorToType(rewriter, rhsTensor, lhsTy).value(); } else if (!isLhsElemFloat && isRhsElemFloat) { - lhs = tosa::promoteType(rewriter, lhs, rhsTensorTy); + lhs = tosa::tosaCastTensorToType(rewriter, lhs, rhsTensorTy).value(); } else if (isLhsElemFloat && isRhsElemFloat) { auto lhsElemFloatTy = dyn_cast(lhsElemTy); auto rhsElemFloatTy = dyn_cast(rhsElemTy); if (lhsElemFloatTy.getWidth() > rhsElemFloatTy.getWidth()) { - rhsTensor = tosa::promoteType(rewriter, rhsTensor, lhsTy); + rhsTensor = + tosa::tosaCastTensorToType(rewriter, rhsTensor, lhsTy).value(); } else { - lhs = tosa::promoteType(rewriter, lhs, rhsTensorTy); + lhs = tosa::tosaCastTensorToType(rewriter, lhs, rhsTensorTy).value(); } } else { auto lhsElemIntTy = dyn_cast(lhsElemTy); auto rhsElemIntTy = dyn_cast(rhsElemTy); if (lhsElemIntTy.getWidth() > rhsElemIntTy.getWidth()) { - rhsTensor = tosa::promoteType(rewriter, rhsTensor, lhsTy); + rhsTensor = + tosa::tosaCastTensorToType(rewriter, rhsTensor, lhsTy).value(); } else { - lhs = tosa::promoteType(rewriter, lhs, rhsTensorTy); + lhs = tosa::tosaCastTensorToType(rewriter, lhs, rhsTensorTy).value(); } } } @@ -629,7 +636,7 @@ std::optional truncFloatDivWithDivResult(PatternRewriter &rewriter, // towards zero) for float type inputs Value truncFloatDiv(PatternRewriter &rewriter, Operation *op, TensorType outType, Value lhs, Value rhs) { - rhs = tosa::promoteType(rewriter, rhs, outType); + rhs = tosa::tosaCastTensorToType(rewriter, rhs, outType).value(); auto rhsRcp = rewriter.create(op->getLoc(), rhs.getType(), rhs); @@ -655,8 +662,8 @@ std::optional floorIntDiv(PatternRewriter &rewriter, Operation *op, // TOSA IntDiv requires inputs to be i32 auto i32Type = RankedTensorType::get(outType.getShape(), rewriter.getIntegerType(32)); - lhs = tosa::promoteType(rewriter, lhs, i32Type); - rhs = tosa::promoteType(rewriter, rhs, i32Type); + lhs = tosa::tosaCastTensorToType(rewriter, lhs, i32Type).value(); + rhs = tosa::tosaCastTensorToType(rewriter, rhs, i32Type).value(); auto intDivOp = rewriter.create(op->getLoc(), i32Type, lhs, rhs); @@ -672,14 +679,14 @@ std::optional floorIntDiv(PatternRewriter &rewriter, Operation *op, auto boolType = RankedTensorType::get(outType.getShape(), rewriter.getIntegerType(1)); - auto lhsMulRhs = rewriter.create(op->getLoc(), i32Type, lhs, rhs, - /*shift=*/0); + auto lhsMulRhs = tosa::createMulOpAndCast(rewriter, op, i32Type, lhs, rhs, + /*shift=*/0); auto lhsRhsDifferentSign = rewriter.create(op->getLoc(), boolType, zero, lhsMulRhs); - auto truncMulRhs = rewriter.create(op->getLoc(), i32Type, - intDivOp, rhs, /*shift=*/0); + auto truncMulRhs = tosa::createMulOpAndCast(rewriter, op, i32Type, intDivOp, + rhs, /*shift=*/0); auto truncMulRhsEqualLhs = rewriter.create(op->getLoc(), boolType, truncMulRhs, lhs); @@ -696,7 +703,8 @@ std::optional floorIntDiv(PatternRewriter &rewriter, Operation *op, auto selectOp = rewriter.create(op->getLoc(), i32Type, cond, truncMinusOne, intDivOp); - Value result = tosa::promoteType(rewriter, selectOp, outType); + Value result = + tosa::tosaCastTensorToType(rewriter, selectOp, outType).value(); return result; } @@ -755,7 +763,8 @@ class ConvertAtenDivOp : public OpConversionPattern { // The input to the reciprocal is an integer sometimes, and we may need // to promote it to a floating point. Per TOSA specification, the input // types can only be floating point for tosa::ReciprocalOp. - rhsTensor = tosa::promoteType(rewriter, rhsTensor, outType); + rhsTensor = + tosa::tosaCastTensorToType(rewriter, rhsTensor, outType).value(); auto rhsRcp = rewriter.create( op->getLoc(), rhsTensor.getType(), rhsTensor); @@ -792,13 +801,15 @@ class ConvertAtenDivOp : public OpConversionPattern { // TOSA IntDiv requires inputs to be i32 auto i32Type = RankedTensorType::get(outType.getShape(), rewriter.getIntegerType(32)); - lhs = tosa::promoteType(rewriter, lhs, i32Type); - rhsTensor = tosa::promoteType(rewriter, rhsTensor, i32Type); + lhs = tosa::tosaCastTensorToType(rewriter, lhs, i32Type).value(); + rhsTensor = + tosa::tosaCastTensorToType(rewriter, rhsTensor, i32Type).value(); auto intDivOp = rewriter.create(op->getLoc(), i32Type, lhs, rhsTensor); - result = tosa::promoteType(rewriter, intDivOp, outType); + result = + tosa::tosaCastTensorToType(rewriter, intDivOp, outType).value(); } } @@ -843,7 +854,7 @@ class ConvertAtenActivationFunctionOp : public OpConversionPattern { // Non floating point inputs are not supported for activation functions // (erf, sigmoid, tanh) in TOSA so we cast the input to result type if (!isa(selfTy.getElementType())) - self = tosa::promoteType(rewriter, self, resultTy); + self = tosa::tosaCastTensorToType(rewriter, self, resultTy).value(); rewriter.replaceOpWithNewOp(op, resultTy, self); @@ -918,12 +929,14 @@ LogicalResult ConvertAtenOp::matchAndRewrite( op->getLoc(), RankedTensorType::get(selfTy.getShape(), rewriter.getIntegerType(1)), self, zero); - auto mulTensor = rewriter.create( - op->getLoc(), getTypeConverter()->convertType(op.getType()), self, - alphaTensor, /*shift=*/0); - rewriter.replaceOpWithNewOp( - op, getTypeConverter()->convertType(op.getType()), cond, self, mulTensor); + auto resultTy = + dyn_cast(getTypeConverter()->convertType(op.getType())); + auto mulTensor = tosa::createMulOpAndCast(rewriter, op, resultTy, self, + alphaTensor, /*shift=*/0); + + rewriter.replaceOpWithNewOp(op, resultTy, cond, self, + mulTensor); return success(); } @@ -978,9 +991,11 @@ class ConvertAtenReductionOp : public OpConversionPattern { std::is_same() || std::is_same() || std::is_same()) { - self = tosa::promoteType( - rewriter, self, - RankedTensorType::get(selfTy.getShape(), rewriter.getIntegerType(1))); + self = tosa::tosaCastTensorToType( + rewriter, self, + RankedTensorType::get(selfTy.getShape(), + rewriter.getIntegerType(1))) + .value(); } // Handle dtype output and bool elem type for ReduceSum and ReduceProd ops @@ -1005,13 +1020,14 @@ class ConvertAtenReductionOp : public OpConversionPattern { dtypeType = rewriter.getIntegerType(dtypeType.getIntOrFloatBitWidth()); - self = tosa::promoteType( - rewriter, self, - RankedTensorType::get(selfTy.getShape(), dtypeType)); + self = tosa::tosaCastTensorToType( + rewriter, self, + RankedTensorType::get(selfTy.getShape(), dtypeType)) + .value(); } } else { if (selfElemTy.isInteger(1)) - self = tosa::promoteType(rewriter, self, outputTy); + self = tosa::tosaCastTensorToType(rewriter, self, outputTy).value(); } } @@ -1284,7 +1300,7 @@ class ConvertAtenSqueezeOp : public OpConversionPattern { op->getLoc(), OpConversionPattern::getTypeConverter()->convertType( newOutputTy), - self, rewriter.getDenseI64ArrayAttr(newOutputShape)); + self, tosa::getTosaConstShape(rewriter, op->getLoc(), newOutputShape)); rewriter.replaceOpWithNewOp( op, OpConversionPattern::getTypeConverter()->convertType( @@ -1385,7 +1401,8 @@ class ConvertAtenPowOp : public OpConversionPattern { // Non floating point inputs are not supported for tosa.pow so we cast the // input to result type if (!isa(selfTy.getElementType())) - selfTensor = tosa::promoteType(rewriter, selfTensor, outType); + selfTensor = + tosa::tosaCastTensorToType(rewriter, selfTensor, outType).value(); } Value expTensor; @@ -1407,7 +1424,8 @@ class ConvertAtenPowOp : public OpConversionPattern { // Non floating point exponents are not supported for tosa.pow so we cast // the exponent to result type if (!isa(expTy.getElementType())) - expTensor = tosa::promoteType(rewriter, expTensor, outType); + expTensor = + tosa::tosaCastTensorToType(rewriter, expTensor, outType).value(); } if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), selfTensor, expTensor) @@ -1643,7 +1661,9 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern { op->getLoc(), OpConversionPattern::getTypeConverter()->convertType( lhsBroadcastedTy), - lhs, rewriter.getDenseI64ArrayAttr(lhsBroadcastedShape)); + lhs, + tosa::getTosaConstShape(rewriter, op->getLoc(), + lhsBroadcastedShape)); auto rankBroadcastedRhs = rhsRank == maxInputRank @@ -1652,7 +1672,9 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern { op->getLoc(), OpConversionPattern::getTypeConverter()->convertType( rhsBroadcastedTy), - rhs, rewriter.getDenseI64ArrayAttr(rhsBroadcastedShape)); + rhs, + tosa::getTosaConstShape(rewriter, op->getLoc(), + rhsBroadcastedShape)); // TOSA matmul is performed on two 3D inputs and generates a 3D output. // Lower ranked tensors are dim-1 reshaped up to 3D @@ -1680,7 +1702,7 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern { op->getLoc(), OpConversionPattern::getTypeConverter()->convertType( newType), - tensor, rewriter.getDenseI64ArrayAttr(newShape)); + tensor, tosa::getTosaConstShape(rewriter, op->getLoc(), newShape)); }; // Where broadcasting is required in one or more batch dims, the following @@ -1870,7 +1892,8 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern { op->getLoc(), OpConversionPattern::getTypeConverter()->convertType( newLhsType), - lhsReshapeInput, rewriter.getDenseI64ArrayAttr(newLhsShape)); + lhsReshapeInput, + tosa::getTosaConstShape(rewriter, op->getLoc(), newLhsShape)); SmallVector transposedRhsShape; SmallVector transposedRhsDims; @@ -1942,7 +1965,8 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern { op->getLoc(), OpConversionPattern::getTypeConverter()->convertType( newRhsType), - transposedRhsValue, rewriter.getDenseI64ArrayAttr(newRhsShape)); + transposedRhsValue, + tosa::getTosaConstShape(rewriter, op->getLoc(), newRhsShape)); } auto matmulLhsShape = makeShapeTorchCompatible( @@ -2084,7 +2108,8 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern { op->getLoc(), OpConversionPattern::getTypeConverter()->convertType( reshapedOpType), - castResult, rewriter.getDenseI64ArrayAttr(reshapedOpShape)); + castResult, + tosa::getTosaConstShape(rewriter, op->getLoc(), reshapedOpShape)); if (opNeedsTranspose) { @@ -2326,7 +2351,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( dyn_cast(getTypeConverter()->convertType(op.getType())); auto resultElemTy = resultTy.getElementType(); - self = tosa::promoteType(rewriter, self, resultTy); + self = tosa::tosaCastTensorToType(rewriter, self, resultTy).value(); Value otherTensor, alphaTensor; @@ -2348,8 +2373,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "Failed to equalize ranks among operands and result"); - auto multTensor = rewriter.create(op->getLoc(), resultTy, self, - alphaTensor, /*shift=*/0); + auto multTensor = tosa::createMulOpAndCast(rewriter, op, resultTy, self, + alphaTensor, /*shift=*/0); rewriter.replaceOpWithNewOp(op, resultTy, otherTensor, multTensor); @@ -2588,7 +2613,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( op->getLoc(), getTypeConverter()->convertType(transformedWeightType), transposedWeight, - rewriter.getDenseI64ArrayAttr(transformedWeightShape)) + tosa::getTosaConstShape(rewriter, op->getLoc(), + transformedWeightShape)) .getResult(); } @@ -2710,7 +2736,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(newType), self, - rewriter.getDenseI64ArrayAttr(newShape)); + tosa::getTosaConstShape(rewriter, op->getLoc(), newShape)); return success(); } @@ -2761,12 +2787,13 @@ std::optional computeBatchNorm(Operation *op, auto op3RsqrtOp2 = rewriter.create( op->getLoc(), variance.getType(), op2AddVarEpsilon.getResult()); - auto op4MulOp1Op3 = rewriter.create(op->getLoc(), outType, - op1SubInputMean.getResult(), - op3RsqrtOp2.getResult(), 0); + auto op4MulOp1Op3 = tosa::createMulOpAndCast( + rewriter, op, dyn_cast(outType), op1SubInputMean.getResult(), + op3RsqrtOp2.getResult(), 0); - auto op5MulOp4Scale = rewriter.create( - op->getLoc(), outType, op4MulOp1Op3.getResult(), weight, 0); + auto op5MulOp4Scale = + tosa::createMulOpAndCast(rewriter, op, dyn_cast(outType), + op4MulOp1Op3.getResult(), weight, 0); return rewriter .create(op->getLoc(), outType, op5MulOp4Scale.getResult(), @@ -2821,7 +2848,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( result = rewriter.create( op->getLoc(), newType, toBcast, - rewriter.getDenseI64ArrayAttr(newShape)); + tosa::getTosaConstShape(rewriter, op->getLoc(), newShape)); return success(); }; @@ -2955,7 +2982,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( } return rewriter.create( - op.getLoc(), outType, sumDiv, rewriter.getDenseI64ArrayAttr(outShape)); + op.getLoc(), outType, sumDiv, + tosa::getTosaConstShape(rewriter, op->getLoc(), outShape)); }; // TOSA has integer Div so, compute reciprocal of element count to be used in @@ -2989,19 +3017,19 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Compute mean. Value sum = computeSumAndReshape(input, inputType, bcastOutType, bcastOutShape); - Value meanVal = rewriter.create(op.getLoc(), bcastOutType, sum, - elemCntRcp, /*shift=*/0); + Value meanVal = tosa::createMulOpAndCast(rewriter, op, bcastOutType, sum, + elemCntRcp, /*shift=*/0); // Compute variance. Value squareSumSub = rewriter.create(op.getLoc(), inputType, input, meanVal); - Value squareSum = rewriter.create(op.getLoc(), inputType, - squareSumSub, squareSumSub, 0); + Value squareSum = tosa::createMulOpAndCast(rewriter, op, inputType, + squareSumSub, squareSumSub, 0); Value squareSumReduced = computeSumAndReshape(squareSum, inputType, bcastOutType, bcastOutShape); - Value varianceVal = rewriter.create( - op.getLoc(), bcastOutType, squareSumReduced, elemCntRcp, /*shift=*/0); + Value varianceVal = tosa::createMulOpAndCast( + rewriter, op, bcastOutType, squareSumReduced, elemCntRcp, /*shift=*/0); // Reshape weight and bias. SmallVector weightAndBiasBcastShape; @@ -3016,11 +3044,11 @@ LogicalResult ConvertAtenOp::matchAndRewrite( Value weightVal = rewriter.create( op.getLoc(), weightAndMeanBcastType, weight, - rewriter.getDenseI64ArrayAttr(weightAndBiasBcastShape)); + tosa::getTosaConstShape(rewriter, op->getLoc(), weightAndBiasBcastShape)); Value biasVal = rewriter.create( op.getLoc(), weightAndMeanBcastType, bias, - rewriter.getDenseI64ArrayAttr(weightAndBiasBcastShape)); + tosa::getTosaConstShape(rewriter, op->getLoc(), weightAndBiasBcastShape)); double eps; if (!matchPattern(op.getEps(), m_TorchConstantFloat(&eps))) @@ -3124,9 +3152,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto newType = RankedTensorType::get(makeShapeLLVMCompatible(newShape), selfType.getElementType()); - auto reshapeOp = - rewriter.create(op.getLoc(), newType, adaptor.getSelf(), - rewriter.getDenseI64ArrayAttr(newShape)); + auto reshapeOp = rewriter.create( + op.getLoc(), newType, adaptor.getSelf(), + tosa::getTosaConstShape(rewriter, op->getLoc(), newShape)); rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), reshapeOp); @@ -3183,7 +3211,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(newType), adaptor.getSelf(), - rewriter.getDenseI64ArrayAttr(newShape)); + tosa::getTosaConstShape(rewriter, op->getLoc(), newShape)); return success(); } @@ -3243,7 +3271,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // If input is not a float type then cast it to output type auto selfElemTy = selfType.getElementType(); if (!isa(selfElemTy)) - self = tosa::promoteType(rewriter, self, outType); + self = tosa::tosaCastTensorToType(rewriter, self, outType).value(); // Constant value of ln2. SmallVector ln2Shape(selfType.getRank(), 1); @@ -3259,8 +3287,10 @@ LogicalResult ConvertAtenOp::matchAndRewrite( rewriter.create(op.getLoc(), ln2Op.getType(), ln2Op); auto logOp = rewriter.create(op.getLoc(), outType, self); - rewriter.replaceOpWithNewOp(op, outType, logOp, rcpOp, - /*shift=*/0); + auto result = tosa::createMulOpAndCast(rewriter, op, outType, logOp, rcpOp, + /*shift=*/0); + + rewriter.replaceOp(op, result.getResult()); return success(); } @@ -3357,7 +3387,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), adaptor.getSelf(), - rewriter.getDenseI64ArrayAttr(outShape)); + tosa::getTosaConstShape(rewriter, op->getLoc(), outShape)); return success(); } @@ -3386,7 +3416,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( ConversionPatternRewriter &rewriter) const { // Not a tensor type. - auto selfType = dyn_cast(adaptor.getInput().getType()); + auto self = adaptor.getInput(); + auto selfType = dyn_cast(self.getType()); if (!selfType) return rewriter.notifyMatchFailure( op, "Only tensor types are currently supported"); @@ -3400,8 +3431,11 @@ LogicalResult ConvertAtenOp::matchAndRewrite( if (train) return rewriter.notifyMatchFailure(op, "train must be false"); - rewriter.replaceOpWithNewOp( - op, getTypeConverter()->convertType(op.getType()), adaptor.getInput()); + auto resultType = + dyn_cast(getTypeConverter()->convertType(op.getType())); + auto result = tosa::tosaCastTensorToType(rewriter, self, resultType).value(); + + rewriter.replaceOp(op, result); return success(); } @@ -3460,7 +3494,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), adaptor.getSelf(), - rewriter.getDenseI64ArrayAttr(outShape)); + tosa::getTosaConstShape(rewriter, op->getLoc(), outShape)); return success(); } @@ -3497,26 +3531,33 @@ approximateErfOp(ConversionPatternRewriter &rewriter, Operation *op, Value x, mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), x, a4).failed()) return std::nullopt; - auto a1X = rewriter.create(loc, outType, a1, absX, /*shift=*/0); + auto a1X = + tosa::createMulOpAndCast(rewriter, op, outType, a1, absX, /*shift=*/0); auto sum = rewriter.create(loc, outType, a1X, one); - auto x2 = rewriter.create(loc, outType, absX, absX, /*shift=*/0); - auto a2X = rewriter.create(loc, outType, a2, x2, /*shift=*/0); + auto x2 = + tosa::createMulOpAndCast(rewriter, op, outType, absX, absX, /*shift=*/0); + auto a2X = + tosa::createMulOpAndCast(rewriter, op, outType, a2, x2, /*shift=*/0); sum = rewriter.create(loc, outType, sum, a2X); - auto x3 = rewriter.create(loc, outType, x2, absX, /*shift=*/0); - auto a3X = rewriter.create(loc, outType, a3, x3, /*shift=*/0); + auto x3 = + tosa::createMulOpAndCast(rewriter, op, outType, x2, absX, /*shift=*/0); + auto a3X = + tosa::createMulOpAndCast(rewriter, op, outType, a3, x3, /*shift=*/0); sum = rewriter.create(loc, outType, sum, a3X); - auto x4 = rewriter.create(loc, outType, x3, absX, /*shift=*/0); - auto a4X = rewriter.create(loc, outType, a4, x4, /*shift=*/0); + auto x4 = + tosa::createMulOpAndCast(rewriter, op, outType, x3, absX, /*shift=*/0); + auto a4X = + tosa::createMulOpAndCast(rewriter, op, outType, a4, x4, /*shift=*/0); sum = rewriter.create(loc, outType, sum, a4X); auto rcprl = rewriter.create(loc, outType, sum); - auto rcprl2 = - rewriter.create(loc, outType, rcprl, rcprl, /*shift=*/0); - auto rcprl4 = - rewriter.create(loc, outType, rcprl2, rcprl2, /*shift=*/0); + auto rcprl2 = tosa::createMulOpAndCast(rewriter, op, outType, rcprl, rcprl, + /*shift=*/0); + auto rcprl4 = tosa::createMulOpAndCast(rewriter, op, outType, rcprl2, rcprl2, + /*shift=*/0); auto erf = rewriter.create(loc, outType, one, rcprl4); // Deal with negative x. @@ -3549,17 +3590,18 @@ buildUnitNormalCdf(ConversionPatternRewriter &rewriter, Operation *op, Value x, auto loc = op->getLoc(); // buildNormalCdf, mean = zero, sigma = one - auto outType = x.getType(); + auto outType = dyn_cast(x.getType()); auto mean = zero; Value xMinusMean = rewriter.create(loc, outType, x, mean); - Value erfArg = rewriter.create(loc, outType, xMinusMean, rsqrt2, - /*shift=*/0); + Value erfArg = + tosa::createMulOpAndCast(rewriter, op, outType, xMinusMean, rsqrt2, + /*shift=*/0); Value erf = approximateErfOp(rewriter, op, erfArg, dtype).value(); Value erfPlus1 = rewriter.create(loc, outType, one, erf); - Value normalCdf = rewriter.create(loc, outType, oneHalf, - erfPlus1, /*shift=*/0); + Value normalCdf = tosa::createMulOpAndCast(rewriter, op, outType, oneHalf, + erfPlus1, /*shift=*/0); return normalCdf; } @@ -3595,12 +3637,12 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // GELU(x) = x * CDF(x) Value cdf = buildUnitNormalCdf(rewriter, op, adaptor.getSelf(), selfElemTy).value(); - cdf = rewriter.createOrFold( - op->getLoc(), - cast(cdf.getType()).cloneWith({}, selfElemTy), cdf); + cdf = tosa::tosaCastTensorToType(rewriter, cdf, selfType).value(); + + auto result = tosa::createMulOpAndCast(rewriter, op, resultType, self, cdf, + /*shift=*/0); - rewriter.replaceOpWithNewOp(op, resultType, self, cdf, - /*shift=*/0); + rewriter.replaceOp(op, result.getResult()); } else if (approximate.compare("tanh") == 0) { // "tanh" approximate // GELU(x) = 0.5 * x * (1 + Tanh(sqrt(2/pi) * (x + 0.044715 * x^3)) @@ -3644,8 +3686,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( .value(); // 0.5 * x - auto halfInput = rewriter.create(op->getLoc(), resultType, - half, self, /*shift=*/0); + auto halfInput = tosa::createMulOpAndCast(rewriter, op, resultType, half, + self, /*shift=*/0); // sqrt(2/pi) auto sqrtTwoOverPi = @@ -3657,16 +3699,16 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // 0.044715 * x^3 auto inputPowThreeMul = - rewriter.create(op->getLoc(), resultType, magicNumber, - inputPowThree.getResult(), /*shift=*/0); + tosa::createMulOpAndCast(rewriter, op, resultType, magicNumber, + inputPowThree.getResult(), /*shift=*/0); // x + 0.044715 * x^3 auto inputPowThreeMulAdd = rewriter.create( op->getLoc(), resultType, self, inputPowThreeMul.getResult()); // sqrt(2/pi) * (x + 0.044715 * x^3) - auto sqrtTwoOverPiMul = rewriter.create( - op->getLoc(), resultType, sqrtTwoOverPi.getResult(), + auto sqrtTwoOverPiMul = tosa::createMulOpAndCast( + rewriter, op, resultType, sqrtTwoOverPi.getResult(), inputPowThreeMulAdd.getResult(), /*shift=*/0); // tanh(sqrt(2/pi) * (x + 0.044715 * x^3)) @@ -3677,9 +3719,11 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto tanhAdd = rewriter.create(op->getLoc(), resultType, one, tanh.getResult()); - rewriter.replaceOpWithNewOp( - op, resultType, halfInput.getResult(), tanhAdd.getResult(), - /*shift=*/0); + auto result = tosa::createMulOpAndCast(rewriter, op, resultType, + halfInput.getResult(), + tanhAdd.getResult(), /*shift=*/0); + + rewriter.replaceOp(op, result.getResult()); } else { return rewriter.notifyMatchFailure(op, "Unsupported approximation algorithm"); @@ -3733,22 +3777,25 @@ LogicalResult ConvertAtenOp::matchAndRewrite( op, "Failed to equalize ranks among operands and result"); Value inputSquared = - rewriter.create(loc, selfType, self, self, /*shift=*/0); - Value negHalfInputSquared = rewriter.create( - loc, selfType, inputSquared, negOneHalf, /*shift=*/0); + tosa::createMulOpAndCast(rewriter, op, selfType, self, self, /*shift=*/0); + Value negHalfInputSquared = tosa::createMulOpAndCast( + rewriter, op, selfType, inputSquared, negOneHalf, /*shift=*/0); Value dinput = rewriter.create(loc, selfType, negHalfInputSquared); Value cdf = buildUnitNormalCdf(rewriter, op, self, selfElemTy).value(); - Value dinputInput = - rewriter.create(loc, selfType, dinput, self, /*shift=*/0); - Value dinputInputAlpha = rewriter.create( - loc, selfType, dinputInput, kAlphaHalf, /*shift=*/0); + Value dinputInput = tosa::createMulOpAndCast(rewriter, op, selfType, dinput, + self, /*shift=*/0); + Value dinputInputAlpha = tosa::createMulOpAndCast( + rewriter, op, selfType, dinputInput, kAlphaHalf, /*shift=*/0); Value cdfExt = rewriter.create(loc, selfType, dinputInputAlpha, cdf); - rewriter.replaceOpWithNewOp( - op, getTypeConverter()->convertType(op.getType()), - adaptor.getGradOutput(), cdfExt, - /*shift=*/0); + + auto resultTy = + dyn_cast(getTypeConverter()->convertType(op.getType())); + auto result = tosa::createMulOpAndCast( + rewriter, op, resultTy, adaptor.getGradOutput(), cdfExt, /*shift=*/0); + + rewriter.replaceOp(op, result.getResult()); return success(); } @@ -3901,7 +3948,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( op->getLoc(), RankedTensorType::get(makeShapeLLVMCompatible(newWeightShape), weightType.getElementType()), - weight, rewriter.getDenseI64ArrayAttr(newWeightShape)); + weight, tosa::getTosaConstShape(rewriter, op->getLoc(), newWeightShape)); int64_t numIndices = 1; if (indicesType.hasStaticShape()) { @@ -3916,13 +3963,15 @@ LogicalResult ConvertAtenOp::matchAndRewrite( op->getLoc(), RankedTensorType::get(makeShapeLLVMCompatible(newIndicesShape), indicesType.getElementType()), - indices, rewriter.getDenseI64ArrayAttr(newIndicesShape)); - - auto castIndices = rewriter.create( - op->getLoc(), - RankedTensorType::get(makeShapeLLVMCompatible(newIndicesShape), - rewriter.getIntegerType(32)), - reshapedIndices); + indices, + tosa::getTosaConstShape(rewriter, op->getLoc(), newIndicesShape)); + + auto castIndices = + tosa::tosaCastTensorToType( + rewriter, reshapedIndices, + RankedTensorType::get(makeShapeLLVMCompatible(newIndicesShape), + rewriter.getIntegerType(32))) + .value(); SmallVector intermediateOutShape = {1, numIndices, weightShape[1]}; auto gatherOp = rewriter.create( @@ -3933,8 +3982,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( rewriter.replaceOpWithNewOp( op, outType, gatherOp, - rewriter.getDenseI64ArrayAttr( - makeShapeTorchCompatible(outType.getShape()))); + tosa::getTosaConstShape(rewriter, op->getLoc(), + makeShapeTorchCompatible(outType.getShape()))); return success(); } @@ -4032,7 +4081,8 @@ class ConvertAtenMinMaxDimOp : public OpConversionPattern { } auto dimAttr = rewriter.getIntegerAttr(rewriter.getI32Type(), dim); - auto prunedShapeAttr = rewriter.getDenseI64ArrayAttr(prunedShape); + auto prunedShapeValue = + tosa::getTosaConstShape(rewriter, op->getLoc(), prunedShape); Value reduceOp; if constexpr (std::is_same() || @@ -4077,7 +4127,7 @@ class ConvertAtenMinMaxDimOp : public OpConversionPattern { if (argMaxOp.getType() != indicesType) { argMaxOp = rewriter.create( op->getLoc(), indicesType, argMaxOp, - rewriter.getDenseI64ArrayAttr(reducedShape)); + tosa::getTosaConstShape(rewriter, op->getLoc(), reducedShape)); } if (!keepDim) { @@ -4085,7 +4135,7 @@ class ConvertAtenMinMaxDimOp : public OpConversionPattern { op->getLoc(), RankedTensorType::get(makeShapeLLVMCompatible(prunedShape), selfElemType), - reduceOp, prunedShapeAttr); + reduceOp, prunedShapeValue); } rewriter.replaceOp(op, {reduceOp, argMaxOp}); @@ -4181,8 +4231,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto slice = rewriter.create( op.getLoc(), outTy.cloneWith(sliceShape, outTy.getElementType()), - reshaped, rewriter.getDenseI64ArrayAttr(startSlice), - rewriter.getDenseI64ArrayAttr(sliceShape)); + reshaped, tosa::getTosaConstShape(rewriter, op->getLoc(), startSlice), + tosa::getTosaConstShape(rewriter, op->getLoc(), sliceShape)); auto out = tosa::reshapeTo(op->getLoc(), rewriter, slice, outTy.getShape()); @@ -4270,7 +4320,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( op->getLoc(), RankedTensorType::get(makeShapeTorchCompatible(targetInputShape), selfElemTy), - self, rewriter.getDenseI64ArrayAttr(targetInputShape)); + self, + tosa::getTosaConstShape(rewriter, op->getLoc(), targetInputShape)); SmallVector tileOpShape; for (int64_t i = 0; i < outputRank; i++) { @@ -4332,11 +4383,11 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // index i64 to i32 for tosa compatitable if (indexType.getElementType() != rewriter.getIntegerType(32)) { - index = rewriter.create( - op->getLoc(), - RankedTensorType::get(indexType.getShape(), - rewriter.getIntegerType(32)), - index); + index = tosa::tosaCastTensorToType( + rewriter, index, + RankedTensorType::get(indexType.getShape(), + rewriter.getIntegerType(32))) + .value(); } // Get positive dim @@ -4407,7 +4458,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( index = rewriter.create( op->getLoc(), RankedTensorType::get(indexShape, indexType.getElementType()), index, - rewriter.getDenseI64ArrayAttr(indexShape)); + tosa::getTosaConstShape(rewriter, op->getLoc(), indexShape)); } // Dynamic shape check @@ -4418,9 +4469,10 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // index i64 to i32 for tosa compatible if (indexType.getElementType() != rewriter.getIntegerType(32)) { - index = rewriter.create( - op->getLoc(), - RankedTensorType::get(indexShape, rewriter.getIntegerType(32)), index); + index = tosa::tosaCastTensorToType( + rewriter, index, + RankedTensorType::get(indexShape, rewriter.getIntegerType(32))) + .value(); } // Get positive dim @@ -4461,7 +4513,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto reshapedIndices = rewriter.create( op->getLoc(), indicesInputRankType, index, - rewriter.getDenseI64ArrayAttr(indicesInputRankShape)); + tosa::getTosaConstShape(rewriter, op->getLoc(), indicesInputRankShape)); SmallVector tileShape(indicesInputRankShape); SmallVector expandedIndicesShape(indicesInputRankShape); @@ -4557,10 +4609,11 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // index i64 to i32 for tosa compatible if (indexType.getElementType() != rewriter.getIntegerType(32)) - index = rewriter.create( - op->getLoc(), - RankedTensorType::get(indexShape, rewriter.getIntegerType(32)), - index); + index = + tosa::tosaCastTensorToType( + rewriter, index, + RankedTensorType::get(indexShape, rewriter.getIntegerType(32))) + .value(); // Expand last dim of index to tf indices [3] -> [3,1] // convert [0,0,0] to [[0],[0],[0]] @@ -4572,7 +4625,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto indicesTfOneDim = tosa::CreateOpAndInfer( rewriter, op->getLoc(), RankedTensorType::get(indiceShapeOneDim, rewriter.getIntegerType(32)), - index, rewriter.getDenseI64ArrayAttr(indiceShapeOneDim)); + index, + tosa::getTosaConstShape(rewriter, op->getLoc(), indiceShapeOneDim)); // create concat tensor for indicesTf // ([[0],[0],[0]], [[1],[2],[3]]) @@ -4700,10 +4754,11 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Make type of index tosa compatible, i64 to i32. if (indexType.getElementType() != rewriter.getIntegerType(32)) { - index = rewriter.create( - op->getLoc(), - RankedTensorType::get(indexShape, rewriter.getIntegerType(32)), - index); + index = + tosa::tosaCastTensorToType( + rewriter, index, + RankedTensorType::get(indexShape, rewriter.getIntegerType(32))) + .value(); } index = wrapNegativeIndices(index, inputTensorType.getShape()[i], op, @@ -4718,7 +4773,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto indicesTfOneDim = tosa::CreateOpAndInfer( rewriter, op->getLoc(), RankedTensorType::get(indiceShapeOneDim, rewriter.getIntegerType(32)), - index, rewriter.getDenseI64ArrayAttr(indiceShapeOneDim)); + index, + tosa::getTosaConstShape(rewriter, op->getLoc(), indiceShapeOneDim)); // create concat tensor for indicesTf indicesTfConcatTensors.push_back(indicesTfOneDim.getResult()); @@ -4790,7 +4846,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Update the tensor array with the max rank-extended form indicesTfConcatTensors[i] = rewriter.create( op->getLoc(), reshapeOutputTy, unreshapedIdxTensor, - rewriter.getDenseI64ArrayAttr(broadcastedShapeTf)); + tosa::getTosaConstShape(rewriter, op->getLoc(), + broadcastedShapeTf)); } // Construct the max rank broadcasted form of all index tensors with @@ -4870,10 +4927,11 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto indexShape = indexType.getShape(); // index i64 to i32 for tosa compatible if (indexType.getElementType() != rewriter.getIntegerType(32)) { - index = rewriter.create( - op->getLoc(), - RankedTensorType::get(indexShape, rewriter.getIntegerType(32)), - index); + index = + tosa::tosaCastTensorToType( + rewriter, index, + RankedTensorType::get(indexShape, rewriter.getIntegerType(32))) + .value(); } index = @@ -4889,7 +4947,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( indicesTf = tosa::CreateOpAndInfer( rewriter, op->getLoc(), RankedTensorType::get(indicesShape, rewriter.getIntegerType(32)), index, - rewriter.getDenseI64ArrayAttr(indicesShape)); + tosa::getTosaConstShape(rewriter, op->getLoc(), indicesShape)); } if (!indicesTf) { @@ -4959,9 +5017,10 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // index i64 to i32 for tosa compatitable if (indexType.getElementType() != rewriter.getIntegerType(32)) { - index = rewriter.create( - op->getLoc(), - RankedTensorType::get(indexShape, rewriter.getIntegerType(32)), index); + index = tosa::tosaCastTensorToType( + rewriter, index, + RankedTensorType::get(indexShape, rewriter.getIntegerType(32))) + .value(); } // Get positive dim @@ -5232,8 +5291,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( rewriter.create(op->getLoc(), selfType, rhsSubOp); auto lhsAbsOp = rewriter.create(op->getLoc(), otherType, other); - auto mulOp = rewriter.create(op->getLoc(), otherType, - rtolConstOp, lhsAbsOp, /*shift=*/0); + auto mulOp = tosa::createMulOpAndCast(rewriter, op, otherType, rtolConstOp, + lhsAbsOp, /*shift=*/0); auto addOp = rewriter.create(op->getLoc(), otherType, atolConstOp, mulOp); @@ -5401,9 +5460,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "Failed to equalize ranks among operands and result"); - self = tosa::promoteType(rewriter, self, resultType); - min = tosa::promoteType(rewriter, min, resultType); - max = tosa::promoteType(rewriter, max, resultType); + self = tosa::tosaCastTensorToType(rewriter, self, resultType).value(); + min = tosa::tosaCastTensorToType(rewriter, min, resultType).value(); + max = tosa::tosaCastTensorToType(rewriter, max, resultType).value(); // max(xi, min_valuei) // Use default NaN Propagation mode "PROPAGATE" for tosa.maximum @@ -5575,8 +5634,10 @@ LogicalResult ConvertAtenOp::matchAndRewrite( op, "failed to generate constant tensor for arange"); } auto result = maybeResult.value(); + result = tosa::tosaCastTensorToType(rewriter, result, resultType).value(); + + rewriter.replaceOp(op, result); - rewriter.replaceOpWithNewOp(op, resultType, result); return success(); } @@ -5631,6 +5692,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "Only tensor types with static shape are supported"); + auto resultTy = + dyn_cast(getTypeConverter()->convertType(op.getType())); + // The non_blocking should be a constant `False`. bool nonBlocking; if (!matchPattern(op.getNonBlocking(), m_TorchConstantBool(&nonBlocking))) { @@ -5647,12 +5711,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( if (llvm::equal(selfShape, srcShape) || selfShape.size() == 0) { // If we reach here, then it means the given case is handled by implicit // broadcasting done by tosa. - Value result; - if (failed(tosa::tosaCastTensorToType( - rewriter, op, adaptor.getSrc(), - getTypeConverter()->convertType(op.getType()), result))) - return rewriter.notifyMatchFailure( - op, "unimplemented: cast to result type not supported"); + Value result = + tosa::tosaCastTensorToType(rewriter, adaptor.getSrc(), resultTy) + .value(); rewriter.replaceOp(op, result); return success(); } @@ -5709,10 +5770,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto resultTy = cast( getTypeConverter()->convertType(op.getResult().getType())); - Value result; - if (failed(tosa::tosaCastTensorToType(rewriter, op, adaptor.getSelf(), - resultTy, result))) - return rewriter.notifyMatchFailure(op, "conversion to result type failed"); + Value result = + tosa::tosaCastTensorToType(rewriter, adaptor.getSelf(), resultTy).value(); rewriter.replaceOp(op, result); return success(); @@ -5770,7 +5829,7 @@ class ConvertAtenRemainderFmodOp : public OpConversionPattern { std::is_same(); if (selfTy.getElementType() != outElemTy) - self = rewriter.create(op.getLoc(), outType, self); + self = tosa::tosaCastTensorToType(rewriter, self, outType).value(); Value divTensor; if (isRemainderOp) { @@ -5778,8 +5837,8 @@ class ConvertAtenRemainderFmodOp : public OpConversionPattern { if (isa(outElemTy)) { auto otherTensorReciprocal = rewriter.create( op.getLoc(), otherTensor.getType(), otherTensor); - divTensor = rewriter.create( - op.getLoc(), outType, self, otherTensorReciprocal, /*shift=*/0); + divTensor = tosa::createMulOpAndCast( + rewriter, op, outType, self, otherTensorReciprocal, /*shift=*/0); divTensor = rewriter.create(op.getLoc(), outType, divTensor); } else { @@ -5794,19 +5853,21 @@ class ConvertAtenRemainderFmodOp : public OpConversionPattern { // TOSA IntDiv requires inputs to be i32 auto i32Type = RankedTensorType::get(outType.getShape(), rewriter.getIntegerType(32)); - self = tosa::promoteType(rewriter, self, i32Type); - otherTensor = tosa::promoteType(rewriter, otherTensor, i32Type); + self = tosa::tosaCastTensorToType(rewriter, self, i32Type).value(); + otherTensor = + tosa::tosaCastTensorToType(rewriter, otherTensor, i32Type).value(); auto intDivTensor = rewriter.create( op->getLoc(), i32Type, self, otherTensor); - divTensor = tosa::promoteType(rewriter, intDivTensor, outType); + divTensor = + tosa::tosaCastTensorToType(rewriter, intDivTensor, outType).value(); } } - auto mulTensor = rewriter.create(op.getLoc(), outType, - otherTensor, divTensor, - /*shift=*/0); + auto mulTensor = + tosa::createMulOpAndCast(rewriter, op, outType, otherTensor, divTensor, + /*shift=*/0); rewriter.replaceOpWithNewOp(op, outType, self, mulTensor); return success(); @@ -5955,7 +6016,8 @@ class ConvertAtenPoolingBaseOp : public OpConversionPattern { RankedTensorType::get(makeShapeTorchCompatible(resultShape), resultElemTy), transposedOutput, - rewriter.getDenseI64ArrayAttr(makeShapeTorchCompatible(resultShape))); + tosa::getTosaConstShape(rewriter, op->getLoc(), + makeShapeTorchCompatible(resultShape))); } rewriter.replaceOpWithNewOp(op, resultTy, result); @@ -6245,7 +6307,7 @@ class ConvertAtenMaxPool1dOp op->getLoc(), RankedTensorType::get(makeShapeTorchCompatible(rank4Shape), selfTy.getElementType()), - self, rewriter.getDenseI64ArrayAttr(rank4Shape)); + self, tosa::getTosaConstShape(rewriter, op->getLoc(), rank4Shape)); SmallVector dilationArray; if (!matchPattern(op.getDilation(), @@ -6343,7 +6405,7 @@ class ConvertAtenAvgPool1dOp op->getLoc(), RankedTensorType::get(makeShapeTorchCompatible(rank4Shape), selfTy.getElementType()), - self, rewriter.getDenseI64ArrayAttr(rank4Shape)); + self, tosa::getTosaConstShape(rewriter, op->getLoc(), rank4Shape)); SmallVector dilationArray{1, 1}; if (failed(getOutputTypeAndPoolingParameters { auto constOp = tosa::getConstTensor(rewriter, op, values, shape).value(); - rewriter.replaceOpWithNewOp(op, outType, constOp); + auto result = + tosa::tosaCastTensorToType(rewriter, constOp, outType).value(); + + rewriter.replaceOp(op, result); return success(); } @@ -6466,7 +6531,8 @@ class ConvertAtenFillOp : public OpConversionPattern { auto fillValueMatchedInputRankTensor = rewriter.create( op->getLoc(), fillValueMatchedInputRankType, fillValue, - rewriter.getDenseI64ArrayAttr(fillValueMatchedInputRankShape)); + tosa::getTosaConstShape(rewriter, op->getLoc(), + fillValueMatchedInputRankShape)); auto tileOpMultiples = tosa::getTosaConstShape(rewriter, op->getLoc(), outType.getShape()); @@ -6484,8 +6550,11 @@ class ConvertAtenFillOp : public OpConversionPattern { op, "Fill value must be a scalar constant"); } - rewriter.replaceOpWithNewOp(op, outType, - fillValueTargetTensor); + auto result = + tosa::tosaCastTensorToType(rewriter, fillValueTargetTensor, outType) + .value(); + + rewriter.replaceOp(op, result); return success(); } @@ -6542,10 +6611,8 @@ class ConvertAtenMaskedFillOp : public OpConversionPattern { auto rhsTensor = rhsType ? rhs : rhsAsTensor; auto rhsTensorType = dyn_cast(rhsTensor.getType()); if (rhsTensorType.getElementType() != outElemTy) - rhsTensor = rewriter.create( - op.getLoc(), - RankedTensorType::get(rhsTensorType.getShape(), outElemTy), - rhsTensor); + rhsTensor = + tosa::tosaCastTensorToType(rewriter, rhsTensor, outType).value(); if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), self, rhsTensor) .failed()) @@ -6580,7 +6647,11 @@ class ConvertAtenCloneOp : public OpConversionPattern { auto outType = dyn_cast( OpConversionPattern::getTypeConverter()->convertType( op.getType())); - rewriter.replaceOpWithNewOp(op, outType, adaptor.getSelf()); + + auto result = + tosa::tosaCastTensorToType(rewriter, adaptor.getSelf(), outType) + .value(); + rewriter.replaceOp(op, result); return success(); } @@ -6685,7 +6756,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( getTypeConvertedValues(rewriter, loc, typeConverter, tensorsTorchType); for (auto &tensor : builtinTensors) - tensor = tosa::promoteType(rewriter, tensor, outType); + tensor = tosa::tosaCastTensorToType(rewriter, tensor, outType).value(); auto result = tosa::CreateOpAndInfer( rewriter, loc, outType, builtinTensors, rewriter.getI32IntegerAttr(dim)); @@ -6709,11 +6780,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( cast(typeConverter->convertType(op.getType())); auto elementType = resultType.getElementType(); - if (isa(selfTy.getElementType())) { - self = rewriter.createOrFold( - op->getLoc(), RankedTensorType::get(resultType.getShape(), elementType), - self); - } + if (isa(selfTy.getElementType())) + self = tosa::tosaCastTensorToType(rewriter, self, resultType).value(); auto oneHalf = tosa::getConstTensor(rewriter, op, 0.5, {}, elementType).value(); @@ -7010,8 +7078,10 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "Failed to equalize ranks among operands and result"); - rewriter.replaceOpWithNewOp(op, resultType, self, trilMask, - /*shift=*/0); + auto result = + tosa::createMulOpAndCast(rewriter, op, resultType, self, trilMask, + /*shift=*/0); + rewriter.replaceOp(op, result.getResult()); return success(); } @@ -7106,15 +7176,15 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto ceilInput = rewriter.create(op->getLoc(), resultTy, self); - auto floorInputDivByTwo = rewriter.create( - op->getLoc(), resultTy, floorInput.getResult(), oneHalf, /*shift=*/0); + auto floorInputDivByTwo = tosa::createMulOpAndCast( + rewriter, op, resultTy, floorInput.getResult(), oneHalf, /*shift=*/0); auto floorDivResult = rewriter.create( op->getLoc(), resultTy, floorInputDivByTwo.getResult()); // (floor(input) // 2) * 2 - auto evenComparison = rewriter.create( - op->getLoc(), resultTy, floorDivResult.getResult(), two, /*shift=*/0); + auto evenComparison = tosa::createMulOpAndCast( + rewriter, op, resultTy, floorDivResult.getResult(), two, /*shift=*/0); // floor(input) // 2) * 2 == input <=> floor(input) % 2 == 0 auto floorInputEven = rewriter.create( @@ -7296,8 +7366,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "Failed to equalize ranks among operands and result"); - Value diagonalTensor = rewriter.create( - op->getLoc(), transposedInputType, selfTransposed, diagonalMask, + Value diagonalTensor = tosa::createMulOpAndCast( + rewriter, op, transposedInputType, selfTransposed, diagonalMask, /*shift=*/0); auto resultShape = makeShapeTorchCompatible(resultType.getShape()); @@ -7319,8 +7389,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( startSlice[targetDim1] = std::abs(offset); diagonalTensor = rewriter.create( op->getLoc(), transposedInputType, diagonalTensor, - rewriter.getDenseI64ArrayAttr(startSlice), - rewriter.getDenseI64ArrayAttr(sizeSlice)); + tosa::getTosaConstShape(rewriter, op->getLoc(), startSlice), + tosa::getTosaConstShape(rewriter, op->getLoc(), sizeSlice)); } // Apply Reduce Sum to get the result @@ -7426,7 +7496,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto scatterSrc = rewriter.create( op->getLoc(), RankedTensorType::get(makeShapeTorchCompatible(indexShape), selfElemTy), - self, rewriter.getDenseI64ArrayAttr(indexShape)); + self, tosa::getTosaConstShape(rewriter, op->getLoc(), indexShape)); // Create a const zero tensor to scatter the input onto SmallVector zeroShape; @@ -7580,7 +7650,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( selfType.getElementType()) .value(); - result = tosa::promoteType(rewriter, result, resultType); + result = tosa::tosaCastTensorToType(rewriter, result, resultType).value(); rewriter.replaceOp(op, {result}); @@ -7646,8 +7716,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( op->getLoc(), RankedTensorType::get(selfShape, rewriter.getI1Type()), threshold, self); - self = tosa::promoteType(rewriter, self, resultType); - grad = tosa::promoteType(rewriter, grad, resultType); + self = tosa::tosaCastTensorToType(rewriter, self, resultType).value(); + grad = tosa::tosaCastTensorToType(rewriter, grad, resultType).value(); if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), self, zero).failed() || mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), self, grad).failed()) @@ -7708,7 +7778,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto self1D = rewriter.create( op->getLoc(), RankedTensorType::get({selfNumElems}, selfElemTy), self, - rewriter.getDenseI64ArrayAttr({selfNumElems})); + tosa::getTosaConstShape(rewriter, op->getLoc(), {selfNumElems})); // Calculate the target elements indices SmallVector targetIndicesVec; @@ -7757,7 +7827,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto result = rewriter.create( op->getLoc(), resultType, gatherOp.value(), - rewriter.getDenseI64ArrayAttr(outputSize)); + tosa::getTosaConstShape(rewriter, op->getLoc(), outputSize)); rewriter.replaceOp(op, {result.getResult()}); @@ -7802,7 +7872,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // here, which is more simple and quicker. rewriter.replaceOpWithNewOp( op, resultType, self, - rewriter.getDenseI64ArrayAttr(makeShapeTorchCompatible(resultShape))); + tosa::getTosaConstShape(rewriter, op->getLoc(), + makeShapeTorchCompatible(resultShape))); return success(); } @@ -7845,8 +7916,9 @@ Value reflectionPadAlongAxis(Value input, ArrayRef unpaddedShape, auto leftPadType = RankedTensorType::get(leftPadShape, inputElemTy); auto leftPadSlice = rewriter.create( - loc, leftPadType, input, rewriter.getDenseI64ArrayAttr(leftStartSlice), - rewriter.getDenseI64ArrayAttr(leftSizeSlice)); + loc, leftPadType, input, + tosa::getTosaConstShape(rewriter, loc, leftStartSlice), + tosa::getTosaConstShape(rewriter, loc, leftSizeSlice)); auto leftPad = rewriter.create( loc, leftPadType, leftPadSlice.getResult(), static_cast(axis)); @@ -7878,8 +7950,8 @@ Value reflectionPadAlongAxis(Value input, ArrayRef unpaddedShape, auto rightPadSlice = rewriter.create( loc, rightPadType, input, - rewriter.getDenseI64ArrayAttr(rightStartSlice), - rewriter.getDenseI64ArrayAttr(rightSizeSlice)); + tosa::getTosaConstShape(rewriter, loc, rightStartSlice), + tosa::getTosaConstShape(rewriter, loc, rightSizeSlice)); auto rightPad = rewriter.create( loc, rightPadType, rightPadSlice.getResult(), @@ -8125,8 +8197,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto leftPadSlice = rewriter.create( op->getLoc(), leftPadSliceType, self, - rewriter.getDenseI64ArrayAttr(leftStartSlice), - rewriter.getDenseI64ArrayAttr(leftSizeSlice)); + tosa::getTosaConstShape(rewriter, op->getLoc(), leftStartSlice), + tosa::getTosaConstShape(rewriter, op->getLoc(), leftSizeSlice)); for (int64_t i = 0; i < paddingLeft; i++) sideTensors.push_back(leftPadSlice.getResult()); @@ -8150,8 +8222,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto rightPadSlice = rewriter.create( op->getLoc(), rightPadSliceType, self, - rewriter.getDenseI64ArrayAttr(rightStartSlice), - rewriter.getDenseI64ArrayAttr(rightSizeSlice)); + tosa::getTosaConstShape(rewriter, op->getLoc(), rightStartSlice), + tosa::getTosaConstShape(rewriter, op->getLoc(), rightSizeSlice)); for (int64_t i = 0; i < paddingRight; i++) sideTensors.push_back(rightPadSlice.getResult()); @@ -8185,8 +8257,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto topPadSlice = rewriter.create( op->getLoc(), topPadSliceType, selfSidePadded, - rewriter.getDenseI64ArrayAttr(topStartSlice), - rewriter.getDenseI64ArrayAttr(topSizeSlice)); + tosa::getTosaConstShape(rewriter, op->getLoc(), topStartSlice), + tosa::getTosaConstShape(rewriter, op->getLoc(), topSizeSlice)); for (int64_t i = 0; i < paddingTop; i++) resultTensors.push_back(topPadSlice.getResult()); @@ -8213,8 +8285,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto bottomPadSlice = rewriter.create( op->getLoc(), bottomPadSliceType, selfSidePadded, - rewriter.getDenseI64ArrayAttr(bottomStartSlice), - rewriter.getDenseI64ArrayAttr(bottomSizeSlice)); + tosa::getTosaConstShape(rewriter, op->getLoc(), bottomStartSlice), + tosa::getTosaConstShape(rewriter, op->getLoc(), bottomSizeSlice)); for (int64_t i = 0; i < paddingBottom; i++) resultTensors.push_back(bottomPadSlice.getResult()); @@ -8264,7 +8336,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // approach here, which is more simple and quicker. rewriter.replaceOpWithNewOp( op, resultType, self, - rewriter.getDenseI64ArrayAttr(makeShapeTorchCompatible(resultShape))); + tosa::getTosaConstShape(rewriter, op->getLoc(), + makeShapeTorchCompatible(resultShape))); return success(); } @@ -8296,8 +8369,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( dyn_cast(typeConverter->convertType(op.getType())); auto resultShape = resultType.getShape(); - self = tosa::promoteType(rewriter, self, resultType); - vec2 = tosa::promoteType(rewriter, vec2, resultType); + self = tosa::tosaCastTensorToType(rewriter, self, resultType).value(); + vec2 = tosa::tosaCastTensorToType(rewriter, vec2, resultType).value(); SmallVector resultShapeIndex1Replaced({resultShape[0], 1}); SmallVector resultShapeIndex0Replaced({1, resultShape[1]}); @@ -8307,7 +8380,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( op->getLoc(), RankedTensorType::get(resultShapeIndex1Replaced, resultType.getElementType()), - self, rewriter.getDenseI64ArrayAttr(resultShapeIndex1Replaced)); + self, + tosa::getTosaConstShape(rewriter, op->getLoc(), + resultShapeIndex1Replaced)); auto selfTileOpMultiples = tosa::getTosaConstShape(rewriter, op->getLoc(), resultShapeIndex0Replaced); @@ -8320,7 +8395,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( op->getLoc(), RankedTensorType::get(resultShapeIndex0Replaced, resultType.getElementType()), - vec2, rewriter.getDenseI64ArrayAttr(resultShapeIndex0Replaced)); + vec2, + tosa::getTosaConstShape(rewriter, op->getLoc(), + resultShapeIndex0Replaced)); auto vec2TileOpMultiples = tosa::getTosaConstShape(rewriter, op->getLoc(), resultShapeIndex1Replaced); @@ -8471,7 +8548,8 @@ class ConvertUpsampleNearest2dForward : public OpConversionPattern { auto reshapedSelf = rewriter.create( op->getLoc(), RankedTensorType::get(reshapedSelfShape, selfElemTy), - self, rewriter.getDenseI64ArrayAttr(reshapedSelfShape)); + self, + tosa::getTosaConstShape(rewriter, op->getLoc(), reshapedSelfShape)); // Calculate PyTorch-styled gather indices SmallVector targetIndicesVec; @@ -8513,7 +8591,7 @@ class ConvertUpsampleNearest2dForward : public OpConversionPattern { auto result = rewriter.create( op->getLoc(), resultType, gatherOp.value(), - rewriter.getDenseI64ArrayAttr(resultShape)); + tosa::getTosaConstShape(rewriter, op->getLoc(), resultShape)); rewriter.replaceOp(op, {result.getResult()}); @@ -8549,7 +8627,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // If input is not a float type then cast it to result element type auto selfElemTy = selfType.getElementType(); if (!isa(selfElemTy)) - self = tosa::promoteType(rewriter, self, resultType); + self = tosa::tosaCastTensorToType(rewriter, self, resultType).value(); bool isEpsNone = isa(op.getEps().getType()); @@ -8587,9 +8665,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto oneMinusZiReciprocal = rewriter.create( op->getLoc(), resultType, oneMinusZi.getResult()); - auto mulOp = rewriter.create(op->getLoc(), resultType, zi, - oneMinusZiReciprocal.getResult(), - /*shift=*/0); + auto mulOp = tosa::createMulOpAndCast(rewriter, op, resultType, zi, + oneMinusZiReciprocal.getResult(), + /*shift=*/0); auto result = rewriter.create(op->getLoc(), resultType, mulOp.getResult()); @@ -8623,7 +8701,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // If input is not a float type then cast it to result element type auto selfElemTy = selfType.getElementType(); if (!isa(selfElemTy)) - self = tosa::promoteType(rewriter, self, resultType); + self = tosa::tosaCastTensorToType(rewriter, self, resultType).value(); auto one = tosa::getConstTensor(rewriter, op, 1.0f, {}, resultElemTy).value(); @@ -8668,7 +8746,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // If input is not a float type then cast it to result element type auto selfElemTy = selfType.getElementType(); if (!isa(selfElemTy)) - self = tosa::promoteType(rewriter, self, resultType); + self = tosa::tosaCastTensorToType(rewriter, self, resultType).value(); auto ten = tosa::getConstTensor(rewriter, op, 10.0f, {}, resultElemTy) .value(); @@ -8687,8 +8765,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto reciprocalOp = rewriter.create( op->getLoc(), constTenType, logOfTen.getResult()); - auto result = rewriter.create( - op->getLoc(), resultType, logOfSelf.getResult(), reciprocalOp.getResult(), + auto result = tosa::createMulOpAndCast( + rewriter, op, resultType, logOfSelf.getResult(), reciprocalOp.getResult(), /*shift=*/0); rewriter.replaceOp(op, {result.getResult()}); @@ -8722,7 +8800,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // If input is not a float type then cast it to result element type auto selfElemTy = selfType.getElementType(); if (!isa(selfElemTy)) - self = tosa::promoteType(rewriter, self, resultType); + self = tosa::tosaCastTensorToType(rewriter, self, resultType).value(); auto one = tosa::getConstTensor(rewriter, op, 1.0f, {}, resultElemTy).value(); @@ -8763,7 +8841,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Non floating point inputs are not supported in TOSA so we cast the input // to result type if (!isa(selfType.getElementType())) - self = tosa::promoteType(rewriter, self, resultType); + self = tosa::tosaCastTensorToType(rewriter, self, resultType).value(); auto sinOp = rewriter.create(op->getLoc(), resultType, self); @@ -8772,8 +8850,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto reciprocalOp = rewriter.create(op->getLoc(), resultType, cosOp); - auto result = rewriter.create( - op->getLoc(), resultType, sinOp.getResult(), reciprocalOp.getResult(), + auto result = tosa::createMulOpAndCast( + rewriter, op, resultType, sinOp.getResult(), reciprocalOp.getResult(), /*shift=*/0); rewriter.replaceOp(op, {result.getResult()}); @@ -8847,7 +8925,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto result = rewriter.create( op->getLoc(), RankedTensorType::get({1}, selfElemTy), self, - rewriter.getDenseI64ArrayAttr({1})); + tosa::getTosaConstShape(rewriter, op->getLoc(), {1})); rewriter.replaceOp(op, {result.getResult()}); return success(); @@ -8949,7 +9027,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto reshapeOp = rewriter.create( op->getLoc(), RankedTensorType::get(intermediaryShape, resultElemTy), - gatherNdOp.value(), rewriter.getDenseI64ArrayAttr(intermediaryShape)); + gatherNdOp.value(), + tosa::getTosaConstShape(rewriter, op->getLoc(), intermediaryShape)); // Permute dims to the correct result order SmallVector permutedDims; diff --git a/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp b/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp index 1f18cabd8cb0..3ff9bf9c3448 100644 --- a/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp +++ b/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp @@ -114,13 +114,17 @@ createOneDimTfIndices(PatternRewriter &rewriter, Operation *op, return indicesDim; } +// Default function to create TOSA op with shift value tosa::MulOp createMulOpAndCast(PatternRewriter &rewriter, Operation *op, TensorType outType, Value lhs, Value rhs, int32_t shift) { - lhs = promoteType(rewriter, lhs, outType); - rhs = promoteType(rewriter, rhs, outType); + lhs = tosa::tosaCastTensorToType(rewriter, lhs, outType).value(); + rhs = tosa::tosaCastTensorToType(rewriter, rhs, outType).value(); + + auto constShift = tosa::getTosaMulShiftConstTensor(rewriter, op, shift); + return tosa::CreateOpAndInfer(rewriter, op->getLoc(), outType, - lhs, rhs, shift); + lhs, rhs, constShift); } template <> @@ -134,8 +138,8 @@ createBinaryOpAndCast(PatternRewriter &rewriter, Operation *op, op, "tosa.int_div only supports integer type"); } - lhs = promoteType(rewriter, lhs, outType); - rhs = promoteType(rewriter, rhs, outType); + lhs = tosa::tosaCastTensorToType(rewriter, lhs, outType).value(); + rhs = tosa::tosaCastTensorToType(rewriter, rhs, outType).value(); return tosa::CreateOpAndInfer(rewriter, op->getLoc(), outType, lhs, rhs); } @@ -186,7 +190,8 @@ std::optional convertTorchIndexToTfIndices(PatternRewriter &rewriter, auto indicesChosenAxis = tosa::CreateOpAndInfer( rewriter, op->getLoc(), GetTypeFromTensorShape(indicesOneDimShape, indexType.getElementType()), - indexValue, rewriter.getDenseI64ArrayAttr(indicesOneDimShape)); + indexValue, + tosa::getTosaConstShape(rewriter, op->getLoc(), indicesOneDimShape)); SmallVector concatInputs; for (auto dim = 0; dim < paramsRank; dim++) { @@ -347,14 +352,16 @@ std::optional convertGatherNdOp(PatternRewriter &rewriter, Operation *op, auto tosaValuesReshapeOp = tosa::CreateOpAndInfer( rewriter, op->getLoc(), GetTypeFromTensorShape(tosaValuesShape, paramsType.getElementType()), - paramsValue, rewriter.getDenseI64ArrayAttr(tosaValuesShape)); + paramsValue, + tosa::getTosaConstShape(rewriter, op->getLoc(), tosaValuesShape)); // %3 = "tosa.reshape"(%1) {new_shape = [8, 3]} : (tensor<1x4x2x3xi32>) -> // tensor<8x3xi32> Flatten the input indices tensor to an [W, ND] matrix. Value indicesMatrixReshapeOp = tosa::CreateOpAndInfer( rewriter, op->getLoc(), GetTypeFromTensorShape(indicesMatrixShape, indicesType.getElementType()), - indicesValue, rewriter.getDenseI64ArrayAttr(indicesMatrixShape)); + indicesValue, + tosa::getTosaConstShape(rewriter, op->getLoc(), indicesMatrixShape)); SmallVector flattenedCoeffVec; // [12,3,1] // flattenedCoeffVec = [4,3,1] @@ -386,8 +393,8 @@ std::optional convertGatherNdOp(PatternRewriter &rewriter, Operation *op, // Multiply the coefficients by the coordinates // %5 = "tosa.mul"(%3, %4) {shift = 0 : i32} : (tensor<8x3xi32>, // tensor<3xi32>) -> tensor<8x3xi32> - auto flattenedIndicesMulOp = tosa::CreateOpAndInfer( - rewriter, op->getLoc(), + auto flattenedIndicesMulOp = tosa::createMulOpAndCast( + rewriter, op, GetTypeFromTensorShape(indicesMatrixShape, indicesType.getElementType()), indicesMatrixReshapeOp, flattenedCoeffValue.value(), 0); @@ -407,7 +414,7 @@ std::optional convertGatherNdOp(PatternRewriter &rewriter, Operation *op, rewriter, op->getLoc(), GetTypeFromTensorShape(tosaIndicesShape, indicesType.getElementType()), flattenedIndicesReduceOp.getResult(), - rewriter.getDenseI64ArrayAttr(tosaIndicesShape)); + tosa::getTosaConstShape(rewriter, op->getLoc(), tosaIndicesShape)); // Now the gather op itself // %9 = "tosa.gather"(%2, %7) : (tensor<1x12x1xf32>, tensor<1x8xi32>) -> @@ -424,7 +431,8 @@ std::optional convertGatherNdOp(PatternRewriter &rewriter, Operation *op, // %10 : tensor<1x4x2xf32> -> !torch.vtensor<[1,4,2],f32> return tosa::CreateOpAndInfer( rewriter, op->getLoc(), resultType, tosaGatherOp.getResult(), - rewriter.getDenseI64ArrayAttr(resultType.getShape())) + tosa::getTosaConstShape(rewriter, op->getLoc(), + resultType.getShape())) .getResult(); } @@ -568,7 +576,7 @@ std::optional convertScatterNdOp(PatternRewriter &rewriter, auto tosaFillValuesOneReshapeOp = tosa::CreateOpAndInfer( rewriter, op->getLoc(), GetTypeFromTensorShape(oneShape, fillValuesType.getElementType()), - fillValues, rewriter.getDenseI64ArrayAttr(oneShape)); + fillValues, tosa::getTosaConstShape(rewriter, op->getLoc(), oneShape)); // [0] -> [0,0,0] SmallVector tileShape({W}); // {3} @@ -586,7 +594,8 @@ std::optional convertScatterNdOp(PatternRewriter &rewriter, GetTypeFromTensorShape(newTosaFillValuesShape, fillValuesType.getElementType()), tosaFillValuesTileOp.getResult(), - rewriter.getDenseI64ArrayAttr(newTosaFillValuesShape)); + tosa::getTosaConstShape(rewriter, op->getLoc(), + newTosaFillValuesShape)); fillValues = newTosaFillValuesReshapeOp.getResult(); fillValuesType = dyn_cast(fillValues.getType()); } @@ -606,7 +615,8 @@ std::optional convertScatterNdOp(PatternRewriter &rewriter, rewriter, op->getLoc(), GetTypeFromTensorShape(tosaFillValuesShape, fillValuesType.getElementType()), - fillValues, rewriter.getDenseI64ArrayAttr(tosaFillValuesShape)); + fillValues, + tosa::getTosaConstShape(rewriter, op->getLoc(), tosaFillValuesShape)); // Reshape/Flatten input to 3d tensor // [[1, 2, 3, 4]] -> [[[1], [2], [3], [4]]] @@ -615,7 +625,8 @@ std::optional convertScatterNdOp(PatternRewriter &rewriter, auto tosaValuesReshapeOp = tosa::CreateOpAndInfer( rewriter, op->getLoc(), GetTypeFromTensorShape(tosaInputValuesShape, paramsType.getElementType()), - paramsValue, rewriter.getDenseI64ArrayAttr(tosaInputValuesShape)); + paramsValue, + tosa::getTosaConstShape(rewriter, op->getLoc(), tosaInputValuesShape)); // Reshape/Flatten the input indices tensor to a 2d [W, ND] matrix. // [[0, 1], [0, 2], [0, 3]] -> [[0, 1], [0, 2], [0, 3]] @@ -624,7 +635,8 @@ std::optional convertScatterNdOp(PatternRewriter &rewriter, Value indicesMatrixReshapeOp = tosa::CreateOpAndInfer( rewriter, op->getLoc(), GetTypeFromTensorShape(indicesMatrixShape, indicesType.getElementType()), - indicesValue, rewriter.getDenseI64ArrayAttr(indicesMatrixShape)); + indicesValue, + tosa::getTosaConstShape(rewriter, op->getLoc(), indicesMatrixShape)); SmallVector flattenedCoeffVec; // [4,1] // flattenedCoeffVec = [4,1] @@ -657,8 +669,8 @@ std::optional convertScatterNdOp(PatternRewriter &rewriter, // [[0, 1], [0, 2], [0, 3]] X [4, 1] -> [[4*0, 1*1], [4*0, 1*2], [4*0, 1*3]] // %13 = "tosa.mul"(%11, %12) {shift = 0 : i32} : (tensor<3x2xi32>, // tensor<2xi32>) -> tensor<3x2xi32> - auto flattenedIndicesMulOp = tosa::CreateOpAndInfer( - rewriter, op->getLoc(), + auto flattenedIndicesMulOp = tosa::createMulOpAndCast( + rewriter, op, GetTypeFromTensorShape(indicesMatrixShape, indicesType.getElementType()), indicesMatrixReshapeOp, flattenedCoeffValue.value(), 0); @@ -680,7 +692,7 @@ std::optional convertScatterNdOp(PatternRewriter &rewriter, rewriter, op->getLoc(), GetTypeFromTensorShape(tosaIndicesShape, indicesType.getElementType()), flattenedIndicesReduceOp.getResult(), - rewriter.getDenseI64ArrayAttr(tosaIndicesShape)); + tosa::getTosaConstShape(rewriter, op->getLoc(), tosaIndicesShape)); // Now the Scatter op itself // %16 = "tosa.scatter"(%9, %15, %10) : (tensor<1x4x1xi64>, tensor<1x3xi32>, @@ -700,7 +712,8 @@ std::optional convertScatterNdOp(PatternRewriter &rewriter, // (tensor<1x4x1xi64>) -> tensor<1x4xi64> return tosa::CreateOpAndInfer( rewriter, op->getLoc(), resultType, tosaScatterOp.getResult(), - rewriter.getDenseI64ArrayAttr(resultType.getShape())) + tosa::getTosaConstShape(rewriter, op->getLoc(), + resultType.getShape())) .getResult(); } @@ -778,7 +791,7 @@ std::optional convertReduceOpCommon( if (!keep_dims) { auto reshape_op = CreateOpAndInfer( rewriter, op->getLoc(), output_type, val, - rewriter.getDenseI64ArrayAttr(output_shape)); + tosa::getTosaConstShape(rewriter, op->getLoc(), output_shape)); val = reshape_op.getResult(); } } @@ -1006,8 +1019,8 @@ convertReduceMeanOp(PatternRewriter &rewriter, Operation *op, .failed()) return std::nullopt; - return CreateOpAndInfer(rewriter, op->getLoc(), output_type, - val.value(), div_const, 0) + return tosa::createMulOpAndCast(rewriter, op, output_type, val.value(), + div_const, 0) .getResult(); } @@ -1072,7 +1085,7 @@ convertLinalgVectorNormOp(PatternRewriter &rewriter, Operation *op, } auto input_value_casted = - tosa::promoteType(rewriter, input_value, output_type); + tosa::tosaCastTensorToType(rewriter, input_value, output_type).value(); auto absVal = CreateOpAndInfer( rewriter, op->getLoc(), RankedTensorType::get(input_type.getShape(), elemType), diff --git a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp index 2243c8dcfd83..6cb4afad1d04 100644 --- a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp +++ b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp @@ -144,7 +144,8 @@ Value buildSlice(PatternRewriter &rewriter, Value &input, RankedTensorType::get( llvm::SmallVector(size.size(), ShapedType::kDynamic), cast(input.getType()).getElementType()), - input, start, size); + input, tosa::getTosaConstShape(rewriter, input.getLoc(), start), + tosa::getTosaConstShape(rewriter, input.getLoc(), size)); } // Check if scale32 mode is used for given output_element_type @@ -163,6 +164,19 @@ Value getTosaConstTensorSingleF32(PatternRewriter &rewriter, Operation *op, return const_op.getResult(); } +// Create an int8_t const tosa.mul shift tensor from an int +Value getTosaMulShiftConstTensor(PatternRewriter &rewriter, Operation *op, + int32_t shift) { + auto shiftType = RankedTensorType::get({1}, rewriter.getIntegerType(8)); + auto shiftAttr = DenseElementsAttr::get( + shiftType, rewriter.getIntegerAttr(rewriter.getIntegerType(8), shift)); + + auto constShift = + rewriter.create(op->getLoc(), shiftType, shiftAttr); + + return constShift.getResult(); +} + // Create a zero constant tensor of the desired type and shape. std::optional getZerosLikeTensor(PatternRewriter &rewriter, Operation *op, Type type) { @@ -212,8 +226,9 @@ std::optional getConstTensor(PatternRewriter &rewriter, Operation *op, rewriter.create(op->getLoc(), const_type, const_attr); if (dtype) { - return rewriter.createOrFold( - op->getLoc(), RankedTensorType::get(shape, *dtype), const_op); + return tosa::tosaCastTensorToType(rewriter, const_op, + RankedTensorType::get(shape, *dtype)) + .value(); } return const_op.getResult(); } @@ -242,8 +257,9 @@ std::optional getConstTensor(PatternRewriter &rewriter, rewriter.create(op->getLoc(), const_type, const_attr); if (dtype) { - return rewriter.createOrFold( - op->getLoc(), RankedTensorType::get(shape, *dtype), const_op); + return tosa::tosaCastTensorToType(rewriter, const_op, + RankedTensorType::get(shape, *dtype)) + .value(); } return const_op.getResult(); } @@ -299,8 +315,9 @@ std::optional getConstTensor(PatternRewriter &rewriter, rewriter.create(op->getLoc(), const_type, const_attr); if (dtype) { - return rewriter.createOrFold( - op->getLoc(), RankedTensorType::get(shape, *dtype), const_op); + return tosa::tosaCastTensorToType(rewriter, const_op, + RankedTensorType::get(shape, *dtype)) + .value(); } return const_op.getResult(); } @@ -373,10 +390,11 @@ std::optional getConstTensor(PatternRewriter &rewriter, return failure(); } -// Template specialization for float -LogicalResult tosaCastTensorToType(PatternRewriter &rewriter, Operation *op, - Value src, Type destType, Value &result) { - +// Default function to create tosa.cast op. This should be called instead of +// directly calling rewriter.create. +std::optional tosaCastTensorToType(PatternRewriter &rewriter, Value src, + TensorType destType) { + Operation *op = src.getDefiningOp(); TensorType srcType = dyn_cast(src.getType()); Type srcElemTy = srcType.getElementType(); Type destElemTy = dyn_cast(destType).getElementType(); @@ -390,93 +408,36 @@ LogicalResult tosaCastTensorToType(PatternRewriter &rewriter, Operation *op, // casting only when needed (the default value of `--strict` mode will be // off). // if (failed(checkValidityOfCast(srcElemTy, destElemTy))) - // return rewriter.notifyMatchFailure( - // op, "casting to result dtype is invalid or unsupported"); - - if (destElemTy.isInteger(1)) { - auto srcType = dyn_cast(src.getType()); - SmallVector srcShape(srcType.getShape()); - uint64_t num_total_elements = 1; - for (int64_t a : srcShape) - num_total_elements *= a; - - std::optional constOp; - if (srcElemTy.isInteger(64)) { - SmallVector values(num_total_elements, 0); - constOp = - tosa::getConstTensor(rewriter, op, values, srcShape).value(); - } else if (srcElemTy.isInteger(32)) { - SmallVector values(num_total_elements, 0); - constOp = - tosa::getConstTensor(rewriter, op, values, srcShape).value(); - } else if (srcElemTy.isInteger(8)) { - SmallVector values(num_total_elements, 0); - constOp = - tosa::getConstTensor(rewriter, op, values, srcShape).value(); - } else if (srcElemTy.isInteger(16)) { - SmallVector values(num_total_elements, 0); - constOp = - tosa::getConstTensor(rewriter, op, values, srcShape).value(); - } else if (srcElemTy.isBF16()) { - SmallVector values(num_total_elements, 0.0); - constOp = - tosa::getConstTensor(rewriter, op, values, srcShape, srcElemTy) - .value(); - } else if (srcElemTy.isF32()) { - SmallVector values(num_total_elements, 0.0); - constOp = - tosa::getConstTensor(rewriter, op, values, srcShape).value(); - } else if (srcElemTy.isF64()) { - SmallVector values(num_total_elements, 0.0); - constOp = - tosa::getConstTensor(rewriter, op, values, srcShape).value(); - } else { - op->dump(); - op->emitError("Unsupported conversion to i1"); - return failure(); - } - Value equalToZero = rewriter.create(op->getLoc(), destType, - src, constOp.value()); - result = rewriter.create(op->getLoc(), destType, - equalToZero); - } else { - if (llvm::isa(srcElemTy) && destElemTy.isInteger()) { - // for float->int conversion, tosa.cast performs round-to-nearest - // torch performs round-to-zero instead - // generate round-to-zero conversion prior to tosa.cast to match with - // expected torch behavior - auto floor = rewriter.create(op->getLoc(), srcType, src); - auto ceil = rewriter.create(op->getLoc(), srcType, src); - - auto zeroValue = - tosa::getConstTensor(rewriter, op, 0, {}, srcElemTy).value(); - - if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), src, zeroValue) - .failed()) - return rewriter.notifyMatchFailure( - op, "Failed to equalize ranks among operands and result"); - - auto boolType = srcType.clone(rewriter.getIntegerType(1)); - auto isNegative = tosa::CreateOpAndInfer( - rewriter, op->getLoc(), boolType, zeroValue, src); - src = tosa::CreateOpAndInfer( - rewriter, op->getLoc(), srcType, isNegative, ceil, floor); - } - result = rewriter.create(op->getLoc(), destType, src); + // return std::nullopt; + + if (srcElemTy == destElemTy) + return src; + + if (llvm::isa(srcElemTy) && destElemTy.isInteger() && + !destElemTy.isInteger(1)) { + // For float->int conversion, tosa.cast performs round-to-nearest. + // PyTorch performs round-to-zero instead. + // Generate round-to-zero conversion prior to tosa.cast to match with + // expected torch behavior. + auto floor = rewriter.create(op->getLoc(), srcType, src); + auto ceil = rewriter.create(op->getLoc(), srcType, src); + + auto zeroValue = + tosa::getConstTensor(rewriter, op, 0, {}, srcElemTy).value(); + + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), src, zeroValue) + .failed()) + return std::nullopt; + + auto boolType = srcType.clone(rewriter.getIntegerType(1)); + auto isNegative = tosa::CreateOpAndInfer( + rewriter, op->getLoc(), boolType, zeroValue, src); + src = tosa::CreateOpAndInfer( + rewriter, op->getLoc(), srcType, isNegative, ceil, floor); } - return success(); -} -Value promoteType(PatternRewriter &rewriter, Value input, TensorType outType) { - Operation *op = input.getDefiningOp(); - TensorType inType = cast(input.getType()); - - if (inType.getElementType() != outType.getElementType()) { - TensorType promotedType = - inType.cloneWith(inType.getShape(), outType.getElementType()); - return rewriter.create(op->getLoc(), promotedType, input); - } - return input; + TensorType castedSrcType = srcType.clone(destElemTy); + return rewriter.create(op->getLoc(), castedSrcType, src); } TypedValue reshapeTo(Location loc, PatternRewriter &rewriter, @@ -487,7 +448,7 @@ TypedValue reshapeTo(Location loc, PatternRewriter &rewriter, auto newTy = RankedTensorType::get(newShape, tensorTy.getElementType()); return rewriter.create( - loc, newTy, val, rewriter.getDenseI64ArrayAttr(newShape)); + loc, newTy, val, tosa::getTosaConstShape(rewriter, loc, newShape)); } TypedValue transposeBy(Location loc, diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index 941d710e4f2e..c32bf7a3f0fa 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -26,10 +26,20 @@ func.func @torch.aten.sigmoid$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch. // ----- -// CHECK-DAG: %[[VAL_2:.+]] = tosa.reshape %1 {new_shape = array} : (tensor<4x8xf32>) -> tensor<1x4x8xf32> -// CHECK-DAG: %[[VAL_3:.+]] = tosa.reshape %0 {new_shape = array} : (tensor<8x16xf32>) -> tensor<1x8x16xf32> -// CHECK-NEXT: %[[VAL_4:.+]] = tosa.matmul %[[VAL_2]], %[[VAL_3]] : (tensor<1x4x8xf32>, tensor<1x8x16xf32>) -> tensor<1x4x16xf32> -// CHECK-NEXT: %[[VAL_5:.+]] = tosa.reshape %[[VAL_4]] {new_shape = array} : (tensor<1x4x16xf32>) -> tensor<4x16xf32> +// CHECK-LABEL: func.func @torch.aten.mm$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,8],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[8,16],f32>) -> !torch.vtensor<[4,16],f32> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[8,16],f32> -> tensor<8x16xf32> +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,8],f32> -> tensor<4x8xf32> +// CHECK: %[[VAL_4:.*]] = tosa.const_shape {value = dense<[1, 4, 8]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_5:.*]] = tosa.reshape %[[VAL_3]], %[[VAL_4]] : (tensor<4x8xf32>, !tosa.shape<3>) -> tensor<1x4x8xf32> +// CHECK: %[[VAL_6:.*]] = tosa.const_shape {value = dense<[1, 8, 16]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_7:.*]] = tosa.reshape %[[VAL_2]], %[[VAL_6]] : (tensor<8x16xf32>, !tosa.shape<3>) -> tensor<1x8x16xf32> +// CHECK: %[[VAL_8:.*]] = tosa.matmul %[[VAL_5]], %[[VAL_7]] : (tensor<1x4x8xf32>, tensor<1x8x16xf32>) -> tensor<1x4x16xf32> +// CHECK: %[[VAL_9:.*]] = tosa.const_shape {value = dense<[4, 16]> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_10:.*]] = tosa.reshape %[[VAL_8]], %[[VAL_9]] : (tensor<1x4x16xf32>, !tosa.shape<2>) -> tensor<4x16xf32> +// CHECK: %[[VAL_11:.*]] = torch_c.from_builtin_tensor %[[VAL_10]] : tensor<4x16xf32> -> !torch.vtensor<[4,16],f32> +// CHECK: return %[[VAL_11]] : !torch.vtensor<[4,16],f32> func.func @torch.aten.mm$basic(%arg0: !torch.vtensor<[4,8],f32>, %arg1: !torch.vtensor<[8,16],f32>) -> !torch.vtensor<[4,16],f32> { %0 = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[4,8],f32>, !torch.vtensor<[8,16],f32> -> !torch.vtensor<[4,16],f32> return %0 : !torch.vtensor<[4,16],f32> @@ -37,12 +47,24 @@ func.func @torch.aten.mm$basic(%arg0: !torch.vtensor<[4,8],f32>, %arg1: !torch.v // ----- -// CHECK: %[[VAL_2:.+]] = tosa.reshape %1 {new_shape = array} : (tensor<6xf32>) -> tensor<1x6xf32> -// CHECK-NEXT: %[[VAL_3:.+]] = tosa.reshape %0 {new_shape = array} : (tensor<6xf32>) -> tensor<6x1xf32> -// CHECK-NEXT: %[[VAL_4:.+]] = tosa.reshape %[[VAL_2]] {new_shape = array} : (tensor<1x6xf32>) -> tensor<1x1x6xf32> -// CHECK-NEXT: %[[VAL_5:.+]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor<6x1xf32>) -> tensor<1x6x1xf32> -// CHECK-NEXT: %[[VAL_6:.+]] = tosa.matmul %[[VAL_4]], %[[VAL_5]] : (tensor<1x1x6xf32>, tensor<1x6x1xf32>) -> tensor<1x1x1xf32> -// CHECK-NEXT: %[[VAL_7:.+]] = tosa.reshape %[[VAL_6]] {new_shape = array} : (tensor<1x1x1xf32>) -> tensor +// CHECK-LABEL: func.func @torch.aten.matmul_1d( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[6],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[6],f32>) -> !torch.vtensor<[],f32> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[6],f32> -> tensor<6xf32> +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[6],f32> -> tensor<6xf32> +// CHECK: %[[VAL_4:.*]] = tosa.const_shape {value = dense<[1, 6]> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_5:.*]] = tosa.reshape %[[VAL_3]], %[[VAL_4]] : (tensor<6xf32>, !tosa.shape<2>) -> tensor<1x6xf32> +// CHECK: %[[VAL_6:.*]] = tosa.const_shape {value = dense<[6, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_7:.*]] = tosa.reshape %[[VAL_2]], %[[VAL_6]] : (tensor<6xf32>, !tosa.shape<2>) -> tensor<6x1xf32> +// CHECK: %[[VAL_8:.*]] = tosa.const_shape {value = dense<[1, 1, 6]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_9:.*]] = tosa.reshape %[[VAL_5]], %[[VAL_8]] : (tensor<1x6xf32>, !tosa.shape<3>) -> tensor<1x1x6xf32> +// CHECK: %[[VAL_10:.*]] = tosa.const_shape {value = dense<[1, 6, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_11:.*]] = tosa.reshape %[[VAL_7]], %[[VAL_10]] : (tensor<6x1xf32>, !tosa.shape<3>) -> tensor<1x6x1xf32> +// CHECK: %[[VAL_12:.*]] = tosa.matmul %[[VAL_9]], %[[VAL_11]] : (tensor<1x1x6xf32>, tensor<1x6x1xf32>) -> tensor<1x1x1xf32> +// CHECK: %[[VAL_13:.*]] = tosa.const_shape {value = dense<> : tensor<0xindex>} : () -> !tosa.shape<0> +// CHECK: %[[VAL_14:.*]] = tosa.reshape %[[VAL_12]], %[[VAL_13]] : (tensor<1x1x1xf32>, !tosa.shape<0>) -> tensor +// CHECK: %[[VAL_15:.*]] = torch_c.from_builtin_tensor %[[VAL_14]] : tensor -> !torch.vtensor<[],f32> +// CHECK: return %[[VAL_15]] : !torch.vtensor<[],f32> func.func @torch.aten.matmul_1d(%arg0 : !torch.vtensor<[6],f32>, %arg1 : !torch.vtensor<[6],f32>) -> !torch.vtensor<[],f32> { %0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[6],f32>, !torch.vtensor<[6],f32> -> !torch.vtensor<[],f32> return %0 : !torch.vtensor<[],f32> @@ -50,11 +72,22 @@ func.func @torch.aten.matmul_1d(%arg0 : !torch.vtensor<[6],f32>, %arg1 : !torch. // ----- -// CHECK: %[[VAL_2:.+]] = tosa.reshape %1 {new_shape = array} : (tensor<6xf32>) -> tensor<1x6xf32> -// CHECK-NEXT: %[[VAL_3:.+]] = tosa.reshape %[[VAL_2]] {new_shape = array} : (tensor<1x6xf32>) -> tensor<1x1x6xf32> -// CHECK-NEXT: %[[VAL_4:.+]] = tosa.reshape %0 {new_shape = array} : (tensor<6x1xf32>) -> tensor<1x6x1xf32> -// CHECK-NEXT: %[[VAL_5:.+]] = tosa.matmul %[[VAL_3]], %[[VAL_4]] : (tensor<1x1x6xf32>, tensor<1x6x1xf32>) -> tensor<1x1x1xf32> -// CHECK-NEXT: %[[VAL_6:.+]] = tosa.reshape %[[VAL_5]] {new_shape = array} : (tensor<1x1x1xf32>) -> tensor<1xf32> +// CHECK-LABEL: func.func @torch.aten.matmul_12d( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[6],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[6,1],f32>) -> !torch.vtensor<[1],f32> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[6,1],f32> -> tensor<6x1xf32> +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[6],f32> -> tensor<6xf32> +// CHECK: %[[VAL_4:.*]] = tosa.const_shape {value = dense<[1, 6]> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_5:.*]] = tosa.reshape %[[VAL_3]], %[[VAL_4]] : (tensor<6xf32>, !tosa.shape<2>) -> tensor<1x6xf32> +// CHECK: %[[VAL_6:.*]] = tosa.const_shape {value = dense<[1, 1, 6]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_7:.*]] = tosa.reshape %[[VAL_5]], %[[VAL_6]] : (tensor<1x6xf32>, !tosa.shape<3>) -> tensor<1x1x6xf32> +// CHECK: %[[VAL_8:.*]] = tosa.const_shape {value = dense<[1, 6, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_9:.*]] = tosa.reshape %[[VAL_2]], %[[VAL_8]] : (tensor<6x1xf32>, !tosa.shape<3>) -> tensor<1x6x1xf32> +// CHECK: %[[VAL_10:.*]] = tosa.matmul %[[VAL_7]], %[[VAL_9]] : (tensor<1x1x6xf32>, tensor<1x6x1xf32>) -> tensor<1x1x1xf32> +// CHECK: %[[VAL_11:.*]] = tosa.const_shape {value = dense<1> : tensor<1xindex>} : () -> !tosa.shape<1> +// CHECK: %[[VAL_12:.*]] = tosa.reshape %[[VAL_10]], %[[VAL_11]] : (tensor<1x1x1xf32>, !tosa.shape<1>) -> tensor<1xf32> +// CHECK: %[[VAL_13:.*]] = torch_c.from_builtin_tensor %[[VAL_12]] : tensor<1xf32> -> !torch.vtensor<[1],f32> +// CHECK: return %[[VAL_13]] : !torch.vtensor<[1],f32> func.func @torch.aten.matmul_12d(%arg0 : !torch.vtensor<[6],f32>, %arg1 : !torch.vtensor<[6,1],f32>) -> !torch.vtensor<[1],f32> { %0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[6],f32>, !torch.vtensor<[6,1],f32> -> !torch.vtensor<[1],f32> return %0 : !torch.vtensor<[1],f32> @@ -62,11 +95,22 @@ func.func @torch.aten.matmul_12d(%arg0 : !torch.vtensor<[6],f32>, %arg1 : !torch // ----- -// CHECK: %[[VAL_2:.+]] = tosa.reshape %0 {new_shape = array} : (tensor<6xf32>) -> tensor<6x1xf32> -// CHECK-NEXT: %[[VAL_3:.+]] = tosa.reshape %1 {new_shape = array} : (tensor<2x6xf32>) -> tensor<1x2x6xf32> -// CHECK-NEXT: %[[VAL_4:.+]] = tosa.reshape %[[VAL_2]] {new_shape = array} : (tensor<6x1xf32>) -> tensor<1x6x1xf32> -// CHECK-NEXT: %[[VAL_5:.+]] = tosa.matmul %[[VAL_3]], %[[VAL_4]] : (tensor<1x2x6xf32>, tensor<1x6x1xf32>) -> tensor<1x2x1xf32> -// CHECK-NEXT: %[[VAL_6:.+]] = tosa.reshape %[[VAL_5]] {new_shape = array} : (tensor<1x2x1xf32>) -> tensor<2xf32> +// CHECK-LABEL: func.func @torch.aten.matmul_21d( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,6],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[6],f32>) -> !torch.vtensor<[2],f32> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[6],f32> -> tensor<6xf32> +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,6],f32> -> tensor<2x6xf32> +// CHECK: %[[VAL_4:.*]] = tosa.const_shape {value = dense<[6, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_5:.*]] = tosa.reshape %[[VAL_2]], %[[VAL_4]] : (tensor<6xf32>, !tosa.shape<2>) -> tensor<6x1xf32> +// CHECK: %[[VAL_6:.*]] = tosa.const_shape {value = dense<[1, 2, 6]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_7:.*]] = tosa.reshape %[[VAL_3]], %[[VAL_6]] : (tensor<2x6xf32>, !tosa.shape<3>) -> tensor<1x2x6xf32> +// CHECK: %[[VAL_8:.*]] = tosa.const_shape {value = dense<[1, 6, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_9:.*]] = tosa.reshape %[[VAL_5]], %[[VAL_8]] : (tensor<6x1xf32>, !tosa.shape<3>) -> tensor<1x6x1xf32> +// CHECK: %[[VAL_10:.*]] = tosa.matmul %[[VAL_7]], %[[VAL_9]] : (tensor<1x2x6xf32>, tensor<1x6x1xf32>) -> tensor<1x2x1xf32> +// CHECK: %[[VAL_11:.*]] = tosa.const_shape {value = dense<2> : tensor<1xindex>} : () -> !tosa.shape<1> +// CHECK: %[[VAL_12:.*]] = tosa.reshape %[[VAL_10]], %[[VAL_11]] : (tensor<1x2x1xf32>, !tosa.shape<1>) -> tensor<2xf32> +// CHECK: %[[VAL_13:.*]] = torch_c.from_builtin_tensor %[[VAL_12]] : tensor<2xf32> -> !torch.vtensor<[2],f32> +// CHECK: return %[[VAL_13]] : !torch.vtensor<[2],f32> func.func @torch.aten.matmul_21d(%arg0 : !torch.vtensor<[2,6],f32>, %arg1 : !torch.vtensor<[6],f32>) -> !torch.vtensor<[2],f32> { %0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[2,6],f32>, !torch.vtensor<[6],f32> -> !torch.vtensor<[2],f32> return %0 : !torch.vtensor<[2],f32> @@ -74,10 +118,20 @@ func.func @torch.aten.matmul_21d(%arg0 : !torch.vtensor<[2,6],f32>, %arg1 : !tor // ----- -// CHECK: %[[VAL_2:.+]] = tosa.reshape %1 {new_shape = array} : (tensor<2x6xf32>) -> tensor<1x2x6xf32> -// CHECK-NEXT: %[[VAL_3:.+]] = tosa.reshape %0 {new_shape = array} : (tensor<6x8xf32>) -> tensor<1x6x8xf32> -// CHECK-NEXT: %[[VAL_4:.+]] = tosa.matmul %[[VAL_2]], %[[VAL_3]] : (tensor<1x2x6xf32>, tensor<1x6x8xf32>) -> tensor<1x2x8xf32> -// CHECK-NEXT: %[[VAL_5:.+]] = tosa.reshape %[[VAL_4]] {new_shape = array} : (tensor<1x2x8xf32>) -> tensor<2x8xf32> +// CHECK-LABEL: func.func @torch.aten.mm_2d( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,6],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[6,8],f32>) -> !torch.vtensor<[2,8],f32> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[6,8],f32> -> tensor<6x8xf32> +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,6],f32> -> tensor<2x6xf32> +// CHECK: %[[VAL_4:.*]] = tosa.const_shape {value = dense<[1, 2, 6]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_5:.*]] = tosa.reshape %[[VAL_3]], %[[VAL_4]] : (tensor<2x6xf32>, !tosa.shape<3>) -> tensor<1x2x6xf32> +// CHECK: %[[VAL_6:.*]] = tosa.const_shape {value = dense<[1, 6, 8]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_7:.*]] = tosa.reshape %[[VAL_2]], %[[VAL_6]] : (tensor<6x8xf32>, !tosa.shape<3>) -> tensor<1x6x8xf32> +// CHECK: %[[VAL_8:.*]] = tosa.matmul %[[VAL_5]], %[[VAL_7]] : (tensor<1x2x6xf32>, tensor<1x6x8xf32>) -> tensor<1x2x8xf32> +// CHECK: %[[VAL_9:.*]] = tosa.const_shape {value = dense<[2, 8]> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_10:.*]] = tosa.reshape %[[VAL_8]], %[[VAL_9]] : (tensor<1x2x8xf32>, !tosa.shape<2>) -> tensor<2x8xf32> +// CHECK: %[[VAL_11:.*]] = torch_c.from_builtin_tensor %[[VAL_10]] : tensor<2x8xf32> -> !torch.vtensor<[2,8],f32> +// CHECK: return %[[VAL_11]] : !torch.vtensor<[2,8],f32> func.func @torch.aten.mm_2d(%arg0 : !torch.vtensor<[2,6],f32>, %arg1 : !torch.vtensor<[6,8],f32>) -> !torch.vtensor<[2,8],f32> { %0 = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[2,6],f32>, !torch.vtensor<[6,8],f32> -> !torch.vtensor<[2,8],f32> return %0 : !torch.vtensor<[2,8],f32> @@ -137,13 +191,25 @@ func.func @torch.aten.mm_i16(%arg0: !torch.vtensor<[4,8],si16>, %arg1: !torch.vt // ----- -// CHECK: %[[VAL_2:.+]] = tosa.cast %{{[0-9]+}} : (tensor<4x8xf32>) -> tensor<4x8xf16> -// CHECK-NEXT: %[[VAL_3:.+]] = tosa.cast %{{[0-9]+}} : (tensor<8x16xf32>) -> tensor<8x16xf16> -// CHECK-NEXT: %[[VAL_4:.+]] = tosa.reshape %[[VAL_2]] {new_shape = array} : (tensor<4x8xf16>) -> tensor<1x4x8xf16> -// CHECK-NEXT: %[[VAL_5:.+]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor<8x16xf16>) -> tensor<1x8x16xf16> -// CHECK-NEXT: %[[VAL_6:.+]] = tosa.matmul %[[VAL_4]], %[[VAL_5]] : (tensor<1x4x8xf16>, tensor<1x8x16xf16>) -> tensor<1x4x16xf16> -// CHECK-NEXT: %[[VAL_7:.+]] = tosa.reshape %[[VAL_6]] {new_shape = array} : (tensor<1x4x16xf16>) -> tensor<4x16xf16> - +// CHECK-LABEL: func.func @torch.aten.mm_f32_to_f16( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,8],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[8,16],f32>) -> !torch.vtensor<[4,16],f16> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[8,16],f32> -> tensor<8x16xf32> +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,8],f32> -> tensor<4x8xf32> +// CHECK: %[[VAL_4:.*]] = torch.constant.bool false +// CHECK: %[[VAL_5:.*]] = torch.constant.none +// CHECK: %[[VAL_6:.*]] = torch.constant.int 5 +// CHECK: %[[VAL_7:.*]] = tosa.cast %[[VAL_3]] : (tensor<4x8xf32>) -> tensor<4x8xf16> +// CHECK: %[[VAL_8:.*]] = tosa.cast %[[VAL_2]] : (tensor<8x16xf32>) -> tensor<8x16xf16> +// CHECK: %[[VAL_9:.*]] = tosa.const_shape {value = dense<[1, 4, 8]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_10:.*]] = tosa.reshape %[[VAL_7]], %[[VAL_9]] : (tensor<4x8xf16>, !tosa.shape<3>) -> tensor<1x4x8xf16> +// CHECK: %[[VAL_11:.*]] = tosa.const_shape {value = dense<[1, 8, 16]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_12:.*]] = tosa.reshape %[[VAL_8]], %[[VAL_11]] : (tensor<8x16xf16>, !tosa.shape<3>) -> tensor<1x8x16xf16> +// CHECK: %[[VAL_13:.*]] = tosa.matmul %[[VAL_10]], %[[VAL_12]] : (tensor<1x4x8xf16>, tensor<1x8x16xf16>) -> tensor<1x4x16xf16> +// CHECK: %[[VAL_14:.*]] = tosa.const_shape {value = dense<[4, 16]> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_15:.*]] = tosa.reshape %[[VAL_13]], %[[VAL_14]] : (tensor<1x4x16xf16>, !tosa.shape<2>) -> tensor<4x16xf16> +// CHECK: %[[VAL_16:.*]] = torch_c.from_builtin_tensor %[[VAL_15]] : tensor<4x16xf16> -> !torch.vtensor<[4,16],f16> +// CHECK: return %[[VAL_16]] : !torch.vtensor<[4,16],f16> func.func @torch.aten.mm_f32_to_f16(%arg0: !torch.vtensor<[4,8],f32>, %arg1: !torch.vtensor<[8,16],f32>) -> !torch.vtensor<[4,16],f16> { %false = torch.constant.bool false %none = torch.constant.none @@ -156,10 +222,20 @@ func.func @torch.aten.mm_f32_to_f16(%arg0: !torch.vtensor<[4,8],f32>, %arg1: !to // ----- -// CHECK: %[[VAL_2:.+]] = tosa.reshape %1 {new_shape = array} : (tensor<10x10x6x2xf32>) -> tensor<100x6x2xf32> -// CHECK-NEXT: %[[VAL_3:.+]] = tosa.reshape %0 {new_shape = array} : (tensor<10x10x2x6xf32>) -> tensor<100x2x6xf32> -// CHECK-NEXT: %[[VAL_4:.+]] = tosa.matmul %[[VAL_2]], %[[VAL_3]] : (tensor<100x6x2xf32>, tensor<100x2x6xf32>) -> tensor<100x6x6xf32> -// CHECK-NEXT: %[[VAL_5:.+]] = tosa.reshape %[[VAL_4]] {new_shape = array} : (tensor<100x6x6xf32>) -> tensor<10x10x6x6xf32> +// CHECK-LABEL: func.func @torch.aten.matmul_4d( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[10,10,6,2],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[10,10,2,6],f32>) -> !torch.vtensor<[10,10,6,6],f32> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[10,10,2,6],f32> -> tensor<10x10x2x6xf32> +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[10,10,6,2],f32> -> tensor<10x10x6x2xf32> +// CHECK: %[[VAL_4:.*]] = tosa.const_shape {value = dense<[100, 6, 2]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_5:.*]] = tosa.reshape %[[VAL_3]], %[[VAL_4]] : (tensor<10x10x6x2xf32>, !tosa.shape<3>) -> tensor<100x6x2xf32> +// CHECK: %[[VAL_6:.*]] = tosa.const_shape {value = dense<[100, 2, 6]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_7:.*]] = tosa.reshape %[[VAL_2]], %[[VAL_6]] : (tensor<10x10x2x6xf32>, !tosa.shape<3>) -> tensor<100x2x6xf32> +// CHECK: %[[VAL_8:.*]] = tosa.matmul %[[VAL_5]], %[[VAL_7]] : (tensor<100x6x2xf32>, tensor<100x2x6xf32>) -> tensor<100x6x6xf32> +// CHECK: %[[VAL_9:.*]] = tosa.const_shape {value = dense<[10, 10, 6, 6]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_10:.*]] = tosa.reshape %[[VAL_8]], %[[VAL_9]] : (tensor<100x6x6xf32>, !tosa.shape<4>) -> tensor<10x10x6x6xf32> +// CHECK: %[[VAL_11:.*]] = torch_c.from_builtin_tensor %[[VAL_10]] : tensor<10x10x6x6xf32> -> !torch.vtensor<[10,10,6,6],f32> +// CHECK: return %[[VAL_11]] : !torch.vtensor<[10,10,6,6],f32> func.func @torch.aten.matmul_4d(%arg0 : !torch.vtensor<[10,10,6,2],f32>, %arg1 : !torch.vtensor<[10,10,2,6],f32>) -> !torch.vtensor<[10,10,6,6],f32> { %0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[10,10,6,2],f32>, !torch.vtensor<[10,10,2,6],f32> -> !torch.vtensor<[10,10,6,6],f32> return %0 : !torch.vtensor<[10,10,6,6],f32> @@ -167,17 +243,28 @@ func.func @torch.aten.matmul_4d(%arg0 : !torch.vtensor<[10,10,6,2],f32>, %arg1 : // ----- -// CHECK: %[[VAL_2:.+]] = tosa.reshape %1 {new_shape = array} : (tensor<10x6x2xf32>) -> tensor<1x10x6x2xf32> -// CHECK-NEXT: %[[VAL_3:.+]] = "tosa.const"() <{value = dense<[1, 0, 2, 3]> : tensor<4xi32>}> : () -> tensor<4xi32> -// CHECK-NEXT: %[[VAL_4:.+]] = tosa.transpose %[[VAL_2]], %[[VAL_3]] : (tensor<1x10x6x2xf32>, tensor<4xi32>) -> tensor<10x1x6x2xf32> -// CHECK-NEXT: %[[VAL_5:.+]] = tosa.reshape %[[VAL_4]] {new_shape = array} : (tensor<10x1x6x2xf32>) -> tensor<10x6x2xf32> -// CHECK-NEXT: %[[VAL_6:.+]] = "tosa.const"() <{value = dense<[1, 2, 0, 3]> : tensor<4xi32>}> : () -> tensor<4xi32> -// CHECK-NEXT: %[[VAL_7:.+]] = tosa.transpose %0, %[[VAL_6]] : (tensor<10x10x2x6xf32>, tensor<4xi32>) -> tensor<10x2x10x6xf32> -// CHECK-NEXT: %[[VAL_8:.+]] = tosa.reshape %[[VAL_7]] {new_shape = array} : (tensor<10x2x10x6xf32>) -> tensor<10x2x60xf32> -// CHECK-NEXT: %[[VAL_9:.+]] = tosa.matmul %[[VAL_5]], %[[VAL_8]] : (tensor<10x6x2xf32>, tensor<10x2x60xf32>) -> tensor<10x6x60xf32> -// CHECK-NEXT: %[[VAL_10:.+]] = tosa.reshape %[[VAL_9]] {new_shape = array} : (tensor<10x6x60xf32>) -> tensor<10x6x10x6xf32> -// CHECK-NEXT: %[[VAL_11:.+]] = "tosa.const"() <{value = dense<[2, 0, 1, 3]> : tensor<4xi32>}> : () -> tensor<4xi32> -// CHECK-NEXT: %[[VAL_12:.+]] = tosa.transpose %[[VAL_10]], %[[VAL_11]] : (tensor<10x6x10x6xf32>, tensor<4xi32>) -> tensor<10x10x6x6xf32> +// CHECK-LABEL: func.func @torch.aten.matmul_4d_broadcast( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[10,6,2],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[10,10,2,6],f32>) -> !torch.vtensor<[10,10,6,6],f32> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[10,10,2,6],f32> -> tensor<10x10x2x6xf32> +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[10,6,2],f32> -> tensor<10x6x2xf32> +// CHECK: %[[VAL_4:.*]] = tosa.const_shape {value = dense<[1, 10, 6, 2]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_5:.*]] = tosa.reshape %[[VAL_3]], %[[VAL_4]] : (tensor<10x6x2xf32>, !tosa.shape<4>) -> tensor<1x10x6x2xf32> +// CHECK: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<[1, 0, 2, 3]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK: %[[VAL_7:.*]] = tosa.transpose %[[VAL_5]], %[[VAL_6]] : (tensor<1x10x6x2xf32>, tensor<4xi32>) -> tensor<10x1x6x2xf32> +// CHECK: %[[VAL_8:.*]] = tosa.const_shape {value = dense<[10, 6, 2]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_9:.*]] = tosa.reshape %[[VAL_7]], %[[VAL_8]] : (tensor<10x1x6x2xf32>, !tosa.shape<3>) -> tensor<10x6x2xf32> +// CHECK: %[[VAL_10:.*]] = "tosa.const"() <{value = dense<[1, 2, 0, 3]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK: %[[VAL_11:.*]] = tosa.transpose %[[VAL_2]], %[[VAL_10]] : (tensor<10x10x2x6xf32>, tensor<4xi32>) -> tensor<10x2x10x6xf32> +// CHECK: %[[VAL_12:.*]] = tosa.const_shape {value = dense<[10, 2, 60]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_11]], %[[VAL_12]] : (tensor<10x2x10x6xf32>, !tosa.shape<3>) -> tensor<10x2x60xf32> +// CHECK: %[[VAL_14:.*]] = tosa.matmul %[[VAL_9]], %[[VAL_13]] : (tensor<10x6x2xf32>, tensor<10x2x60xf32>) -> tensor<10x6x60xf32> +// CHECK: %[[VAL_15:.*]] = tosa.const_shape {value = dense<[10, 6, 10, 6]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_16:.*]] = tosa.reshape %[[VAL_14]], %[[VAL_15]] : (tensor<10x6x60xf32>, !tosa.shape<4>) -> tensor<10x6x10x6xf32> +// CHECK: %[[VAL_17:.*]] = "tosa.const"() <{value = dense<[2, 0, 1, 3]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK: %[[VAL_18:.*]] = tosa.transpose %[[VAL_16]], %[[VAL_17]] : (tensor<10x6x10x6xf32>, tensor<4xi32>) -> tensor<10x10x6x6xf32> +// CHECK: %[[VAL_19:.*]] = torch_c.from_builtin_tensor %[[VAL_18]] : tensor<10x10x6x6xf32> -> !torch.vtensor<[10,10,6,6],f32> +// CHECK: return %[[VAL_19]] : !torch.vtensor<[10,10,6,6],f32> func.func @torch.aten.matmul_4d_broadcast(%arg0 : !torch.vtensor<[10,6,2],f32>, %arg1 : !torch.vtensor<[10,10,2,6],f32>) -> !torch.vtensor<[10,10,6,6],f32> { %0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[10,6,2],f32>, !torch.vtensor<[10,10,2,6],f32> -> !torch.vtensor<[10,10,6,6],f32> return %0 : !torch.vtensor<[10,10,6,6],f32> @@ -185,14 +272,24 @@ func.func @torch.aten.matmul_4d_broadcast(%arg0 : !torch.vtensor<[10,6,2],f32>, // ----- -// CHECK: %[[VAL_2:.+]] = tosa.reshape %1 {new_shape = array} : (tensor<4x1x5x6xf32>) -> tensor<1x20x6xf32> -// CHECK-NEXT: %[[VAL_3:.+]] = "tosa.const"() <{value = dense<[2, 0, 1, 3]> : tensor<4xi32>}> : () -> tensor<4xi32> -// CHECK-NEXT: %[[VAL_4:.+]] = tosa.transpose %0, %[[VAL_3]] : (tensor<1x3x6x7xf32>, tensor<4xi32>) -> tensor<6x1x3x7xf32> -// CHECK-NEXT: %[[VAL_5:.+]] = tosa.reshape %[[VAL_4]] {new_shape = array} : (tensor<6x1x3x7xf32>) -> tensor<1x6x21xf32> -// CHECK-NEXT: %[[VAL_6:.+]] = tosa.matmul %[[VAL_2]], %[[VAL_5]] : (tensor<1x20x6xf32>, tensor<1x6x21xf32>) -> tensor<1x20x21xf32> -// CHECK-NEXT: %[[VAL_7:.+]] = tosa.reshape %[[VAL_6]] {new_shape = array} : (tensor<1x20x21xf32>) -> tensor<4x5x3x7xf32> -// CHECK-NEXT: %[[VAL_8:.+]] = "tosa.const"() <{value = dense<[0, 2, 1, 3]> : tensor<4xi32>}> : () -> tensor<4xi32> -// CHECK-NEXT: %[[VAL_9:.+]] = tosa.transpose %[[VAL_7]], %[[VAL_8]] : (tensor<4x5x3x7xf32>, tensor<4xi32>) -> tensor<4x3x5x7xf32> +/// CHECK-LABEL: func.func @torch.aten.matmul_4d_broadcast_2( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,1,5,6],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[1,3,6,7],f32>) -> !torch.vtensor<[4,3,5,7],f32> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[1,3,6,7],f32> -> tensor<1x3x6x7xf32> +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,1,5,6],f32> -> tensor<4x1x5x6xf32> +// CHECK: %[[VAL_4:.*]] = tosa.const_shape {value = dense<[1, 20, 6]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_5:.*]] = tosa.reshape %[[VAL_3]], %[[VAL_4]] : (tensor<4x1x5x6xf32>, !tosa.shape<3>) -> tensor<1x20x6xf32> +// CHECK: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<[2, 0, 1, 3]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK: %[[VAL_7:.*]] = tosa.transpose %[[VAL_2]], %[[VAL_6]] : (tensor<1x3x6x7xf32>, tensor<4xi32>) -> tensor<6x1x3x7xf32> +// CHECK: %[[VAL_8:.*]] = tosa.const_shape {value = dense<[1, 6, 21]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_9:.*]] = tosa.reshape %[[VAL_7]], %[[VAL_8]] : (tensor<6x1x3x7xf32>, !tosa.shape<3>) -> tensor<1x6x21xf32> +// CHECK: %[[VAL_10:.*]] = tosa.matmul %[[VAL_5]], %[[VAL_9]] : (tensor<1x20x6xf32>, tensor<1x6x21xf32>) -> tensor<1x20x21xf32> +// CHECK: %[[VAL_11:.*]] = tosa.const_shape {value = dense<[4, 5, 3, 7]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_12:.*]] = tosa.reshape %[[VAL_10]], %[[VAL_11]] : (tensor<1x20x21xf32>, !tosa.shape<4>) -> tensor<4x5x3x7xf32> +// CHECK: %[[VAL_13:.*]] = "tosa.const"() <{value = dense<[0, 2, 1, 3]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK: %[[VAL_14:.*]] = tosa.transpose %[[VAL_12]], %[[VAL_13]] : (tensor<4x5x3x7xf32>, tensor<4xi32>) -> tensor<4x3x5x7xf32> +// CHECK: %[[VAL_15:.*]] = torch_c.from_builtin_tensor %[[VAL_14]] : tensor<4x3x5x7xf32> -> !torch.vtensor<[4,3,5,7],f32> +// CHECK: return %[[VAL_15]] : !torch.vtensor<[4,3,5,7],f32> func.func @torch.aten.matmul_4d_broadcast_2(%arg0 : !torch.vtensor<[4,1,5,6],f32>, %arg1 : !torch.vtensor<[1,3,6,7],f32>) -> !torch.vtensor<[4,3,5,7],f32> { %0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[4,1,5,6],f32>, !torch.vtensor<[1,3,6,7],f32> -> !torch.vtensor<[4,3,5,7],f32> return %0 : !torch.vtensor<[4,3,5,7],f32> @@ -200,13 +297,24 @@ func.func @torch.aten.matmul_4d_broadcast_2(%arg0 : !torch.vtensor<[4,1,5,6],f32 // ----- -// CHECK: %[[VAL_2:.+]] = tosa.reshape %0 {new_shape = array} : (tensor<8x16xf32>) -> tensor<1x8x16xf32> -// CHECK-NEXT: %[[VAL_3:.+]] = tosa.reshape %1 {new_shape = array} : (tensor<100x4x8xf32>) -> tensor<1x400x8xf32> -// CHECK-NEXT: %[[VAL_4:.+]] = "tosa.const"() <{value = dense<[1, 0, 2]> : tensor<3xi32>}> : () -> tensor<3xi32> -// CHECK-NEXT: %[[VAL_5:.+]] = tosa.transpose %[[VAL_2]], %[[VAL_4]] : (tensor<1x8x16xf32>, tensor<3xi32>) -> tensor<8x1x16xf32> -// CHECK-NEXT: %[[VAL_6:.+]] = tosa.reshape %[[VAL_5]] {new_shape = array} : (tensor<8x1x16xf32>) -> tensor<1x8x16xf32> -// CHECK-NEXT: %[[VAL_7:.+]] = tosa.matmul %[[VAL_3]], %[[VAL_6]] : (tensor<1x400x8xf32>, tensor<1x8x16xf32>) -> tensor<1x400x16xf32> -// CHECK-NEXT: %[[VAL_8:.+]] = tosa.reshape %[[VAL_7]] {new_shape = array} : (tensor<1x400x16xf32>) -> tensor<100x4x16xf32> +// CHECK-LABEL: func.func @torch.aten.matmul_3d_broadcast( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[100,4,8],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[8,16],f32>) -> !torch.vtensor<[100,4,16],f32> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[8,16],f32> -> tensor<8x16xf32> +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[100,4,8],f32> -> tensor<100x4x8xf32> +// CHECK: %[[VAL_4:.*]] = tosa.const_shape {value = dense<[1, 8, 16]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_5:.*]] = tosa.reshape %[[VAL_2]], %[[VAL_4]] : (tensor<8x16xf32>, !tosa.shape<3>) -> tensor<1x8x16xf32> +// CHECK: %[[VAL_6:.*]] = tosa.const_shape {value = dense<[1, 400, 8]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_7:.*]] = tosa.reshape %[[VAL_3]], %[[VAL_6]] : (tensor<100x4x8xf32>, !tosa.shape<3>) -> tensor<1x400x8xf32> +// CHECK: %[[VAL_8:.*]] = "tosa.const"() <{value = dense<[1, 0, 2]> : tensor<3xi32>}> : () -> tensor<3xi32> +// CHECK: %[[VAL_9:.*]] = tosa.transpose %[[VAL_5]], %[[VAL_8]] : (tensor<1x8x16xf32>, tensor<3xi32>) -> tensor<8x1x16xf32> +// CHECK: %[[VAL_10:.*]] = tosa.const_shape {value = dense<[1, 8, 16]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_11:.*]] = tosa.reshape %[[VAL_9]], %[[VAL_10]] : (tensor<8x1x16xf32>, !tosa.shape<3>) -> tensor<1x8x16xf32> +// CHECK: %[[VAL_12:.*]] = tosa.matmul %[[VAL_7]], %[[VAL_11]] : (tensor<1x400x8xf32>, tensor<1x8x16xf32>) -> tensor<1x400x16xf32> +// CHECK: %[[VAL_13:.*]] = tosa.const_shape {value = dense<[100, 4, 16]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_14:.*]] = tosa.reshape %[[VAL_12]], %[[VAL_13]] : (tensor<1x400x16xf32>, !tosa.shape<3>) -> tensor<100x4x16xf32> +// CHECK: %[[VAL_15:.*]] = torch_c.from_builtin_tensor %[[VAL_14]] : tensor<100x4x16xf32> -> !torch.vtensor<[100,4,16],f32> +// CHECK: return %[[VAL_15]] : !torch.vtensor<[100,4,16],f32> func.func @torch.aten.matmul_3d_broadcast(%arg0 : !torch.vtensor<[100,4,8],f32>, %arg1 : !torch.vtensor<[8,16],f32>) -> !torch.vtensor<[100,4,16],f32> { %0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[100,4,8],f32>, !torch.vtensor<[8,16],f32> -> !torch.vtensor<[100,4,16],f32> return %0 : !torch.vtensor<[100,4,16],f32> @@ -258,14 +366,17 @@ func.func @torch.aten.relu$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vte // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[VAL_2:.*]] = torch.constant.float 1.000000e-01 // CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<1.000000e-01> : tensor}> : () -> tensor -// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor) -> tensor<1x1xf32> -// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor}> : () -> tensor -// CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_5]] {new_shape = array} : (tensor) -> tensor<1x1xf32> -// CHECK: %[[VAL_7:.*]] = tosa.greater_equal %[[VAL_1]], %[[VAL_6]] : (tensor, tensor<1x1xf32>) -> tensor -// CHECK: %[[VAL_8:.*]] = tosa.mul %[[VAL_1]], %[[VAL_4]] {shift = 0 : i8} : (tensor, tensor<1x1xf32>) -> tensor -// CHECK: %[[VAL_9:.*]] = tosa.select %[[VAL_7]], %[[VAL_1]], %[[VAL_8]] : (tensor, tensor, tensor) -> tensor -// CHECK: %[[VAL_10:.*]] = torch_c.from_builtin_tensor %[[VAL_9]] : tensor -> !torch.vtensor<[?,?],f32> -// CHECK: return %[[VAL_10]] : !torch.vtensor<[?,?],f32> +// CHECK: %[[VAL_4:.*]] = tosa.const_shape {value = dense<1> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_5:.*]] = tosa.reshape %[[VAL_3]], %[[VAL_4]] : (tensor, !tosa.shape<2>) -> tensor<1x1xf32> +// CHECK: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor}> : () -> tensor +// CHECK: %[[VAL_7:.*]] = tosa.const_shape {value = dense<1> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_6]], %[[VAL_7]] : (tensor, !tosa.shape<2>) -> tensor<1x1xf32> +// CHECK: %[[VAL_9:.*]] = tosa.greater_equal %[[VAL_1]], %[[VAL_8]] : (tensor, tensor<1x1xf32>) -> tensor +// CHECK: %[[VAL_10:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_11:.*]] = tosa.mul %[[VAL_1]], %[[VAL_5]], %[[VAL_10]] : (tensor, tensor<1x1xf32>, tensor<1xi8>) -> tensor +// CHECK: %[[VAL_12:.*]] = tosa.select %[[VAL_9]], %[[VAL_1]], %[[VAL_11]] : (tensor, tensor, tensor) -> tensor +// CHECK: %[[VAL_13:.*]] = torch_c.from_builtin_tensor %[[VAL_12]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_13]] : !torch.vtensor<[?,?],f32> // CHECK: } func.func @torch.aten.leaky_relu$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { %fp0 = torch.constant.float 1.000000e-01 @@ -376,11 +487,13 @@ func.func @torch.aten.reciprocal$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !tor // CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[VAL_4:.*]] = torch.constant.int 1 // CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor}> : () -> tensor -// CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_5]] {new_shape = array} : (tensor) -> tensor<1x1xf32> -// CHECK: %[[VAL_7:.*]] = tosa.mul %[[VAL_2]], %[[VAL_6]] {shift = 0 : i8} : (tensor, tensor<1x1xf32>) -> tensor -// CHECK: %[[VAL_8:.*]] = tosa.add %[[VAL_3]], %[[VAL_7]] : (tensor, tensor) -> tensor -// CHECK: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_8]] : tensor -> !torch.vtensor<[?,?],f32> -// CHECK: return %[[VAL_9]] : !torch.vtensor<[?,?],f32> +// CHECK: %[[VAL_6:.*]] = tosa.const_shape {value = dense<1> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_7:.*]] = tosa.reshape %[[VAL_5]], %[[VAL_6]] : (tensor, !tosa.shape<2>) -> tensor<1x1xf32> +// CHECK: %[[VAL_8:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_9:.*]] = tosa.mul %[[VAL_2]], %[[VAL_7]], %[[VAL_8]] : (tensor, tensor<1x1xf32>, tensor<1xi8>) -> tensor +// CHECK: %[[VAL_10:.*]] = tosa.add %[[VAL_3]], %[[VAL_9]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_11:.*]] = torch_c.from_builtin_tensor %[[VAL_10]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_11]] : !torch.vtensor<[?,?],f32> // CHECK: } func.func @torch.aten.add$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vtensor<[?, ?],f32>) -> !torch.vtensor<[?, ?],f32> { %int1 = torch.constant.int 1 @@ -397,11 +510,13 @@ func.func @torch.aten.add$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch // CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[VAL_4:.*]] = torch.constant.int 1 // CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor}> : () -> tensor -// CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_5]] {new_shape = array} : (tensor) -> tensor<1x1xf32> -// CHECK: %[[VAL_7:.*]] = tosa.mul %[[VAL_2]], %[[VAL_6]] {shift = 0 : i8} : (tensor, tensor<1x1xf32>) -> tensor -// CHECK: %[[VAL_8:.*]] = tosa.sub %[[VAL_3]], %[[VAL_7]] : (tensor, tensor) -> tensor -// CHECK: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_8]] : tensor -> !torch.vtensor<[?,?],f32> -// CHECK: return %[[VAL_9]] : !torch.vtensor<[?,?],f32> +// CHECK: %[[VAL_6:.*]] = tosa.const_shape {value = dense<1> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_7:.*]] = tosa.reshape %[[VAL_5]], %[[VAL_6]] : (tensor, !tosa.shape<2>) -> tensor<1x1xf32> +// CHECK: %[[VAL_8:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_9:.*]] = tosa.mul %[[VAL_2]], %[[VAL_7]], %[[VAL_8]] : (tensor, tensor<1x1xf32>, tensor<1xi8>) -> tensor +// CHECK: %[[VAL_10:.*]] = tosa.sub %[[VAL_3]], %[[VAL_9]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_11:.*]] = torch_c.from_builtin_tensor %[[VAL_10]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_11]] : !torch.vtensor<[?,?],f32> // CHECK: } func.func @torch.aten.sub$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vtensor<[?, ?],f32>) -> !torch.vtensor<[?, ?],f32> { %int1 = torch.constant.int 1 @@ -414,11 +529,12 @@ func.func @torch.aten.sub$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch // CHECK-LABEL: func.func @torch.aten.mul$basic( // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>, // CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { -// CHECK-DAG: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK-DAG: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[VAL_4:.*]] = tosa.mul %[[VAL_2]], %[[VAL_3]] {shift = 0 : i8} : (tensor, tensor) -> tensor -// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],f32> -// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],f32> +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_5:.*]] = tosa.mul %[[VAL_3]], %[[VAL_2]], %[[VAL_4]] : (tensor, tensor, tensor<1xi8>) -> tensor +// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_6]] : !torch.vtensor<[?,?],f32> // CHECK: } func.func @torch.aten.mul$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vtensor<[?, ?],f32>) -> !torch.vtensor<[?, ?],f32> { %0 = torch.aten.mul.Tensor %arg0, %arg1 : !torch.vtensor<[?, ?],f32>, !torch.vtensor<[?, ?],f32> -> !torch.vtensor<[?, ?],f32> @@ -433,9 +549,10 @@ func.func @torch.aten.mul$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch // CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[VAL_4:.*]] = tosa.reciprocal %[[VAL_2]] : (tensor) -> tensor -// CHECK: %[[VAL_5:.*]] = tosa.mul %[[VAL_3]], %[[VAL_4]] {shift = 0 : i8} : (tensor, tensor) -> tensor -// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor -> !torch.vtensor<[?,?],f32> -// CHECK: return %[[VAL_6]] : !torch.vtensor<[?,?],f32> +// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_3]], %[[VAL_4]], %[[VAL_5]] : (tensor, tensor, tensor<1xi8>) -> tensor +// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_7]] : !torch.vtensor<[?,?],f32> // CHECK: } func.func @torch.aten.div$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vtensor<[?, ?],f32>) -> !torch.vtensor<[?, ?],f32> { %0 = torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[?, ?],f32>, !torch.vtensor<[?, ?],f32> -> !torch.vtensor<[?, ?],f32> @@ -459,50 +576,54 @@ func.func @torch.aten.rsqrt$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vt // ----- // CHECK-LABEL: func.func @test_reduce_mean_dim$basic( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { -// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4,5,6],f32>) -> !torch.vtensor<[4,5,6],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4,5,6],f32> -> tensor<3x4x5x6xf32> // CHECK: %[[VAL_2:.*]] = torch.constant.int 0 // CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %[[VAL_2]] : (!torch.int) -> !torch.list // CHECK: %[[VAL_4:.*]] = torch.constant.bool false // CHECK: %[[VAL_5:.*]] = torch.constant.none -// CHECK: %[[VAL_6:.*]] = tosa.reduce_sum %[[VAL_1]] {axis = 0 : i32} : (tensor) -> tensor<1x?x?x?xf32> -// CHECK: %[[VAL_7:.*]] = tosa.reshape %[[VAL_6]] {new_shape = array} : (tensor<1x?x?x?xf32>) -> tensor -// CHECK: %[[VAL_8:.*]] = "tosa.const"() <{value = dense<-1.08420217E-19> : tensor}> : () -> tensor -// CHECK: %[[VAL_9:.*]] = tosa.reshape %[[VAL_8]] {new_shape = array} : (tensor) -> tensor<1x1x1xf32> -// CHECK: %[[VAL_10:.*]] = tosa.mul %[[VAL_7]], %[[VAL_9]] {shift = 0 : i8} : (tensor, tensor<1x1x1xf32>) -> tensor -// CHECK: %[[VAL_11:.*]] = torch_c.from_builtin_tensor %[[VAL_10]] : tensor -> !torch.vtensor<[?,?,?],f32> -// CHECK: return %[[VAL_11]] : !torch.vtensor<[?,?,?],f32> -// CHECK: } -func.func @test_reduce_mean_dim$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { +// CHECK: %[[VAL_6:.*]] = tosa.reduce_sum %[[VAL_1]] {axis = 0 : i32} : (tensor<3x4x5x6xf32>) -> tensor<1x4x5x6xf32> +// CHECK: %[[VAL_7:.*]] = tosa.const_shape {value = dense<[4, 5, 6]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_6]], %[[VAL_7]] : (tensor<1x4x5x6xf32>, !tosa.shape<3>) -> tensor<4x5x6xf32> +// CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<0.333333343> : tensor}> : () -> tensor +// CHECK: %[[VAL_10:.*]] = tosa.const_shape {value = dense<1> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_11:.*]] = tosa.reshape %[[VAL_9]], %[[VAL_10]] : (tensor, !tosa.shape<3>) -> tensor<1x1x1xf32> +// CHECK: %[[VAL_12:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_13:.*]] = tosa.mul %[[VAL_8]], %[[VAL_11]], %[[VAL_12]] : (tensor<4x5x6xf32>, tensor<1x1x1xf32>, tensor<1xi8>) -> tensor<4x5x6xf32> +// CHECK: %[[VAL_14:.*]] = torch_c.from_builtin_tensor %[[VAL_13]] : tensor<4x5x6xf32> -> !torch.vtensor<[4,5,6],f32> +// CHECK: return %[[VAL_14]] : !torch.vtensor<[4,5,6],f32> +// CHECK: } +func.func @test_reduce_mean_dim$basic(%arg0: !torch.vtensor<[3,4,5,6],f32>) -> !torch.vtensor<[4,5,6],f32> { %dim0 = torch.constant.int 0 %reducedims = torch.prim.ListConstruct %dim0 : (!torch.int) -> !torch.list %keepdims = torch.constant.bool false %dtype = torch.constant.none - // expected-error @+1 {{Failed convertReduceMean: support for dynamic input shape not implemented}} - %0 = torch.aten.mean.dim %arg0, %reducedims, %keepdims, %dtype : !torch.vtensor<[?,?,?,?],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?],f32> - return %0 : !torch.vtensor<[?,?,?],f32> + %0 = torch.aten.mean.dim %arg0, %reducedims, %keepdims, %dtype : !torch.vtensor<[3,4,5,6],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,5,6],f32> + return %0 : !torch.vtensor<[4,5,6],f32> } // ----- // CHECK-LABEL: func.func @test_reduce_sum_dims$basic( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { -// CHECK: %[[ARG0_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor -// CHECK: %[[ARG1_BUILTIN:.*]] = torch.constant.none -// CHECK: %[[ARG2_BUILTIN:.*]] = torch.constant.bool false -// CHECK: %[[ARG3:.*]] = torch.constant.int 0 -// CHECK: %[[ARG3_BUILTIN:.*]] = torch.prim.ListConstruct %[[ARG3]] : (!torch.int) -> !torch.list -// CHECK: %[[SUM:.*]] = tosa.reduce_sum %[[ARG0_BUILTIN]] {axis = 0 : i32} : (tensor) -> tensor<1x?x?x?xf32> -// CHECK: %[[RESULT_BUILTIN:.*]] = tosa.reshape %[[SUM]] {new_shape = array} : (tensor<1x?x?x?xf32>) -> tensor -// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor -> !torch.vtensor<[?,?,?],f32> -// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?,?],f32> -func.func @test_reduce_sum_dims$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4,5,6],f32>) -> !torch.vtensor<[4,5,6],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4,5,6],f32> -> tensor<3x4x5x6xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.none +// CHECK: %[[VAL_3:.*]] = torch.constant.bool false +// CHECK: %[[VAL_4:.*]] = torch.constant.int 0 +// CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %[[VAL_4]] : (!torch.int) -> !torch.list +// CHECK: %[[VAL_6:.*]] = tosa.reduce_sum %[[VAL_1]] {axis = 0 : i32} : (tensor<3x4x5x6xf32>) -> tensor<1x4x5x6xf32> +// CHECK: %[[VAL_7:.*]] = tosa.const_shape {value = dense<[4, 5, 6]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_6]], %[[VAL_7]] : (tensor<1x4x5x6xf32>, !tosa.shape<3>) -> tensor<4x5x6xf32> +// CHECK: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_8]] : tensor<4x5x6xf32> -> !torch.vtensor<[4,5,6],f32> +// CHECK: return %[[VAL_9]] : !torch.vtensor<[4,5,6],f32> +// CHECK: } +func.func @test_reduce_sum_dims$basic(%arg0: !torch.vtensor<[3,4,5,6],f32>) -> !torch.vtensor<[4,5,6],f32> { %none = torch.constant.none %false = torch.constant.bool false %int0 = torch.constant.int 0 %0 = torch.prim.ListConstruct %int0 : (!torch.int) -> !torch.list - %1 = torch.aten.sum.dim_IntList %arg0, %0, %false, %none : !torch.vtensor<[?,?,?,?],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?],f32> - return %1 : !torch.vtensor<[?,?,?],f32> + %1 = torch.aten.sum.dim_IntList %arg0, %0, %false, %none : !torch.vtensor<[3,4,5,6],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,5,6],f32> + return %1 : !torch.vtensor<[4,5,6],f32> } // ----- @@ -516,15 +637,17 @@ func.func @test_reduce_sum_dims$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> ! // CHECK: %[[VAL_5:.*]] = torch.constant.none // CHECK: %[[VAL_6:.*]] = torch.prim.ListConstruct %[[VAL_3]] : (!torch.int) -> !torch.list // CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<2.000000e+00> : tensor}> : () -> tensor -// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_7]] {new_shape = array} : (tensor) -> tensor<1x1x1xf32> -// CHECK: %[[VAL_9:.*]] = tosa.abs %[[VAL_1]] : (tensor<3x151x64xf32>) -> tensor<3x151x64xf32> -// CHECK: %[[VAL_10:.*]] = tosa.pow %[[VAL_9]], %[[VAL_8]] : (tensor<3x151x64xf32>, tensor<1x1x1xf32>) -> tensor<3x151x64xf32> -// CHECK: %[[VAL_11:.*]] = tosa.reduce_sum %[[VAL_10]] {axis = 2 : i32} : (tensor<3x151x64xf32>) -> tensor<3x151x1xf32> -// CHECK: %[[VAL_12:.*]] = tosa.reciprocal %[[VAL_7]] : (tensor) -> tensor -// CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_12]] {new_shape = array} : (tensor) -> tensor<1x1x1xf32> -// CHECK: %[[VAL_14:.*]] = tosa.pow %[[VAL_11]], %[[VAL_13]] : (tensor<3x151x1xf32>, tensor<1x1x1xf32>) -> tensor<3x151x1xf32> -// CHECK: %[[VAL_15:.*]] = torch_c.from_builtin_tensor %[[VAL_14]] : tensor<3x151x1xf32> -> !torch.vtensor<[3,151,1],f32> -// CHECK: return %[[VAL_15]] : !torch.vtensor<[3,151,1],f32> +// CHECK: %[[VAL_8:.*]] = tosa.const_shape {value = dense<1> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_9:.*]] = tosa.reshape %[[VAL_7]], %[[VAL_8]] : (tensor, !tosa.shape<3>) -> tensor<1x1x1xf32> +// CHECK: %[[VAL_10:.*]] = tosa.abs %[[VAL_1]] : (tensor<3x151x64xf32>) -> tensor<3x151x64xf32> +// CHECK: %[[VAL_11:.*]] = tosa.pow %[[VAL_10]], %[[VAL_9]] : (tensor<3x151x64xf32>, tensor<1x1x1xf32>) -> tensor<3x151x64xf32> +// CHECK: %[[VAL_12:.*]] = tosa.reduce_sum %[[VAL_11]] {axis = 2 : i32} : (tensor<3x151x64xf32>) -> tensor<3x151x1xf32> +// CHECK: %[[VAL_13:.*]] = tosa.reciprocal %[[VAL_7]] : (tensor) -> tensor +// CHECK: %[[VAL_14:.*]] = tosa.const_shape {value = dense<1> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_15:.*]] = tosa.reshape %[[VAL_13]], %[[VAL_14]] : (tensor, !tosa.shape<3>) -> tensor<1x1x1xf32> +// CHECK: %[[VAL_16:.*]] = tosa.pow %[[VAL_12]], %[[VAL_15]] : (tensor<3x151x1xf32>, tensor<1x1x1xf32>) -> tensor<3x151x1xf32> +// CHECK: %[[VAL_17:.*]] = torch_c.from_builtin_tensor %[[VAL_16]] : tensor<3x151x1xf32> -> !torch.vtensor<[3,151,1],f32> +// CHECK: return %[[VAL_17]] : !torch.vtensor<[3,151,1],f32> // CHECK: } func.func @test_linalg_vector_norm$basic(%arg0: !torch.vtensor<[3,151,64],f32>) -> (!torch.vtensor<[3,151,1],f32>) { %float2.000000e00 = torch.constant.float 2.000000e+00 @@ -539,16 +662,18 @@ func.func @test_linalg_vector_norm$basic(%arg0: !torch.vtensor<[3,151,64],f32>) // ----- // CHECK-LABEL: func.func @test_reduce_sum$basic( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[1],f32> { -// CHECK: %[[ARG0_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor -// CHECK: %[[ARG1_BUILTIN:.*]] = torch.constant.none -// CHECK: %[[REDUCE1:.*]] = tosa.reduce_sum %[[ARG0_BUILTIN]] {axis = 0 : i32} : (tensor) -> tensor<1x?x?x?xf32> -// CHECK: %[[REDUCE2:.*]] = tosa.reduce_sum %[[REDUCE1]] {axis = 1 : i32} : (tensor<1x?x?x?xf32>) -> tensor<1x1x?x?xf32> -// CHECK: %[[REDUCE3:.*]] = tosa.reduce_sum %[[REDUCE2]] {axis = 2 : i32} : (tensor<1x1x?x?xf32>) -> tensor<1x1x1x?xf32> -// CHECK: %[[REDUCE4:.*]] = tosa.reduce_sum %[[REDUCE3]] {axis = 3 : i32} : (tensor<1x1x1x?xf32>) -> tensor<1x1x1x1xf32> -// CHECK: %[[RESULT_BUILTIN:.*]] = tosa.reshape %[[REDUCE4]] {new_shape = array} : (tensor<1x1x1x1xf32>) -> tensor<1xf32> -// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<1xf32> -> !torch.vtensor<[1],f32> -// CHECK: return %[[RESULT]] : !torch.vtensor<[1],f32> +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[1],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor +// CHECK: %[[VAL_2:.*]] = torch.constant.none +// CHECK: %[[VAL_3:.*]] = tosa.reduce_sum %[[VAL_1]] {axis = 0 : i32} : (tensor) -> tensor<1x?x?x?xf32> +// CHECK: %[[VAL_4:.*]] = tosa.reduce_sum %[[VAL_3]] {axis = 1 : i32} : (tensor<1x?x?x?xf32>) -> tensor<1x1x?x?xf32> +// CHECK: %[[VAL_5:.*]] = tosa.reduce_sum %[[VAL_4]] {axis = 2 : i32} : (tensor<1x1x?x?xf32>) -> tensor<1x1x1x?xf32> +// CHECK: %[[VAL_6:.*]] = tosa.reduce_sum %[[VAL_5]] {axis = 3 : i32} : (tensor<1x1x1x?xf32>) -> tensor<1x1x1x1xf32> +// CHECK: %[[VAL_7:.*]] = tosa.const_shape {value = dense<1> : tensor<1xindex>} : () -> !tosa.shape<1> +// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_6]], %[[VAL_7]] : (tensor<1x1x1x1xf32>, !tosa.shape<1>) -> tensor<1xf32> +// CHECK: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_8]] : tensor<1xf32> -> !torch.vtensor<[1],f32> +// CHECK: return %[[VAL_9]] : !torch.vtensor<[1],f32> +// CHECK: } func.func @test_reduce_sum$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[1],f32> { %none = torch.constant.none %0 = torch.aten.sum %arg0, %none : !torch.vtensor<[?,?,?,?],f32>, !torch.none -> !torch.vtensor<[1],f32> @@ -558,15 +683,17 @@ func.func @test_reduce_sum$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch // ----- // CHECK-LABEL: func.func @test_reduce_all$basic( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?],i1>) -> !torch.vtensor<[1],i1> { -// CHECK: %[[ARG0_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?],i1> -> tensor -// CHECK: %[[REDUCE1:.*]] = tosa.reduce_all %[[ARG0_BUILTIN]] {axis = 0 : i32} : (tensor) -> tensor<1x?x?x?xi1> -// CHECK: %[[REDUCE2:.*]] = tosa.reduce_all %[[REDUCE1]] {axis = 1 : i32} : (tensor<1x?x?x?xi1>) -> tensor<1x1x?x?xi1> -// CHECK: %[[REDUCE3:.*]] = tosa.reduce_all %[[REDUCE2]] {axis = 2 : i32} : (tensor<1x1x?x?xi1>) -> tensor<1x1x1x?xi1> -// CHECK: %[[REDUCE4:.*]] = tosa.reduce_all %[[REDUCE3]] {axis = 3 : i32} : (tensor<1x1x1x?xi1>) -> tensor<1x1x1x1xi1> -// CHECK: %[[RESULT_BUILTIN:.*]] = tosa.reshape %[[REDUCE4]] {new_shape = array} : (tensor<1x1x1x1xi1>) -> tensor<1xi1> -// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<1xi1> -> !torch.vtensor<[1],i1> -// CHECK: return %[[RESULT]] : !torch.vtensor<[1],i1> +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?,?,?],i1>) -> !torch.vtensor<[1],i1> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?,?,?],i1> -> tensor +// CHECK: %[[VAL_2:.*]] = tosa.reduce_all %[[VAL_1]] {axis = 0 : i32} : (tensor) -> tensor<1x?x?x?xi1> +// CHECK: %[[VAL_3:.*]] = tosa.reduce_all %[[VAL_2]] {axis = 1 : i32} : (tensor<1x?x?x?xi1>) -> tensor<1x1x?x?xi1> +// CHECK: %[[VAL_4:.*]] = tosa.reduce_all %[[VAL_3]] {axis = 2 : i32} : (tensor<1x1x?x?xi1>) -> tensor<1x1x1x?xi1> +// CHECK: %[[VAL_5:.*]] = tosa.reduce_all %[[VAL_4]] {axis = 3 : i32} : (tensor<1x1x1x?xi1>) -> tensor<1x1x1x1xi1> +// CHECK: %[[VAL_6:.*]] = tosa.const_shape {value = dense<1> : tensor<1xindex>} : () -> !tosa.shape<1> +// CHECK: %[[VAL_7:.*]] = tosa.reshape %[[VAL_5]], %[[VAL_6]] : (tensor<1x1x1x1xi1>, !tosa.shape<1>) -> tensor<1xi1> +// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<1xi1> -> !torch.vtensor<[1],i1> +// CHECK: return %[[VAL_8]] : !torch.vtensor<[1],i1> +// CHECK: } func.func @test_reduce_all$basic(%arg0: !torch.vtensor<[?,?,?,?],i1>) -> !torch.vtensor<[1],i1> { %0 = torch.aten.all %arg0 : !torch.vtensor<[?,?,?,?],i1> -> !torch.vtensor<[1],i1> return %0 : !torch.vtensor<[1],i1> @@ -575,33 +702,37 @@ func.func @test_reduce_all$basic(%arg0: !torch.vtensor<[?,?,?,?],i1>) -> !torch. // ----- // CHECK-LABEL: func.func @test_reduce_any_dim$basic( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?],i1>) -> !torch.vtensor<[?,?,?],i1> { -// CHECK: %[[ARG0_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?],i1> -> tensor -// CHECK: %[[ARG1:.*]] = torch.constant.int 0 -// CHECK: %[[ARG2:.*]] = torch.constant.bool false -// CHECK: %[[REDUCE:.*]] = tosa.reduce_any %[[ARG0_BUILTIN]] {axis = 0 : i32} : (tensor) -> tensor<1x?x?x?xi1> -// CHECK: %[[RESULT_BUILTIN:.*]] = tosa.reshape %[[REDUCE]] {new_shape = array} : (tensor<1x?x?x?xi1>) -> tensor -// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor -> !torch.vtensor<[?,?,?],i1> -// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?,?],i1> -func.func @test_reduce_any_dim$basic(%arg0: !torch.vtensor<[?,?,?,?],i1>) -> !torch.vtensor<[?,?,?],i1> { +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4,5,6],i1>) -> !torch.vtensor<[4,5,6],i1> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4,5,6],i1> -> tensor<3x4x5x6xi1> +// CHECK: %[[VAL_2:.*]] = torch.constant.int 0 +// CHECK: %[[VAL_3:.*]] = torch.constant.bool false +// CHECK: %[[VAL_4:.*]] = tosa.reduce_any %[[VAL_1]] {axis = 0 : i32} : (tensor<3x4x5x6xi1>) -> tensor<1x4x5x6xi1> +// CHECK: %[[VAL_5:.*]] = tosa.const_shape {value = dense<[4, 5, 6]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_4]], %[[VAL_5]] : (tensor<1x4x5x6xi1>, !tosa.shape<3>) -> tensor<4x5x6xi1> +// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<4x5x6xi1> -> !torch.vtensor<[4,5,6],i1> +// CHECK: return %[[VAL_7]] : !torch.vtensor<[4,5,6],i1> +// CHECK: } +func.func @test_reduce_any_dim$basic(%arg0: !torch.vtensor<[3,4,5,6],i1>) -> !torch.vtensor<[4,5,6],i1> { %int0 = torch.constant.int 0 %false = torch.constant.bool false - %0 = torch.aten.any.dim %arg0, %int0, %false : !torch.vtensor<[?,?,?,?],i1>, !torch.int, !torch.bool -> !torch.vtensor<[?,?,?],i1> - return %0 : !torch.vtensor<[?,?,?],i1> + %0 = torch.aten.any.dim %arg0, %int0, %false : !torch.vtensor<[3,4,5,6],i1>, !torch.int, !torch.bool -> !torch.vtensor<[4,5,6],i1> + return %0 : !torch.vtensor<[4,5,6],i1> } // ----- // CHECK-LABEL: func.func @test_reduce_any$basic( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?],i1>) -> !torch.vtensor<[1],i1> { -// CHECK: %[[ARG0_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?],i1> -> tensor -// CHECK: %[[REDUCE1:.*]] = tosa.reduce_any %[[ARG0_BUILTIN]] {axis = 0 : i32} : (tensor) -> tensor<1x?x?x?xi1> -// CHECK: %[[REDUCE2:.*]] = tosa.reduce_any %[[REDUCE1]] {axis = 1 : i32} : (tensor<1x?x?x?xi1>) -> tensor<1x1x?x?xi1> -// CHECK: %[[REDUCE3:.*]] = tosa.reduce_any %[[REDUCE2]] {axis = 2 : i32} : (tensor<1x1x?x?xi1>) -> tensor<1x1x1x?xi1> -// CHECK: %[[REDUCE4:.*]] = tosa.reduce_any %[[REDUCE3]] {axis = 3 : i32} : (tensor<1x1x1x?xi1>) -> tensor<1x1x1x1xi1> -// CHECK: %[[RESULT_BUILTIN:.*]] = tosa.reshape %[[REDUCE4]] {new_shape = array} : (tensor<1x1x1x1xi1>) -> tensor<1xi1> -// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<1xi1> -> !torch.vtensor<[1],i1> -// CHECK: return %[[RESULT]] : !torch.vtensor<[1],i1> +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?,?,?],i1>) -> !torch.vtensor<[1],i1> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?,?,?],i1> -> tensor +// CHECK: %[[VAL_2:.*]] = tosa.reduce_any %[[VAL_1]] {axis = 0 : i32} : (tensor) -> tensor<1x?x?x?xi1> +// CHECK: %[[VAL_3:.*]] = tosa.reduce_any %[[VAL_2]] {axis = 1 : i32} : (tensor<1x?x?x?xi1>) -> tensor<1x1x?x?xi1> +// CHECK: %[[VAL_4:.*]] = tosa.reduce_any %[[VAL_3]] {axis = 2 : i32} : (tensor<1x1x?x?xi1>) -> tensor<1x1x1x?xi1> +// CHECK: %[[VAL_5:.*]] = tosa.reduce_any %[[VAL_4]] {axis = 3 : i32} : (tensor<1x1x1x?xi1>) -> tensor<1x1x1x1xi1> +// CHECK: %[[VAL_6:.*]] = tosa.const_shape {value = dense<1> : tensor<1xindex>} : () -> !tosa.shape<1> +// CHECK: %[[VAL_7:.*]] = tosa.reshape %[[VAL_5]], %[[VAL_6]] : (tensor<1x1x1x1xi1>, !tosa.shape<1>) -> tensor<1xi1> +// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<1xi1> -> !torch.vtensor<[1],i1> +// CHECK: return %[[VAL_8]] : !torch.vtensor<[1],i1> +// CHECK: } func.func @test_reduce_any$basic(%arg0: !torch.vtensor<[?,?,?,?],i1>) -> !torch.vtensor<[1],i1> { %0 = torch.aten.any %arg0 : !torch.vtensor<[?,?,?,?],i1> -> !torch.vtensor<[1],i1> return %0 : !torch.vtensor<[1],i1> @@ -660,10 +791,11 @@ func.func @torch.aten.minimum$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !to // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[VAL_2:.*]] = torch.constant.float 3.123400e+00 // CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<3.123400e+00> : tensor}> : () -> tensor -// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor) -> tensor<1x1xf32> -// CHECK: %[[VAL_5:.*]] = tosa.pow %[[VAL_1]], %[[VAL_4]] : (tensor, tensor<1x1xf32>) -> tensor -// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor -> !torch.vtensor<[?,?],f32> -// CHECK: return %[[VAL_6]] : !torch.vtensor<[?,?],f32> +// CHECK: %[[VAL_4:.*]] = tosa.const_shape {value = dense<1> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_5:.*]] = tosa.reshape %[[VAL_3]], %[[VAL_4]] : (tensor, !tosa.shape<2>) -> tensor<1x1xf32> +// CHECK: %[[VAL_6:.*]] = tosa.pow %[[VAL_1]], %[[VAL_5]] : (tensor, tensor<1x1xf32>) -> tensor +// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_7]] : !torch.vtensor<[?,?],f32> // CHECK: } func.func @torch.aten.pow.Tensor_Scalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { %fp0 = torch.constant.float 3.123400e+00 @@ -680,12 +812,15 @@ func.func @torch.aten.pow.Tensor_Scalar$basic(%arg0: !torch.vtensor<[?,?],f32>) // CHECK: %[[VAL_3:.*]] = torch.constant.float 6.432100e+00 // CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<3.123400e+00> : tensor}> : () -> tensor // CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<6.432100e+00> : tensor}> : () -> tensor -// CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_4]] {new_shape = array} : (tensor) -> tensor<1x1xf32> -// CHECK: %[[VAL_7:.*]] = tosa.reshape %[[VAL_5]] {new_shape = array} : (tensor) -> tensor<1x1xf32> -// CHECK: %[[VAL_8:.*]] = tosa.mul %[[VAL_1]], %[[VAL_7]] {shift = 0 : i8} : (tensor, tensor<1x1xf32>) -> tensor -// CHECK: %[[VAL_9:.*]] = tosa.sub %[[VAL_6]], %[[VAL_8]] : (tensor<1x1xf32>, tensor) -> tensor -// CHECK: %[[VAL_10:.*]] = torch_c.from_builtin_tensor %[[VAL_9]] : tensor -> !torch.vtensor<[?,?],f32> -// CHECK: return %[[VAL_10]] : !torch.vtensor<[?,?],f32> +// CHECK: %[[VAL_6:.*]] = tosa.const_shape {value = dense<1> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_7:.*]] = tosa.reshape %[[VAL_4]], %[[VAL_6]] : (tensor, !tosa.shape<2>) -> tensor<1x1xf32> +// CHECK: %[[VAL_8:.*]] = tosa.const_shape {value = dense<1> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_9:.*]] = tosa.reshape %[[VAL_5]], %[[VAL_8]] : (tensor, !tosa.shape<2>) -> tensor<1x1xf32> +// CHECK: %[[VAL_10:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_11:.*]] = tosa.mul %[[VAL_1]], %[[VAL_9]], %[[VAL_10]] : (tensor, tensor<1x1xf32>, tensor<1xi8>) -> tensor +// CHECK: %[[VAL_12:.*]] = tosa.sub %[[VAL_7]], %[[VAL_11]] : (tensor<1x1xf32>, tensor) -> tensor +// CHECK: %[[VAL_13:.*]] = torch_c.from_builtin_tensor %[[VAL_12]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_13]] : !torch.vtensor<[?,?],f32> // CHECK: } func.func @torch.aten.rsub.Scalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { %other = torch.constant.float 3.123400e+00 @@ -703,12 +838,15 @@ func.func @torch.aten.rsub.Scalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !to // CHECK: %[[VAL_3:.*]] = torch.constant.int 1 // CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<3.123400e+00> : tensor}> : () -> tensor // CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor}> : () -> tensor -// CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_4]] {new_shape = array} : (tensor) -> tensor<1x1xf32> -// CHECK: %[[VAL_7:.*]] = tosa.reshape %[[VAL_5]] {new_shape = array} : (tensor) -> tensor<1x1xf32> -// CHECK: %[[VAL_8:.*]] = tosa.mul %[[VAL_1]], %[[VAL_7]] {shift = 0 : i8} : (tensor, tensor<1x1xf32>) -> tensor -// CHECK: %[[VAL_9:.*]] = tosa.sub %[[VAL_6]], %[[VAL_8]] : (tensor<1x1xf32>, tensor) -> tensor -// CHECK: %[[VAL_10:.*]] = torch_c.from_builtin_tensor %[[VAL_9]] : tensor -> !torch.vtensor<[?,?],f32> -// CHECK: return %[[VAL_10]] : !torch.vtensor<[?,?],f32> +// CHECK: %[[VAL_6:.*]] = tosa.const_shape {value = dense<1> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_7:.*]] = tosa.reshape %[[VAL_4]], %[[VAL_6]] : (tensor, !tosa.shape<2>) -> tensor<1x1xf32> +// CHECK: %[[VAL_8:.*]] = tosa.const_shape {value = dense<1> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_9:.*]] = tosa.reshape %[[VAL_5]], %[[VAL_8]] : (tensor, !tosa.shape<2>) -> tensor<1x1xf32> +// CHECK: %[[VAL_10:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_11:.*]] = tosa.mul %[[VAL_1]], %[[VAL_9]], %[[VAL_10]] : (tensor, tensor<1x1xf32>, tensor<1xi8>) -> tensor +// CHECK: %[[VAL_12:.*]] = tosa.sub %[[VAL_7]], %[[VAL_11]] : (tensor<1x1xf32>, tensor) -> tensor +// CHECK: %[[VAL_13:.*]] = torch_c.from_builtin_tensor %[[VAL_12]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_13]] : !torch.vtensor<[?,?],f32> // CHECK: } func.func @torch.aten.rsub.Scalar$float_int(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { %other = torch.constant.float 3.123400e+00 @@ -768,13 +906,14 @@ func.func @torch.aten.eq.Tensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: ! // ----- // CHECK-LABEL: func.func @torch.aten.reshape$basic( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?],f32> { +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?],f32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor // CHECK: %[[VAL_2:.*]] = torch.constant.int -1 // CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %[[VAL_2]] : (!torch.int) -> !torch.list -// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array} : (tensor) -> tensor -// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?],f32> -// CHECK: return %[[VAL_5]] : !torch.vtensor<[?],f32> +// CHECK: %[[VAL_4:.*]] = tosa.const_shape {value = dense<-1> : tensor<1xindex>} : () -> !tosa.shape<1> +// CHECK: %[[VAL_5:.*]] = tosa.reshape %[[VAL_1]], %[[VAL_4]] : (tensor, !tosa.shape<1>) -> tensor +// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor -> !torch.vtensor<[?],f32> +// CHECK: return %[[VAL_6]] : !torch.vtensor<[?],f32> // CHECK: } func.func @torch.aten.reshape$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?],f32> { %dim0 = torch.constant.int -1 @@ -794,24 +933,35 @@ func.func @torch.aten.reshape$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !to // CHECK: %[[VAL_5:.*]] = torch.constant.float 1.000000e-05 // CHECK: %[[VAL_6:.*]] = torch.constant.bool true // CHECK: %[[VAL_7:.*]] = torch.constant.bool false -// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_2]] {new_shape = array} : (tensor<4xf32>) -> tensor<4x1xf32> -// CHECK: %[[VAL_9:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor<4xf32>) -> tensor<4x1xf32> -// CHECK: %[[VAL_10:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor<4xf32>) -> tensor<4x1xf32> -// CHECK: %[[VAL_11:.*]] = tosa.reshape %[[VAL_2]] {new_shape = array} : (tensor<4xf32>) -> tensor<4x1xf32> -// CHECK: %[[VAL_12:.*]] = "tosa.const"() <{value = dense<9.99999974E-6> : tensor}> : () -> tensor -// CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_8]] {new_shape = array} : (tensor<4x1xf32>) -> tensor<1x4x1xf32> -// CHECK: %[[VAL_14:.*]] = tosa.reshape %[[VAL_9]] {new_shape = array} : (tensor<4x1xf32>) -> tensor<1x4x1xf32> -// CHECK: %[[VAL_15:.*]] = tosa.reshape %[[VAL_12]] {new_shape = array} : (tensor) -> tensor<1x1x1xf32> -// CHECK: %[[VAL_16:.*]] = tosa.reshape %[[VAL_10]] {new_shape = array} : (tensor<4x1xf32>) -> tensor<1x4x1xf32> -// CHECK: %[[VAL_17:.*]] = tosa.reshape %[[VAL_11]] {new_shape = array} : (tensor<4x1xf32>) -> tensor<1x4x1xf32> -// CHECK: %[[VAL_18:.*]] = tosa.sub %[[VAL_1]], %[[VAL_13]] : (tensor<10x4x3xf32>, tensor<1x4x1xf32>) -> tensor<10x4x3xf32> -// CHECK: %[[VAL_19:.*]] = tosa.add %[[VAL_14]], %[[VAL_15]] : (tensor<1x4x1xf32>, tensor<1x1x1xf32>) -> tensor<1x4x1xf32> -// CHECK: %[[VAL_20:.*]] = tosa.rsqrt %[[VAL_19]] : (tensor<1x4x1xf32>) -> tensor<1x4x1xf32> -// CHECK: %[[VAL_21:.*]] = tosa.mul %[[VAL_18]], %[[VAL_20]] {shift = 0 : i8} : (tensor<10x4x3xf32>, tensor<1x4x1xf32>) -> tensor<10x4x3xf32> -// CHECK: %[[VAL_22:.*]] = tosa.mul %[[VAL_21]], %[[VAL_16]] {shift = 0 : i8} : (tensor<10x4x3xf32>, tensor<1x4x1xf32>) -> tensor<10x4x3xf32> -// CHECK: %[[VAL_23:.*]] = tosa.add %[[VAL_22]], %[[VAL_17]] : (tensor<10x4x3xf32>, tensor<1x4x1xf32>) -> tensor<10x4x3xf32> -// CHECK: %[[VAL_24:.*]] = torch_c.from_builtin_tensor %[[VAL_23]] : tensor<10x4x3xf32> -> !torch.vtensor<[10,4,3],f32> -// CHECK: return %[[VAL_24]] : !torch.vtensor<[10,4,3],f32> +// CHECK: %[[VAL_8:.*]] = tosa.const_shape {value = dense<[4, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_9:.*]] = tosa.reshape %[[VAL_2]], %[[VAL_8]] : (tensor<4xf32>, !tosa.shape<2>) -> tensor<4x1xf32> +// CHECK: %[[VAL_10:.*]] = tosa.const_shape {value = dense<[4, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_11:.*]] = tosa.reshape %[[VAL_3]], %[[VAL_10]] : (tensor<4xf32>, !tosa.shape<2>) -> tensor<4x1xf32> +// CHECK: %[[VAL_12:.*]] = tosa.const_shape {value = dense<[4, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_3]], %[[VAL_12]] : (tensor<4xf32>, !tosa.shape<2>) -> tensor<4x1xf32> +// CHECK: %[[VAL_14:.*]] = tosa.const_shape {value = dense<[4, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_15:.*]] = tosa.reshape %[[VAL_2]], %[[VAL_14]] : (tensor<4xf32>, !tosa.shape<2>) -> tensor<4x1xf32> +// CHECK: %[[VAL_16:.*]] = "tosa.const"() <{value = dense<9.99999974E-6> : tensor}> : () -> tensor +// CHECK: %[[VAL_17:.*]] = tosa.const_shape {value = dense<[1, 4, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_18:.*]] = tosa.reshape %[[VAL_9]], %[[VAL_17]] : (tensor<4x1xf32>, !tosa.shape<3>) -> tensor<1x4x1xf32> +// CHECK: %[[VAL_19:.*]] = tosa.const_shape {value = dense<[1, 4, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_20:.*]] = tosa.reshape %[[VAL_11]], %[[VAL_19]] : (tensor<4x1xf32>, !tosa.shape<3>) -> tensor<1x4x1xf32> +// CHECK: %[[VAL_21:.*]] = tosa.const_shape {value = dense<1> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_22:.*]] = tosa.reshape %[[VAL_16]], %[[VAL_21]] : (tensor, !tosa.shape<3>) -> tensor<1x1x1xf32> +// CHECK: %[[VAL_23:.*]] = tosa.const_shape {value = dense<[1, 4, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_24:.*]] = tosa.reshape %[[VAL_13]], %[[VAL_23]] : (tensor<4x1xf32>, !tosa.shape<3>) -> tensor<1x4x1xf32> +// CHECK: %[[VAL_25:.*]] = tosa.const_shape {value = dense<[1, 4, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_26:.*]] = tosa.reshape %[[VAL_15]], %[[VAL_25]] : (tensor<4x1xf32>, !tosa.shape<3>) -> tensor<1x4x1xf32> +// CHECK: %[[VAL_27:.*]] = tosa.sub %[[VAL_1]], %[[VAL_18]] : (tensor<10x4x3xf32>, tensor<1x4x1xf32>) -> tensor<10x4x3xf32> +// CHECK: %[[VAL_28:.*]] = tosa.add %[[VAL_20]], %[[VAL_22]] : (tensor<1x4x1xf32>, tensor<1x1x1xf32>) -> tensor<1x4x1xf32> +// CHECK: %[[VAL_29:.*]] = tosa.rsqrt %[[VAL_28]] : (tensor<1x4x1xf32>) -> tensor<1x4x1xf32> +// CHECK: %[[VAL_30:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_31:.*]] = tosa.mul %[[VAL_27]], %[[VAL_29]], %[[VAL_30]] : (tensor<10x4x3xf32>, tensor<1x4x1xf32>, tensor<1xi8>) -> tensor<10x4x3xf32> +// CHECK: %[[VAL_32:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_33:.*]] = tosa.mul %[[VAL_31]], %[[VAL_24]], %[[VAL_32]] : (tensor<10x4x3xf32>, tensor<1x4x1xf32>, tensor<1xi8>) -> tensor<10x4x3xf32> +// CHECK: %[[VAL_34:.*]] = tosa.add %[[VAL_33]], %[[VAL_26]] : (tensor<10x4x3xf32>, tensor<1x4x1xf32>) -> tensor<10x4x3xf32> +// CHECK: %[[VAL_35:.*]] = torch_c.from_builtin_tensor %[[VAL_34]] : tensor<10x4x3xf32> -> !torch.vtensor<[10,4,3],f32> +// CHECK: return %[[VAL_35]] : !torch.vtensor<[10,4,3],f32> // CHECK: } func.func @torch.aten.native_batch_norm$basic(%arg0: !torch.vtensor<[10,4,3],f32> ) -> !torch.vtensor<[10,4,3],f32> { %0 = torch.vtensor.literal(dense<[5.000000e-01, 4.000000e-01, 3.000000e-01, 6.000000e-01]> : tensor<4xf32>) : !torch.vtensor<[4],f32> @@ -826,17 +976,18 @@ func.func @torch.aten.native_batch_norm$basic(%arg0: !torch.vtensor<[10,4,3],f32 // ----- -// CHECK-LABEL: func.func @forward( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[10,3,8,9,3,4],f32>) -> !torch.vtensor<[10,3,?,4],f32> { +// CHECK-LABEL: func.func @torch.aten.flatten.using_ints$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[10,3,8,9,3,4],f32>) -> !torch.vtensor<[10,3,?,4],f32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[10,3,8,9,3,4],f32> -> tensor<10x3x8x9x3x4xf32> // CHECK: %[[VAL_2:.*]] = torch.constant.int 4 // CHECK: %[[VAL_3:.*]] = torch.constant.int 2 -// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array} : (tensor<10x3x8x9x3x4xf32>) -> tensor<10x3x216x4xf32> -// CHECK: %[[VAL_5:.*]] = tensor.cast %[[VAL_4]] : tensor<10x3x216x4xf32> to tensor<10x3x?x4xf32> -// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<10x3x?x4xf32> -> !torch.vtensor<[10,3,?,4],f32> -// CHECK: return %[[VAL_6]] : !torch.vtensor<[10,3,?,4],f32> +// CHECK: %[[VAL_4:.*]] = tosa.const_shape {value = dense<[10, 3, 216, 4]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_5:.*]] = tosa.reshape %[[VAL_1]], %[[VAL_4]] : (tensor<10x3x8x9x3x4xf32>, !tosa.shape<4>) -> tensor<10x3x216x4xf32> +// CHECK: %[[VAL_6:.*]] = tensor.cast %[[VAL_5]] : tensor<10x3x216x4xf32> to tensor<10x3x?x4xf32> +// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<10x3x?x4xf32> -> !torch.vtensor<[10,3,?,4],f32> +// CHECK: return %[[VAL_7]] : !torch.vtensor<[10,3,?,4],f32> // CHECK: } -func.func @forward(%arg0: !torch.vtensor<[10,3,8,9,3,4],f32> ) -> !torch.vtensor<[10,3,?,4],f32> { +func.func @torch.aten.flatten.using_ints$basic(%arg0: !torch.vtensor<[10,3,8,9,3,4],f32> ) -> !torch.vtensor<[10,3,?,4],f32> { %int4 = torch.constant.int 4 %int2 = torch.constant.int 2 %0 = torch.aten.flatten.using_ints %arg0, %int2, %int4 : !torch.vtensor<[10,3,8,9,3,4],f32>, !torch.int, !torch.int -> !torch.vtensor<[10,3,?,4],f32> @@ -845,18 +996,19 @@ func.func @forward(%arg0: !torch.vtensor<[10,3,8,9,3,4],f32> ) -> !torch.vtensor // ----- -// CHECK-LABEL: func.func @forward( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,6,4],f32>) -> !torch.vtensor<[1,2,3,4],f32> { -// CHECK: %[[VAL:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,6,4],f32> -> tensor<1x6x4xf32> -// CHECK: %[[VAL_1:.*]] = torch.constant.int 1 -// CHECK: %[[VAL_2:.*]] = torch.constant.int 2 -// CHECK: %[[VAL_3:.*]] = torch.constant.int 3 -// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %[[VAL_2]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[VAL_5:.*]] = tosa.reshape %[[VAL]] {new_shape = array} : (tensor<1x6x4xf32>) -> tensor<1x2x3x4xf32> -// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<1x2x3x4xf32> -> !torch.vtensor<[1,2,3,4],f32> -// CHECK: return %[[VAL_6]] : !torch.vtensor<[1,2,3,4],f32> +// CHECK-LABEL: func.func @torch.aten.unflatten.int$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,6,4],f32>) -> !torch.vtensor<[1,2,3,4],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,6,4],f32> -> tensor<1x6x4xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_3:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_4:.*]] = torch.constant.int 3 +// CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_4]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_6:.*]] = tosa.const_shape {value = dense<[1, 2, 3, 4]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_7:.*]] = tosa.reshape %[[VAL_1]], %[[VAL_6]] : (tensor<1x6x4xf32>, !tosa.shape<4>) -> tensor<1x2x3x4xf32> +// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<1x2x3x4xf32> -> !torch.vtensor<[1,2,3,4],f32> +// CHECK: return %[[VAL_8]] : !torch.vtensor<[1,2,3,4],f32> // CHECK: } -func.func @forward(%arg0: !torch.vtensor<[1,6,4],f32> ) -> !torch.vtensor<[1,2,3,4],f32> { +func.func @torch.aten.unflatten.int$basic(%arg0: !torch.vtensor<[1,6,4],f32> ) -> !torch.vtensor<[1,2,3,4],f32> { %int1 = torch.constant.int 1 %int2 = torch.constant.int 2 %int3 = torch.constant.int 3 @@ -880,31 +1032,42 @@ func.func @forward(%arg0: !torch.vtensor<[1,6,4],f32> ) -> !torch.vtensor<[1,2,3 // CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_8]], %[[VAL_8]], %[[VAL_7]] : (!torch.int, !torch.int, !torch.int) -> !torch.list // CHECK: %[[VAL_10:.*]] = "tosa.const"() <{value = dense<1.200000e+01> : tensor<1xf32>}> : () -> tensor<1xf32> // CHECK: %[[VAL_11:.*]] = tosa.reciprocal %[[VAL_10]] : (tensor<1xf32>) -> tensor<1xf32> -// CHECK: %[[VAL_12:.*]] = tosa.reshape %[[VAL_11]] {new_shape = array} : (tensor<1xf32>) -> tensor<1x1x1x1xf32> -// CHECK: %[[VAL_13:.*]] = tosa.reduce_sum %[[VAL_5]] {axis = 3 : i32} : (tensor<5x2x2x3xf32>) -> tensor<5x2x2x1xf32> -// CHECK: %[[VAL_14:.*]] = tosa.reduce_sum %[[VAL_13]] {axis = 2 : i32} : (tensor<5x2x2x1xf32>) -> tensor<5x2x1x1xf32> -// CHECK: %[[VAL_15:.*]] = tosa.reduce_sum %[[VAL_14]] {axis = 1 : i32} : (tensor<5x2x1x1xf32>) -> tensor<5x1x1x1xf32> -// CHECK: %[[VAL_16:.*]] = tosa.reshape %[[VAL_15]] {new_shape = array} : (tensor<5x1x1x1xf32>) -> tensor<5x1x1x1xf32> -// CHECK: %[[VAL_17:.*]] = tosa.mul %[[VAL_16]], %[[VAL_12]] {shift = 0 : i8} : (tensor<5x1x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<5x1x1x1xf32> -// CHECK: %[[VAL_18:.*]] = tosa.sub %[[VAL_5]], %[[VAL_17]] : (tensor<5x2x2x3xf32>, tensor<5x1x1x1xf32>) -> tensor<5x2x2x3xf32> -// CHECK: %[[VAL_19:.*]] = tosa.mul %[[VAL_18]], %[[VAL_18]] {shift = 0 : i8} : (tensor<5x2x2x3xf32>, tensor<5x2x2x3xf32>) -> tensor<5x2x2x3xf32> -// CHECK: %[[VAL_20:.*]] = tosa.reduce_sum %[[VAL_19]] {axis = 3 : i32} : (tensor<5x2x2x3xf32>) -> tensor<5x2x2x1xf32> -// CHECK: %[[VAL_21:.*]] = tosa.reduce_sum %[[VAL_20]] {axis = 2 : i32} : (tensor<5x2x2x1xf32>) -> tensor<5x2x1x1xf32> -// CHECK: %[[VAL_22:.*]] = tosa.reduce_sum %[[VAL_21]] {axis = 1 : i32} : (tensor<5x2x1x1xf32>) -> tensor<5x1x1x1xf32> -// CHECK: %[[VAL_23:.*]] = tosa.reshape %[[VAL_22]] {new_shape = array} : (tensor<5x1x1x1xf32>) -> tensor<5x1x1x1xf32> -// CHECK: %[[VAL_24:.*]] = tosa.mul %[[VAL_23]], %[[VAL_12]] {shift = 0 : i8} : (tensor<5x1x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<5x1x1x1xf32> -// CHECK: %[[VAL_25:.*]] = tosa.reshape %[[VAL_4]] {new_shape = array} : (tensor<2x2x3xf32>) -> tensor<1x2x2x3xf32> -// CHECK: %[[VAL_26:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor<2x2x3xf32>) -> tensor<1x2x2x3xf32> -// CHECK: %[[VAL_27:.*]] = "tosa.const"() <{value = dense<5.000000e-01> : tensor}> : () -> tensor -// CHECK: %[[VAL_28:.*]] = tosa.reshape %[[VAL_27]] {new_shape = array} : (tensor) -> tensor<1x1x1x1xf32> -// CHECK: %[[VAL_29:.*]] = tosa.sub %[[VAL_5]], %[[VAL_17]] : (tensor<5x2x2x3xf32>, tensor<5x1x1x1xf32>) -> tensor<5x2x2x3xf32> -// CHECK: %[[VAL_30:.*]] = tosa.add %[[VAL_24]], %[[VAL_28]] : (tensor<5x1x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<5x1x1x1xf32> -// CHECK: %[[VAL_31:.*]] = tosa.rsqrt %[[VAL_30]] : (tensor<5x1x1x1xf32>) -> tensor<5x1x1x1xf32> -// CHECK: %[[VAL_32:.*]] = tosa.mul %[[VAL_29]], %[[VAL_31]] {shift = 0 : i8} : (tensor<5x2x2x3xf32>, tensor<5x1x1x1xf32>) -> tensor<5x2x2x3xf32> -// CHECK: %[[VAL_33:.*]] = tosa.mul %[[VAL_32]], %[[VAL_25]] {shift = 0 : i8} : (tensor<5x2x2x3xf32>, tensor<1x2x2x3xf32>) -> tensor<5x2x2x3xf32> -// CHECK: %[[VAL_34:.*]] = tosa.add %[[VAL_33]], %[[VAL_26]] : (tensor<5x2x2x3xf32>, tensor<1x2x2x3xf32>) -> tensor<5x2x2x3xf32> -// CHECK: %[[VAL_35:.*]] = torch_c.from_builtin_tensor %[[VAL_34]] : tensor<5x2x2x3xf32> -> !torch.vtensor<[5,2,2,3],f32> -// CHECK: return %[[VAL_35]] : !torch.vtensor<[5,2,2,3],f32> +// CHECK: %[[VAL_12:.*]] = tosa.const_shape {value = dense<1> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_11]], %[[VAL_12]] : (tensor<1xf32>, !tosa.shape<4>) -> tensor<1x1x1x1xf32> +// CHECK: %[[VAL_14:.*]] = tosa.reduce_sum %[[VAL_5]] {axis = 3 : i32} : (tensor<5x2x2x3xf32>) -> tensor<5x2x2x1xf32> +// CHECK: %[[VAL_15:.*]] = tosa.reduce_sum %[[VAL_14]] {axis = 2 : i32} : (tensor<5x2x2x1xf32>) -> tensor<5x2x1x1xf32> +// CHECK: %[[VAL_16:.*]] = tosa.reduce_sum %[[VAL_15]] {axis = 1 : i32} : (tensor<5x2x1x1xf32>) -> tensor<5x1x1x1xf32> +// CHECK: %[[VAL_17:.*]] = tosa.const_shape {value = dense<[5, 1, 1, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_18:.*]] = tosa.reshape %[[VAL_16]], %[[VAL_17]] : (tensor<5x1x1x1xf32>, !tosa.shape<4>) -> tensor<5x1x1x1xf32> +// CHECK: %[[VAL_19:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_20:.*]] = tosa.mul %[[VAL_18]], %[[VAL_13]], %[[VAL_19]] : (tensor<5x1x1x1xf32>, tensor<1x1x1x1xf32>, tensor<1xi8>) -> tensor<5x1x1x1xf32> +// CHECK: %[[VAL_21:.*]] = tosa.sub %[[VAL_5]], %[[VAL_20]] : (tensor<5x2x2x3xf32>, tensor<5x1x1x1xf32>) -> tensor<5x2x2x3xf32> +// CHECK: %[[VAL_22:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_23:.*]] = tosa.mul %[[VAL_21]], %[[VAL_21]], %[[VAL_22]] : (tensor<5x2x2x3xf32>, tensor<5x2x2x3xf32>, tensor<1xi8>) -> tensor<5x2x2x3xf32> +// CHECK: %[[VAL_24:.*]] = tosa.reduce_sum %[[VAL_23]] {axis = 3 : i32} : (tensor<5x2x2x3xf32>) -> tensor<5x2x2x1xf32> +// CHECK: %[[VAL_25:.*]] = tosa.reduce_sum %[[VAL_24]] {axis = 2 : i32} : (tensor<5x2x2x1xf32>) -> tensor<5x2x1x1xf32> +// CHECK: %[[VAL_26:.*]] = tosa.reduce_sum %[[VAL_25]] {axis = 1 : i32} : (tensor<5x2x1x1xf32>) -> tensor<5x1x1x1xf32> +// CHECK: %[[VAL_27:.*]] = tosa.const_shape {value = dense<[5, 1, 1, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_28:.*]] = tosa.reshape %[[VAL_26]], %[[VAL_27]] : (tensor<5x1x1x1xf32>, !tosa.shape<4>) -> tensor<5x1x1x1xf32> +// CHECK: %[[VAL_29:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_30:.*]] = tosa.mul %[[VAL_28]], %[[VAL_13]], %[[VAL_29]] : (tensor<5x1x1x1xf32>, tensor<1x1x1x1xf32>, tensor<1xi8>) -> tensor<5x1x1x1xf32> +// CHECK: %[[VAL_31:.*]] = tosa.const_shape {value = dense<[1, 2, 2, 3]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_32:.*]] = tosa.reshape %[[VAL_4]], %[[VAL_31]] : (tensor<2x2x3xf32>, !tosa.shape<4>) -> tensor<1x2x2x3xf32> +// CHECK: %[[VAL_33:.*]] = tosa.const_shape {value = dense<[1, 2, 2, 3]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_34:.*]] = tosa.reshape %[[VAL_3]], %[[VAL_33]] : (tensor<2x2x3xf32>, !tosa.shape<4>) -> tensor<1x2x2x3xf32> +// CHECK: %[[VAL_35:.*]] = "tosa.const"() <{value = dense<5.000000e-01> : tensor}> : () -> tensor +// CHECK: %[[VAL_36:.*]] = tosa.const_shape {value = dense<1> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_37:.*]] = tosa.reshape %[[VAL_35]], %[[VAL_36]] : (tensor, !tosa.shape<4>) -> tensor<1x1x1x1xf32> +// CHECK: %[[VAL_38:.*]] = tosa.sub %[[VAL_5]], %[[VAL_20]] : (tensor<5x2x2x3xf32>, tensor<5x1x1x1xf32>) -> tensor<5x2x2x3xf32> +// CHECK: %[[VAL_39:.*]] = tosa.add %[[VAL_30]], %[[VAL_37]] : (tensor<5x1x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<5x1x1x1xf32> +// CHECK: %[[VAL_40:.*]] = tosa.rsqrt %[[VAL_39]] : (tensor<5x1x1x1xf32>) -> tensor<5x1x1x1xf32> +// CHECK: %[[VAL_41:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_42:.*]] = tosa.mul %[[VAL_38]], %[[VAL_40]], %[[VAL_41]] : (tensor<5x2x2x3xf32>, tensor<5x1x1x1xf32>, tensor<1xi8>) -> tensor<5x2x2x3xf32> +// CHECK: %[[VAL_43:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_44:.*]] = tosa.mul %[[VAL_42]], %[[VAL_32]], %[[VAL_43]] : (tensor<5x2x2x3xf32>, tensor<1x2x2x3xf32>, tensor<1xi8>) -> tensor<5x2x2x3xf32> +// CHECK: %[[VAL_45:.*]] = tosa.add %[[VAL_44]], %[[VAL_34]] : (tensor<5x2x2x3xf32>, tensor<1x2x2x3xf32>) -> tensor<5x2x2x3xf32> +// CHECK: %[[VAL_46:.*]] = torch_c.from_builtin_tensor %[[VAL_45]] : tensor<5x2x2x3xf32> -> !torch.vtensor<[5,2,2,3],f32> +// CHECK: return %[[VAL_46]] : !torch.vtensor<[5,2,2,3],f32> // CHECK: } func.func @torch.aten.native_layer_norm$basic(%arg0: !torch.vtensor<[5,2,2,3],f32> , %arg1: !torch.vtensor<[2,2,3],f32> , %arg2: !torch.vtensor<[2,2,3],f32> ) -> !torch.vtensor<[5,2,2,3],f32> { %float5.000000e-01 = torch.constant.float 5.000000e-01 @@ -950,7 +1113,7 @@ func.func @torch.aten.logical_or$basic(%arg0: !torch.vtensor<[?,?],i1>, %arg1: ! // ----- -// CHECK-LABEL: func.func @forward( +// CHECK-LABEL: func.func @torch.aten.permute$basic( // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4,2],f32>) -> !torch.vtensor<[3,2,4],f32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4,2],f32> -> tensor<3x4x2xf32> // CHECK: %[[VAL_2:.*]] = torch.constant.int 1 @@ -962,7 +1125,7 @@ func.func @torch.aten.logical_or$basic(%arg0: !torch.vtensor<[?,?],i1>, %arg1: ! // CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<3x2x4xf32> -> !torch.vtensor<[3,2,4],f32> // CHECK: return %[[VAL_8]] : !torch.vtensor<[3,2,4],f32> // CHECK: } -func.func @forward(%arg0: !torch.vtensor<[3,4,2],f32> ) -> !torch.vtensor<[3,2,4],f32> { +func.func @torch.aten.permute$basic(%arg0: !torch.vtensor<[3,4,2],f32> ) -> !torch.vtensor<[3,2,4],f32> { %int1 = torch.constant.int 1 %int2 = torch.constant.int 2 %int0 = torch.constant.int 0 @@ -995,9 +1158,10 @@ func.func @torch.aten.bitwise_and.Tensor$basic(%arg0: !torch.vtensor<[?,?],si32> // CHECK: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<0.693147182> : tensor<1x1xf32>}> : () -> tensor<1x1xf32> // CHECK: %[[VAL_3:.*]] = tosa.reciprocal %[[VAL_2]] : (tensor<1x1xf32>) -> tensor<1x1xf32> // CHECK: %[[VAL_4:.*]] = tosa.log %[[VAL_1]] : (tensor) -> tensor -// CHECK: %[[VAL_5:.*]] = tosa.mul %[[VAL_4]], %[[VAL_3]] {shift = 0 : i8} : (tensor, tensor<1x1xf32>) -> tensor -// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor -> !torch.vtensor<[?,?],f32> -// CHECK: return %[[VAL_6]] : !torch.vtensor<[?,?],f32> +// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_4]], %[[VAL_3]], %[[VAL_5]] : (tensor, tensor<1x1xf32>, tensor<1xi8>) -> tensor +// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_7]] : !torch.vtensor<[?,?],f32> // CHECK: } func.func @torch.aten.log2$basic(%arg0: !torch.vtensor<[?,?],f32> ) -> !torch.vtensor<[?,?],f32> { %0 = torch.aten.log2 %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> @@ -1031,11 +1195,11 @@ func.func @torch.aten.zeros$basic() -> !torch.vtensor<[3,4],f32> { // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,3],si32>) -> !torch.vtensor<[4,3,1],si32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,3],si32> -> tensor<4x3xi32> // CHECK: %[[VAL_2:.*]] = torch.constant.int 2 -// CHECK: %[[VAL_3:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array} : (tensor<4x3xi32>) -> tensor<4x3x1xi32> -// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<4x3x1xi32> -> !torch.vtensor<[4,3,1],si32> -// CHECK: return %[[VAL_4]] : !torch.vtensor<[4,3,1],si32> +// CHECK: %[[VAL_3:.*]] = tosa.const_shape {value = dense<[4, 3, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_1]], %[[VAL_3]] : (tensor<4x3xi32>, !tosa.shape<3>) -> tensor<4x3x1xi32> +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<4x3x1xi32> -> !torch.vtensor<[4,3,1],si32> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[4,3,1],si32> // CHECK: } - func.func @torch.aten.unsqueeze$basic(%arg0: !torch.vtensor<[4,3],si32> ) -> !torch.vtensor<[4,3,1],si32> { %int2 = torch.constant.int 2 %0 = torch.aten.unsqueeze %arg0, %int2 : !torch.vtensor<[4,3],si32>, !torch.int -> !torch.vtensor<[4,3,1],si32> @@ -1045,12 +1209,13 @@ func.func @torch.aten.unsqueeze$basic(%arg0: !torch.vtensor<[4,3],si32> ) -> !to // ----- // CHECK-LABEL: func.func @torch.aten.unsqueeze$negative_dim( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,3],si32>) -> !torch.vtensor<[4,3,1],si32> { +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,3],si32>) -> !torch.vtensor<[4,3,1],si32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,3],si32> -> tensor<4x3xi32> // CHECK: %[[VAL_2:.*]] = torch.constant.int -1 -// CHECK: %[[VAL_3:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array} : (tensor<4x3xi32>) -> tensor<4x3x1xi32> -// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<4x3x1xi32> -> !torch.vtensor<[4,3,1],si32> -// CHECK: return %[[VAL_4]] : !torch.vtensor<[4,3,1],si32> +// CHECK: %[[VAL_3:.*]] = tosa.const_shape {value = dense<[4, 3, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_1]], %[[VAL_3]] : (tensor<4x3xi32>, !tosa.shape<3>) -> tensor<4x3x1xi32> +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<4x3x1xi32> -> !torch.vtensor<[4,3,1],si32> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[4,3,1],si32> // CHECK: } func.func @torch.aten.unsqueeze$negative_dim(%arg0: !torch.vtensor<[4,3],si32> ) -> !torch.vtensor<[4,3,1],si32> { %int2 = torch.constant.int -1 @@ -1095,13 +1260,10 @@ func.func @torch.aten.ones$basic() -> !torch.vtensor<[3,4],f32> { // ----- // CHECK-LABEL: func.func @torch.aten.dropout$basic( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { -// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[VAL_2:.*]] = torch.constant.float 0.000000e+00 -// CHECK: %[[VAL_3:.*]] = torch.constant.bool false -// CHECK: %[[VAL_4:.*]] = tosa.cast %[[VAL_1]] : (tensor) -> tensor -// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],f32> -// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],f32> +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[VAL_1:.*]] = torch.constant.float 0.000000e+00 +// CHECK: %[[VAL_2:.*]] = torch.constant.bool false +// CHECK: return %[[VAL_0]] : !torch.vtensor<[?,?],f32> // CHECK: } func.func @torch.aten.dropout$basic(%arg0: !torch.vtensor<[?,?],f32> ) -> !torch.vtensor<[?,?],f32> { %float0.000000e00 = torch.constant.float 0.000000e+00 @@ -1147,18 +1309,21 @@ func.func @torch.aten.avg_pool2d$basic(%arg0: !torch.vtensor<[1,512,7,7],f32> ) // ----- -// CHECK-LABEL: @torch.aten.max.dim$basic( -// CHECK-SAME: %[[ARG0:.*]]: tensor<3x2x3xf32>) -// CHECK-DAG: %[[VAL_0:.*]] = torch_c.from_builtin_tensor %[[ARG0]] : tensor<3x2x3xf32> -> !torch.vtensor<[3,2,3],f32> -// CHECK-DAG: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,2,3],f32> -> tensor<3x2x3xf32> -// CHECK-DAG: %[[VAL_TRUE:.*]] = torch.constant.bool true -// CHECK-DAG: %[[VAL_I2:.*]] = torch.constant.int 2 -// CHECK-DAG: %[[VAL_2:.*]] = tosa.reduce_max %[[VAL_1]] {axis = 2 : i32} : (tensor<3x2x3xf32>) -> tensor<3x2x1xf32> -// CHECK-DAG: %[[VAL_3:.*]] = tosa.argmax %[[VAL_1]] {axis = 2 : i32} : (tensor<3x2x3xf32>) -> tensor<3x2xi64> -// CHECK-DAG: %[[VAL_4:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor<3x2xi64>) -> tensor<3x2x1xi64> -// CHECK-DAG: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<3x2x1xf32> -> !torch.vtensor<[3,2,1],f32> -// CHECK-DAG: %[[VAL_6:.*]] = torch_c.to_builtin_tensor %[[VAL_5]] : !torch.vtensor<[3,2,1],f32> -> tensor<3x2x1xf32> -// CHECK: return %[[VAL_6]] : tensor<3x2x1xf32> +// CHECK-LABEL: func.func @torch.aten.max.dim$basic( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<3x2x3xf32>) -> tensor<3x2x1xf32> { +// CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<3x2x3xf32> -> !torch.vtensor<[3,2,3],f32> +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[3,2,3],f32> -> tensor<3x2x3xf32> +// CHECK: %[[VAL_3:.*]] = torch.constant.bool true +// CHECK: %[[VAL_4:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_5:.*]] = tosa.const_shape {value = dense<[3, 2]> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_6:.*]] = tosa.reduce_max %[[VAL_2]] {axis = 2 : i32} : (tensor<3x2x3xf32>) -> tensor<3x2x1xf32> +// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<3x2x1xf32> -> !torch.vtensor<[3,2,1],f32> +// CHECK: %[[VAL_8:.*]] = tosa.argmax %[[VAL_2]] {axis = 2 : i32} : (tensor<3x2x3xf32>) -> tensor<3x2xi64> +// CHECK: %[[VAL_9:.*]] = tosa.const_shape {value = dense<[3, 2, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_10:.*]] = tosa.reshape %[[VAL_8]], %[[VAL_9]] : (tensor<3x2xi64>, !tosa.shape<3>) -> tensor<3x2x1xi64> +// CHECK: %[[VAL_11:.*]] = torch_c.to_builtin_tensor %[[VAL_7]] : !torch.vtensor<[3,2,1],f32> -> tensor<3x2x1xf32> +// CHECK: return %[[VAL_11]] : tensor<3x2x1xf32> +// CHECK: } func.func @torch.aten.max.dim$basic(%arg0: tensor<3x2x3xf32>) -> tensor<3x2x1xf32> { %0 = torch_c.from_builtin_tensor %arg0 : tensor<3x2x3xf32> -> !torch.vtensor<[3,2,3],f32> %true = torch.constant.bool true @@ -1181,14 +1346,14 @@ func.func @torch.vtensor.literal_si64$basic() -> !torch.vtensor<[1,512],si64> { // ----- // CHECK-LABEL: func.func @torch.aten.arange.start_step() -> !torch.vtensor<[5],si64> { -// CHECK: %[[NONE:.*]] = torch.constant.none -// CHECK: %[[CST0:.*]] = torch.constant.int 0 -// CHECK: %[[CST5:.*]] = torch.constant.int 5 -// CHECK: %[[CST1:.*]] = torch.constant.int 1 -// CHECK: %[[VAL_0:.*]] = "tosa.const"() <{value = dense<[0, 1, 2, 3, 4]> : tensor<5xi64>}> : () -> tensor<5xi64> -// CHECK: %[[VAL_1:.*]] = tosa.cast %[[VAL_0]] : (tensor<5xi64>) -> tensor<5xi64> -// CHECK: %[[VAL_2:.*]] = torch_c.from_builtin_tensor %1 : tensor<5xi64> -> !torch.vtensor<[5],si64> -// CHECK: return %[[VAL_2]] : !torch.vtensor<[5],si64> +// CHECK: %[[VAL_0:.*]] = torch.constant.none +// CHECK: %[[VAL_1:.*]] = torch.constant.int 0 +// CHECK: %[[VAL_2:.*]] = torch.constant.int 5 +// CHECK: %[[VAL_3:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<[0, 1, 2, 3, 4]> : tensor<5xi64>}> : () -> tensor<5xi64> +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<5xi64> -> !torch.vtensor<[5],si64> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[5],si64> +// CHECK: } func.func @torch.aten.arange.start_step() -> !torch.vtensor<[5],si64> { %none = torch.constant.none %int0 = torch.constant.int 0 @@ -1212,23 +1377,25 @@ func.func @torch.prim.NumToTensor.Scalar() -> !torch.vtensor<[],si64> { // ----- // CHECK-LABEL: func.func @torch.aten.copy( -// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[1,1,5,5],ui8>) -> !torch.vtensor<[1,1,5,5],i1> { -// CHECK: %[[INP:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[1,1,5,5],ui8> -> tensor<1x1x5x5xi8> -// CHECK: %[[CST5:.*]] = torch.constant.int 5 -// CHECK: %[[CST1:.*]] = torch.constant.int 1 -// CHECK: %[[CST11:.*]] = torch.constant.int 11 -// CHECK: %[[NONE:.*]] = torch.constant.none -// CHECK: %[[FALSE:.*]] = torch.constant.bool false -// CHECK: %[[CST0:.*]] = torch.constant.int 0 -// CHECK: %[[VAL_0:.*]] = "tosa.const"() <{value = dense<0> : tensor}> : () -> tensor -// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<0> : tensor}> : () -> tensor -// CHECK: %[[VAL_2:.*]] = tosa.equal %[[VAL_0]], %[[VAL_1]] : (tensor, tensor) -> tensor -// CHECK: %[[VAL_3:.*]] = tosa.logical_not %[[VAL_2]] : (tensor) -> tensor -// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<0> : tensor<1x1x5x5xi8>}> : () -> tensor<1x1x5x5xi8> -// CHECK: %[[VAL_5:.*]] = tosa.equal %[[INP]], %[[VAL_4]] : (tensor<1x1x5x5xi8>, tensor<1x1x5x5xi8>) -> tensor<1x1x5x5xi1> -// CHECK: %[[VAL_6:.*]] = tosa.logical_not %[[VAL_5]] : (tensor<1x1x5x5xi1>) -> tensor<1x1x5x5xi1> -// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<1x1x5x5xi1> -> !torch.vtensor<[1,1,5,5],i1> -// CHECK: return %[[VAL_7]] : !torch.vtensor<[1,1,5,5],i1> +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,1,5,5],ui8>) -> !torch.vtensor<[1,1,5,5],i1> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,1,5,5],ui8> -> tensor<1x1x5x5xi8> +// CHECK: %[[VAL_2:.*]] = torch.constant.int 5 +// CHECK: %[[VAL_3:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_4:.*]] = torch.constant.int 11 +// CHECK: %[[VAL_5:.*]] = torch.constant.none +// CHECK: %[[VAL_6:.*]] = torch.constant.bool false +// CHECK: %[[VAL_7:.*]] = torch.constant.int 0 +// CHECK: %[[VAL_8:.*]] = "tosa.const"() <{value = dense<0> : tensor}> : () -> tensor +// CHECK: %[[VAL_9:.*]] = tosa.cast %[[VAL_8]] : (tensor) -> tensor +// CHECK: %[[VAL_10:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]], %[[VAL_2]], %[[VAL_2]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_11:.*]] = tosa.const_shape {value = dense<1> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_12:.*]] = tosa.reshape %[[VAL_9]], %[[VAL_11]] : (tensor, !tosa.shape<4>) -> tensor<1x1x1x1xi1> +// CHECK: %[[VAL_13:.*]] = tosa.const_shape {value = dense<[1, 1, 5, 5]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_14:.*]] = tosa.tile %[[VAL_12]], %[[VAL_13]] : (tensor<1x1x1x1xi1>, !tosa.shape<4>) -> tensor<1x1x5x5xi1> +// CHECK: %[[VAL_15:.*]] = tosa.cast %[[VAL_1]] : (tensor<1x1x5x5xi8>) -> tensor<1x1x5x5xi1> +// CHECK: %[[VAL_16:.*]] = torch_c.from_builtin_tensor %[[VAL_15]] : tensor<1x1x5x5xi1> -> !torch.vtensor<[1,1,5,5],i1> +// CHECK: return %[[VAL_16]] : !torch.vtensor<[1,1,5,5],i1> +// CHECK: } func.func @torch.aten.copy(%arg0: !torch.vtensor<[1,1,5,5],ui8>) -> !torch.vtensor<[1,1,5,5],i1> { %int5 = torch.constant.int 5 %int1 = torch.constant.int 1 @@ -1245,18 +1412,17 @@ func.func @torch.aten.copy(%arg0: !torch.vtensor<[1,1,5,5],ui8>) -> !torch.vtens } // ----- -// CHECK-LABEL: func.func @torch.aten.to.dtype( -// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[3,5],si64>) -> !torch.vtensor<[3,5],i1> { -// CHECK: %[[INP:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[3,5],si64> -> tensor<3x5xi64> -// CHECK: %[[CST11:.*]] = torch.constant.int 11 -// CHECK: %[[NONE:.*]] = torch.constant.none -// CHECK: %[[FALSE:.*]] = torch.constant.bool false -// CHECK: %[[VAL_0:.*]] = "tosa.const"() <{value = dense<0> : tensor<3x5xi64>}> : () -> tensor<3x5xi64> -// CHECK: %[[VAL_1:.*]] = tosa.equal %[[INP]], %[[VAL_0]] : (tensor<3x5xi64>, tensor<3x5xi64>) -> tensor<3x5xi1> -// CHECK: %[[VAL_2:.*]] = tosa.logical_not %[[VAL_1]] : (tensor<3x5xi1>) -> tensor<3x5xi1> -// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<3x5xi1> -> !torch.vtensor<[3,5],i1> -// CHECK: return %[[VAL_3]] : !torch.vtensor<[3,5],i1> -func.func @torch.aten.to.dtype(%arg0: !torch.vtensor<[3,5],si64>) -> !torch.vtensor<[3,5],i1> { +// CHECK-LABEL: func.func @torch.aten.to.dtype$toBool( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,5],si64>) -> !torch.vtensor<[3,5],i1> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,5],si64> -> tensor<3x5xi64> +// CHECK: %[[VAL_2:.*]] = torch.constant.int 11 +// CHECK: %[[VAL_3:.*]] = torch.constant.none +// CHECK: %[[VAL_4:.*]] = torch.constant.bool false +// CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_1]] : (tensor<3x5xi64>) -> tensor<3x5xi1> +// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<3x5xi1> -> !torch.vtensor<[3,5],i1> +// CHECK: return %[[VAL_6]] : !torch.vtensor<[3,5],i1> +// CHECK: } +func.func @torch.aten.to.dtype$toBool(%arg0: !torch.vtensor<[3,5],si64>) -> !torch.vtensor<[3,5],i1> { %int11 = torch.constant.int 11 %none = torch.constant.none %false = torch.constant.bool false @@ -1265,7 +1431,7 @@ func.func @torch.aten.to.dtype(%arg0: !torch.vtensor<[3,5],si64>) -> !torch.vten } // ----- -// CHECK-LABEL: func.func @torch.aten.to.dtype( +// CHECK-LABEL: func.func @torch.aten.to.dtype$fromBool( // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,128],i1>) -> !torch.vtensor<[1,128],si64> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,128],i1> -> tensor<1x128xi1> // CHECK: %[[VAL_2:.*]] = torch.constant.int 4 @@ -1275,7 +1441,7 @@ func.func @torch.aten.to.dtype(%arg0: !torch.vtensor<[3,5],si64>) -> !torch.vten // CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<1x128xi64> -> !torch.vtensor<[1,128],si64> // CHECK: return %[[VAL_6]] : !torch.vtensor<[1,128],si64> // CHECK: } -func.func @torch.aten.to.dtype(%arg0: !torch.vtensor<[1,128],i1>) -> !torch.vtensor<[1,128],si64> { +func.func @torch.aten.to.dtype$fromBool(%arg0: !torch.vtensor<[1,128],i1>) -> !torch.vtensor<[1,128],si64> { %int4 = torch.constant.int 4 %none = torch.constant.none %false = torch.constant.bool false @@ -1293,12 +1459,13 @@ func.func @torch.aten.to.dtype(%arg0: !torch.vtensor<[1,128],i1>) -> !torch.vten // CHECK: %[[VAL_5:.*]] = tosa.floor %[[VAL_1]] : (tensor<3x5xf32>) -> tensor<3x5xf32> // CHECK: %[[VAL_6:.*]] = tosa.ceil %[[VAL_1]] : (tensor<3x5xf32>) -> tensor<3x5xf32> // CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor}> : () -> tensor -// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_7]] {new_shape = array} : (tensor) -> tensor<1x1xf32> -// CHECK: %[[VAL_9:.*]] = tosa.greater %[[VAL_8]], %[[VAL_1]] : (tensor<1x1xf32>, tensor<3x5xf32>) -> tensor<3x5xi1> -// CHECK: %[[VAL_10:.*]] = tosa.select %[[VAL_9]], %[[VAL_6]], %[[VAL_5]] : (tensor<3x5xi1>, tensor<3x5xf32>, tensor<3x5xf32>) -> tensor<3x5xf32> -// CHECK: %[[VAL_11:.*]] = tosa.cast %[[VAL_10]] : (tensor<3x5xf32>) -> tensor<3x5xi64> -// CHECK: %[[VAL_12:.*]] = torch_c.from_builtin_tensor %[[VAL_11]] : tensor<3x5xi64> -> !torch.vtensor<[3,5],si64> -// CHECK: return %[[VAL_12]] : !torch.vtensor<[3,5],si64> +// CHECK: %[[VAL_8:.*]] = tosa.const_shape {value = dense<1> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_9:.*]] = tosa.reshape %[[VAL_7]], %[[VAL_8]] : (tensor, !tosa.shape<2>) -> tensor<1x1xf32> +// CHECK: %[[VAL_10:.*]] = tosa.greater %[[VAL_9]], %[[VAL_1]] : (tensor<1x1xf32>, tensor<3x5xf32>) -> tensor<3x5xi1> +// CHECK: %[[VAL_11:.*]] = tosa.select %[[VAL_10]], %[[VAL_6]], %[[VAL_5]] : (tensor<3x5xi1>, tensor<3x5xf32>, tensor<3x5xf32>) -> tensor<3x5xf32> +// CHECK: %[[VAL_12:.*]] = tosa.cast %[[VAL_11]] : (tensor<3x5xf32>) -> tensor<3x5xi64> +// CHECK: %[[VAL_13:.*]] = torch_c.from_builtin_tensor %[[VAL_12]] : tensor<3x5xi64> -> !torch.vtensor<[3,5],si64> +// CHECK: return %[[VAL_13]] : !torch.vtensor<[3,5],si64> // CHECK: } func.func @torch.aten.to.dtype$floatToInt(%arg0: !torch.vtensor<[3,5],f32>) -> !torch.vtensor<[3,5],si64> { %int4 = torch.constant.int 4 @@ -1317,21 +1484,28 @@ func.func @torch.aten.to.dtype$floatToInt(%arg0: !torch.vtensor<[3,5],f32>) -> ! // CHECK: %[[VAL_4:.*]] = torch.constant.int -1 // CHECK: %[[VAL_5:.*]] = torch.constant.bool false // CHECK: %[[VAL_6:.*]] = tosa.cast %[[VAL_2]] : (tensor<1x4x2xi64>) -> tensor<1x4x2xi32> -// CHECK: %[[VAL_7:.*]] = tosa.reshape %[[VAL_6]] {new_shape = array} : (tensor<1x4x2xi32>) -> tensor<1x4x2x1xi32> -// CHECK: %[[VAL_8:.*]] = "tosa.const"() <{value = dense<0> : tensor<1x4x2x1xi32>}> : () -> tensor<1x4x2x1xi32> -// CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]]]]> : tensor<1x4x2x1xi32>}> : () -> tensor<1x4x2x1xi32> -// CHECK: %[[VAL_10:.*]] = tosa.concat %[[VAL_8]], %[[VAL_9]], %[[VAL_7]] {axis = 3 : i32} : (tensor<1x4x2x1xi32>, tensor<1x4x2x1xi32>, tensor<1x4x2x1xi32>) -> tensor<1x4x2x3xi32> -// CHECK: %[[VAL_11:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor<1x4x3xf32>) -> tensor<1x12x1xf32> -// CHECK: %[[VAL_12:.*]] = tosa.reshape %[[VAL_10]] {new_shape = array} : (tensor<1x4x2x3xi32>) -> tensor<8x3xi32> -// CHECK: %[[VAL_13:.*]] = "tosa.const"() <{value = dense<[12, 3, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> -// CHECK: %[[VAL_14:.*]] = tosa.reshape %[[VAL_13]] {new_shape = array} : (tensor<3xi32>) -> tensor<1x3xi32> -// CHECK: %[[VAL_15:.*]] = tosa.mul %[[VAL_12]], %[[VAL_14]] {shift = 0 : i8} : (tensor<8x3xi32>, tensor<1x3xi32>) -> tensor<8x3xi32> -// CHECK: %[[VAL_16:.*]] = tosa.reduce_sum %[[VAL_15]] {axis = 1 : i32} : (tensor<8x3xi32>) -> tensor<8x1xi32> -// CHECK: %[[VAL_17:.*]] = tosa.reshape %[[VAL_16]] {new_shape = array} : (tensor<8x1xi32>) -> tensor<1x8xi32> -// CHECK: %[[VAL_18:.*]] = tosa.gather %[[VAL_11]], %[[VAL_17]] : (tensor<1x12x1xf32>, tensor<1x8xi32>) -> tensor<1x8x1xf32> -// CHECK: %[[VAL_19:.*]] = tosa.reshape %[[VAL_18]] {new_shape = array} : (tensor<1x8x1xf32>) -> tensor<1x4x2xf32> -// CHECK: %[[VAL_20:.*]] = torch_c.from_builtin_tensor %[[VAL_19]] : tensor<1x4x2xf32> -> !torch.vtensor<[1,4,2],f32> -// CHECK: return %[[VAL_20]] : !torch.vtensor<[1,4,2],f32> +// CHECK: %[[VAL_7:.*]] = tosa.const_shape {value = dense<[1, 4, 2, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_6]], %[[VAL_7]] : (tensor<1x4x2xi32>, !tosa.shape<4>) -> tensor<1x4x2x1xi32> +// CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<0> : tensor<1x4x2x1xi32>}> : () -> tensor<1x4x2x1xi32> +// CHECK: %[[VAL_10:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]]]]> : tensor<1x4x2x1xi32>}> : () -> tensor<1x4x2x1xi32> +// CHECK: %[[VAL_11:.*]] = tosa.concat %[[VAL_9]], %[[VAL_10]], %[[VAL_8]] {axis = 3 : i32} : (tensor<1x4x2x1xi32>, tensor<1x4x2x1xi32>, tensor<1x4x2x1xi32>) -> tensor<1x4x2x3xi32> +// CHECK: %[[VAL_12:.*]] = tosa.const_shape {value = dense<[1, 12, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_3]], %[[VAL_12]] : (tensor<1x4x3xf32>, !tosa.shape<3>) -> tensor<1x12x1xf32> +// CHECK: %[[VAL_14:.*]] = tosa.const_shape {value = dense<[8, 3]> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_15:.*]] = tosa.reshape %[[VAL_11]], %[[VAL_14]] : (tensor<1x4x2x3xi32>, !tosa.shape<2>) -> tensor<8x3xi32> +// CHECK: %[[VAL_16:.*]] = "tosa.const"() <{value = dense<[12, 3, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> +// CHECK: %[[VAL_17:.*]] = tosa.const_shape {value = dense<[1, 3]> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_18:.*]] = tosa.reshape %[[VAL_16]], %[[VAL_17]] : (tensor<3xi32>, !tosa.shape<2>) -> tensor<1x3xi32> +// CHECK: %[[VAL_19:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_20:.*]] = tosa.mul %[[VAL_15]], %[[VAL_18]], %[[VAL_19]] : (tensor<8x3xi32>, tensor<1x3xi32>, tensor<1xi8>) -> tensor<8x3xi32> +// CHECK: %[[VAL_21:.*]] = tosa.reduce_sum %[[VAL_20]] {axis = 1 : i32} : (tensor<8x3xi32>) -> tensor<8x1xi32> +// CHECK: %[[VAL_22:.*]] = tosa.const_shape {value = dense<[1, 8]> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_23:.*]] = tosa.reshape %[[VAL_21]], %[[VAL_22]] : (tensor<8x1xi32>, !tosa.shape<2>) -> tensor<1x8xi32> +// CHECK: %[[VAL_24:.*]] = tosa.gather %[[VAL_13]], %[[VAL_23]] : (tensor<1x12x1xf32>, tensor<1x8xi32>) -> tensor<1x8x1xf32> +// CHECK: %[[VAL_25:.*]] = tosa.const_shape {value = dense<[1, 4, 2]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_26:.*]] = tosa.reshape %[[VAL_24]], %[[VAL_25]] : (tensor<1x8x1xf32>, !tosa.shape<3>) -> tensor<1x4x2xf32> +// CHECK: %[[VAL_27:.*]] = torch_c.from_builtin_tensor %[[VAL_26]] : tensor<1x4x2xf32> -> !torch.vtensor<[1,4,2],f32> +// CHECK: return %[[VAL_27]] : !torch.vtensor<[1,4,2],f32> // CHECK: } func.func @torch.aten.gather(%arg0: !torch.vtensor<[1,4,3],f32>, %arg1: !torch.vtensor<[1,4,2],si64>) -> !torch.vtensor<[1,4,2],f32> { %int-1 = torch.constant.int -1 @@ -1341,21 +1515,23 @@ func.func @torch.aten.gather(%arg0: !torch.vtensor<[1,4,3],f32>, %arg1: !torch.v } // ----- -// CHECK-LABEL: func.func @torch.aten.add$basic( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,2],si32>, -// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[2,2],si32>) -> !torch.vtensor<[2,2],si64> { +// CHECK-LABEL: func.func @torch.aten.add$int( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,2],si32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[2,2],si32>) -> !torch.vtensor<[2,2],si64> { // CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[2,2],si32> -> tensor<2x2xi32> // CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,2],si32> -> tensor<2x2xi32> // CHECK: %[[VAL_4:.*]] = torch.constant.int 1 // CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<1> : tensor}> : () -> tensor -// CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_5]] {new_shape = array} : (tensor) -> tensor<1x1xi32> -// CHECK: %[[VAL_7:.*]] = tosa.mul %[[VAL_2]], %[[VAL_6]] {shift = 0 : i8} : (tensor<2x2xi32>, tensor<1x1xi32>) -> tensor<2x2xi32> -// CHECK: %[[VAL_8:.*]] = tosa.add %[[VAL_3]], %[[VAL_7]] : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> -// CHECK: %[[VAL_9:.*]] = tosa.cast %[[VAL_8]] : (tensor<2x2xi32>) -> tensor<2x2xi64> -// CHECK: %[[VAL_10:.*]] = torch_c.from_builtin_tensor %[[VAL_9]] : tensor<2x2xi64> -> !torch.vtensor<[2,2],si64> -// CHECK: return %[[VAL_10]] : !torch.vtensor<[2,2],si64> -// CHECK: } -func.func @torch.aten.add$basic(%arg0: !torch.vtensor<[2, 2],si32>, %arg1: !torch.vtensor<[2, 2],si32>) -> !torch.vtensor<[2, 2],si64> { +// CHECK: %[[VAL_6:.*]] = tosa.const_shape {value = dense<1> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_7:.*]] = tosa.reshape %[[VAL_5]], %[[VAL_6]] : (tensor, !tosa.shape<2>) -> tensor<1x1xi32> +// CHECK: %[[VAL_8:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_9:.*]] = tosa.mul %[[VAL_2]], %[[VAL_7]], %[[VAL_8]] : (tensor<2x2xi32>, tensor<1x1xi32>, tensor<1xi8>) -> tensor<2x2xi32> +// CHECK: %[[VAL_10:.*]] = tosa.add %[[VAL_3]], %[[VAL_9]] : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> +// CHECK: %[[VAL_11:.*]] = tosa.cast %[[VAL_10]] : (tensor<2x2xi32>) -> tensor<2x2xi64> +// CHECK: %[[VAL_12:.*]] = torch_c.from_builtin_tensor %[[VAL_11]] : tensor<2x2xi64> -> !torch.vtensor<[2,2],si64> +// CHECK: return %[[VAL_12]] : !torch.vtensor<[2,2],si64> +// CHECK: } +func.func @torch.aten.add$int(%arg0: !torch.vtensor<[2, 2],si32>, %arg1: !torch.vtensor<[2, 2],si32>) -> !torch.vtensor<[2, 2],si64> { %int1 = torch.constant.int 1 %0 = torch.aten.add.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[2, 2],si32>, !torch.vtensor<[2, 2],si32>, !torch.int -> !torch.vtensor<[2, 2],si64> return %0 : !torch.vtensor<[2, 2],si64> @@ -1368,16 +1544,19 @@ func.func @torch.aten.add$basic(%arg0: !torch.vtensor<[2, 2],si32>, %arg1: !torc // CHECK: %[[VAL_2:.*]] = torch.constant.int 1 // CHECK: %[[VAL_3:.*]] = torch.constant.int 256 // CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<256> : tensor}> : () -> tensor -// CHECK: %[[VAL_5:.*]] = tosa.reshape %[[VAL_4]] {new_shape = array} : (tensor) -> tensor<1x1x1x1xi64> -// CHECK: %[[VAL_5_cast:.*]] = tosa.cast %[[VAL_5]] : (tensor<1x1x1x1xi64>) -> tensor<1x1x1x1xi32> -// CHECK: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<1> : tensor}> : () -> tensor -// CHECK: %[[VAL_7:.*]] = tosa.reshape %[[VAL_6]] {new_shape = array} : (tensor) -> tensor<1x1x1x1xi32> -// CHECK: %[[VAL_8:.*]] = tosa.mul %[[VAL_5_cast]], %[[VAL_7]] {shift = 0 : i8} : (tensor<1x1x1x1xi32>, tensor<1x1x1x1xi32>) -> tensor<1x1x1x1xi32> -// CHECK: %[[VAL_9:.*]] = tosa.cast %[[VAL_1]] : (tensor<1x1x128x128xi64>) -> tensor<1x1x128x128xi32> -// CHECK: %[[VAL_10:.*]] = tosa.add %[[VAL_9]], %[[VAL_8]] : (tensor<1x1x128x128xi32>, tensor<1x1x1x1xi32>) -> tensor<1x1x128x128xi32> -// CHECK: %[[VAL_11:.*]] = tosa.cast %[[VAL_10]] : (tensor<1x1x128x128xi32>) -> tensor<1x1x128x128xi64> -// CHECK: %[[VAL_12:.*]] = torch_c.from_builtin_tensor %[[VAL_11]] : tensor<1x1x128x128xi64> -> !torch.vtensor<[1,1,128,128],si64> -// CHECK: return %[[VAL_12]] : !torch.vtensor<[1,1,128,128],si64> +// CHECK: %[[VAL_5:.*]] = tosa.const_shape {value = dense<1> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_4]], %[[VAL_5]] : (tensor, !tosa.shape<4>) -> tensor<1x1x1x1xi64> +// CHECK: %[[VAL_6_cast:.*]] = tosa.cast %[[VAL_6]] : (tensor<1x1x1x1xi64>) -> tensor<1x1x1x1xi32> +// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<1> : tensor}> : () -> tensor +// CHECK: %[[VAL_8:.*]] = tosa.const_shape {value = dense<1> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_9:.*]] = tosa.reshape %[[VAL_7]], %[[VAL_8]] : (tensor, !tosa.shape<4>) -> tensor<1x1x1x1xi32> +// CHECK: %[[VAL_10:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_11:.*]] = tosa.mul %[[VAL_6_cast]], %[[VAL_9]], %[[VAL_10]] : (tensor<1x1x1x1xi32>, tensor<1x1x1x1xi32>, tensor<1xi8>) -> tensor<1x1x1x1xi32> +// CHECK: %[[VAL_12:.*]] = tosa.cast %[[VAL_1]] : (tensor<1x1x128x128xi64>) -> tensor<1x1x128x128xi32> +// CHECK: %[[VAL_13:.*]] = tosa.add %[[VAL_12]], %[[VAL_11]] : (tensor<1x1x128x128xi32>, tensor<1x1x1x1xi32>) -> tensor<1x1x128x128xi32> +// CHECK: %[[VAL_14:.*]] = tosa.cast %[[VAL_13]] : (tensor<1x1x128x128xi32>) -> tensor<1x1x128x128xi64> +// CHECK: %[[VAL_15:.*]] = torch_c.from_builtin_tensor %[[VAL_14]] : tensor<1x1x128x128xi64> -> !torch.vtensor<[1,1,128,128],si64> +// CHECK: return %[[VAL_15]] : !torch.vtensor<[1,1,128,128],si64> // CHECK: } func.func @torch.aten.Scalar$basic(%arg0: !torch.vtensor<[1,1,128,128],si64>) -> !torch.vtensor<[1,1,128,128],si64> { %int1 = torch.constant.int 1 @@ -1388,17 +1567,19 @@ func.func @torch.aten.Scalar$basic(%arg0: !torch.vtensor<[1,1,128,128],si64>) -> // ----- // CHECK-LABEL: func.func @torch.aten.slice.negative_start( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[4,16,256],f32> { +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[4,16,256],f32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,65,256],f32> -> tensor<4x65x256xf32> // CHECK: %[[VAL_2:.*]] = torch.constant.int 0 // CHECK: %[[VAL_3:.*]] = torch.constant.int 1 // CHECK: %[[VAL_4:.*]] = torch.constant.int 100 // CHECK: %[[VAL_5:.*]] = torch.constant.int -16 // CHECK: %[[VAL_1r:.*]] = tosa.reshape -// CHECK: %[[VAL_4:.*]] = tosa.slice %[[VAL_1r]] {size = array, start = array} : (tensor<4x65x1x256xf32>) -> tensor<4x16x1x256xf32> -// CHECK: %[[VAL_4r:.*]] = tosa.reshape -// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4r]] : tensor<4x16x256xf32> -> !torch.vtensor<[4,16,256],f32> -// CHECK: return %[[VAL_5]] : !torch.vtensor<[4,16,256],f32> +// CHECK: %[[VAL_6:.*]] = tosa.const_shape {value = dense<[0, 49, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_7:.*]] = tosa.const_shape {value = dense<[4, 16, 1, 256]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_8:.*]] = tosa.slice %[[VAL_1r]], %[[VAL_6]], %[[VAL_7]] : (tensor<4x65x1x256xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<4x16x1x256xf32> +// CHECK: %[[VAL_8r:.*]] = tosa.reshape +// CHECK: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_8r]] : tensor<4x16x256xf32> -> !torch.vtensor<[4,16,256],f32> +// CHECK: return %[[VAL_9]] : !torch.vtensor<[4,16,256],f32> // CHECK: } func.func @torch.aten.slice.negative_start(%arg0: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[4,16,256],f32> { %int0 = torch.constant.int 0 @@ -1486,10 +1667,11 @@ func.func @torch.aten.clamp.float(%arg0: !torch.vtensor<[1,1,128,128],f32>) -> ! // CHECK: %[[VAL_4:.*]] = torch.constant.int 0 // CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<0> : tensor}> : () -> tensor // CHECK: %[[VAL_6:.*]] = tosa.cast %[[VAL_5]] : (tensor) -> tensor -// CHECK: %[[VAL_7:.*]] = tosa.reshape %[[VAL_6]] {new_shape = array} : (tensor) -> tensor<1x1x1x1xf32> -// CHECK: %[[VAL_8:.*]] = tosa.select %[[VAL_2]], %[[VAL_7]], %[[VAL_3]] : (tensor<1x1x128x128xi1>, tensor<1x1x1x1xf32>, tensor<1x12x128x128xf32>) -> tensor<1x12x128x128xf32> -// CHECK: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_8]] : tensor<1x12x128x128xf32> -> !torch.vtensor<[1,12,128,128],f32> -// CHECK: return %[[VAL_9]] : !torch.vtensor<[1,12,128,128],f32> +// CHECK: %[[VAL_7:.*]] = tosa.const_shape {value = dense<1> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_6]], %[[VAL_7]] : (tensor, !tosa.shape<4>) -> tensor<1x1x1x1xf32> +// CHECK: %[[VAL_9:.*]] = tosa.select %[[VAL_2]], %[[VAL_8]], %[[VAL_3]] : (tensor<1x1x128x128xi1>, tensor<1x1x1x1xf32>, tensor<1x12x128x128xf32>) -> tensor<1x12x128x128xf32> +// CHECK: %[[VAL_10:.*]] = torch_c.from_builtin_tensor %[[VAL_9]] : tensor<1x12x128x128xf32> -> !torch.vtensor<[1,12,128,128],f32> +// CHECK: return %[[VAL_10]] : !torch.vtensor<[1,12,128,128],f32> // CHECK: } func.func @torch.aten.masked_fill.Scalar(%arg0: !torch.vtensor<[1,12,128,128],f32>, %arg1: !torch.vtensor<[1,1,128,128],i1>) -> !torch.vtensor<[1,12,128,128],f32> { %int0 = torch.constant.int 0 @@ -1505,10 +1687,11 @@ func.func @torch.aten.masked_fill.Scalar(%arg0: !torch.vtensor<[1,12,128,128],f3 // CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_2]] : !torch.vtensor<[],f32> -> tensor // CHECK: %[[VAL_4:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[1,1,128,128],i1> -> tensor<1x1x128x128xi1> // CHECK: %[[VAL_5:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,12,128,128],f32> -> tensor<1x12x128x128xf32> -// CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor) -> tensor<1x1x1x1xf32> -// CHECK: %[[VAL_7:.*]] = tosa.select %[[VAL_4]], %[[VAL_6]], %[[VAL_5]] : (tensor<1x1x128x128xi1>, tensor<1x1x1x1xf32>, tensor<1x12x128x128xf32>) -> tensor<1x12x128x128xf32> -// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<1x12x128x128xf32> -> !torch.vtensor<[1,12,128,128],f32> -// CHECK: return %[[VAL_8]] : !torch.vtensor<[1,12,128,128],f32> +// CHECK: %[[VAL_6:.*]] = tosa.const_shape {value = dense<1> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_7:.*]] = tosa.reshape %[[VAL_3]], %[[VAL_6]] : (tensor, !tosa.shape<4>) -> tensor<1x1x1x1xf32> +// CHECK: %[[VAL_8:.*]] = tosa.select %[[VAL_4]], %[[VAL_7]], %[[VAL_5]] : (tensor<1x1x128x128xi1>, tensor<1x1x1x1xf32>, tensor<1x12x128x128xf32>) -> tensor<1x12x128x128xf32> +// CHECK: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_8]] : tensor<1x12x128x128xf32> -> !torch.vtensor<[1,12,128,128],f32> +// CHECK: return %[[VAL_9]] : !torch.vtensor<[1,12,128,128],f32> // CHECK: } func.func @torch.aten.masked_fill.Tensor(%arg0: !torch.vtensor<[1,12,128,128],f32>, %arg1: !torch.vtensor<[1,1,128,128],i1>, %arg2: !torch.vtensor<[],f32>) -> !torch.vtensor<[1,12,128,128],f32> { %0 = torch.aten.masked_fill.Tensor %arg0, %arg1, %arg2 : !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[1,1,128,128],i1>, !torch.vtensor<[],f32> -> !torch.vtensor<[1,12,128,128],f32> @@ -1536,10 +1719,11 @@ func.func @torch.aten.abs(%arg0: !torch.vtensor<[15,15],si64>) -> !torch.vtensor // CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_2]] : !torch.vtensor<[],f32> -> tensor // CHECK: %[[VAL_4:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[1,12,5,5],f32> -> tensor<1x12x5x5xf32> // CHECK: %[[VAL_5:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,1,5,5],i1> -> tensor<1x1x5x5xi1> -// CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor) -> tensor<1x1x1x1xf32> -// CHECK: %[[VAL_7:.*]] = tosa.select %[[VAL_5]], %[[VAL_4]], %[[VAL_6]] : (tensor<1x1x5x5xi1>, tensor<1x12x5x5xf32>, tensor<1x1x1x1xf32>) -> tensor<1x12x5x5xf32> -// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<1x12x5x5xf32> -> !torch.vtensor<[1,12,5,5],f32> -// CHECK: return %[[VAL_8]] : !torch.vtensor<[1,12,5,5],f32> +// CHECK: %[[VAL_6:.*]] = tosa.const_shape {value = dense<1> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_7:.*]] = tosa.reshape %[[VAL_3]], %[[VAL_6]] : (tensor, !tosa.shape<4>) -> tensor<1x1x1x1xf32> +// CHECK: %[[VAL_8:.*]] = tosa.select %[[VAL_5]], %[[VAL_4]], %[[VAL_7]] : (tensor<1x1x5x5xi1>, tensor<1x12x5x5xf32>, tensor<1x1x1x1xf32>) -> tensor<1x12x5x5xf32> +// CHECK: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_8]] : tensor<1x12x5x5xf32> -> !torch.vtensor<[1,12,5,5],f32> +// CHECK: return %[[VAL_9]] : !torch.vtensor<[1,12,5,5],f32> // CHECK: } func.func @torch.aten.where.self(%arg0: !torch.vtensor<[1,1,5,5],i1>, %arg1: !torch.vtensor<[1,12,5,5],f32>, %arg2: !torch.vtensor<[],f32>) -> !torch.vtensor<[1,12,5,5],f32> { %0 = torch.aten.where.self %arg0, %arg1, %arg2 : !torch.vtensor<[1,1,5,5],i1>, !torch.vtensor<[1,12,5,5],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[1,12,5,5],f32> @@ -1552,14 +1736,17 @@ func.func @torch.aten.where.self(%arg0: !torch.vtensor<[1,1,5,5],i1>, %arg1: !to // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,4],f32> -> tensor<2x4xf32> // CHECK: %[[VAL_2:.*]] = torch.constant.int 2 // CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<2.000000e+00> : tensor}> : () -> tensor -// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor) -> tensor<1x1xf32> -// CHECK: %[[VAL_5:.*]] = tosa.reciprocal %[[VAL_4]] : (tensor<1x1xf32>) -> tensor<1x1xf32> -// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_1]], %[[VAL_5]] {shift = 0 : i8} : (tensor<2x4xf32>, tensor<1x1xf32>) -> tensor<2x4xf32> -// CHECK: %[[VAL_7:.*]] = tosa.floor %[[VAL_6]] : (tensor<2x4xf32>) -> tensor<2x4xf32> -// CHECK: %[[VAL_8:.*]] = tosa.mul %[[VAL_4]], %[[VAL_7]] {shift = 0 : i8} : (tensor<1x1xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> -// CHECK: %[[VAL_9:.*]] = tosa.sub %[[VAL_1]], %[[VAL_8]] : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> -// CHECK: %[[VAL_10:.*]] = torch_c.from_builtin_tensor %[[VAL_9]] : tensor<2x4xf32> -> !torch.vtensor<[2,4],f32> -// CHECK: return %[[VAL_10]] : !torch.vtensor<[2,4],f32> +// CHECK: %[[VAL_4:.*]] = tosa.const_shape {value = dense<1> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_5:.*]] = tosa.reshape %[[VAL_3]], %[[VAL_4]] : (tensor, !tosa.shape<2>) -> tensor<1x1xf32> +// CHECK: %[[VAL_6:.*]] = tosa.reciprocal %[[VAL_5]] : (tensor<1x1xf32>) -> tensor<1x1xf32> +// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_8:.*]] = tosa.mul %[[VAL_1]], %[[VAL_6]], %[[VAL_7]] : (tensor<2x4xf32>, tensor<1x1xf32>, tensor<1xi8>) -> tensor<2x4xf32> +// CHECK: %[[VAL_9:.*]] = tosa.floor %[[VAL_8]] : (tensor<2x4xf32>) -> tensor<2x4xf32> +// CHECK: %[[VAL_10:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_11:.*]] = tosa.mul %[[VAL_5]], %[[VAL_9]], %[[VAL_10]] : (tensor<1x1xf32>, tensor<2x4xf32>, tensor<1xi8>) -> tensor<2x4xf32> +// CHECK: %[[VAL_12:.*]] = tosa.sub %[[VAL_1]], %[[VAL_11]] : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> +// CHECK: %[[VAL_13:.*]] = torch_c.from_builtin_tensor %[[VAL_12]] : tensor<2x4xf32> -> !torch.vtensor<[2,4],f32> +// CHECK: return %[[VAL_13]] : !torch.vtensor<[2,4],f32> // CHECK: } func.func @torch.aten.remainder.Scalar(%arg0: !torch.vtensor<[2, 4],f32>) -> !torch.vtensor<[2, 4],f32> { %int2 = torch.constant.int 2 @@ -1579,16 +1766,19 @@ func.func @torch.aten.remainder.Scalar(%arg0: !torch.vtensor<[2, 4],f32>) -> !to // CHECK: %[[VAL_6:.*]] = torch.constant.bool false // CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<9.99999974E-6> : tensor}> : () -> tensor // CHECK: %[[VAL_8:.*]] = "tosa.const"() <{value = dense<9.99999993E-9> : tensor}> : () -> tensor -// CHECK: %[[VAL_9:.*]] = tosa.reshape %[[VAL_7]] {new_shape = array} : (tensor) -> tensor<1x1xf32> -// CHECK: %[[VAL_10:.*]] = tosa.reshape %[[VAL_8]] {new_shape = array} : (tensor) -> tensor<1x1xf32> -// CHECK: %[[VAL_11:.*]] = tosa.sub %[[VAL_3]], %[[VAL_2]] : (tensor<5x5xf32>, tensor<5x5xf32>) -> tensor<5x5xf32> -// CHECK: %[[VAL_12:.*]] = tosa.abs %[[VAL_11]] : (tensor<5x5xf32>) -> tensor<5x5xf32> -// CHECK: %[[VAL_13:.*]] = tosa.abs %[[VAL_2]] : (tensor<5x5xf32>) -> tensor<5x5xf32> -// CHECK: %[[VAL_14:.*]] = tosa.mul %[[VAL_9]], %[[VAL_13]] {shift = 0 : i8} : (tensor<1x1xf32>, tensor<5x5xf32>) -> tensor<5x5xf32> -// CHECK: %[[VAL_15:.*]] = tosa.add %[[VAL_10]], %[[VAL_14]] : (tensor<1x1xf32>, tensor<5x5xf32>) -> tensor<5x5xf32> -// CHECK: %[[VAL_16:.*]] = tosa.greater_equal %[[VAL_15]], %[[VAL_12]] : (tensor<5x5xf32>, tensor<5x5xf32>) -> tensor<5x5xi1> -// CHECK: %[[VAL_17:.*]] = torch_c.from_builtin_tensor %[[VAL_16]] : tensor<5x5xi1> -> !torch.vtensor<[5,5],i1> -// CHECK: return %[[VAL_17]] : !torch.vtensor<[5,5],i1> +// CHECK: %[[VAL_9:.*]] = tosa.const_shape {value = dense<1> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_10:.*]] = tosa.reshape %[[VAL_7]], %[[VAL_9]] : (tensor, !tosa.shape<2>) -> tensor<1x1xf32> +// CHECK: %[[VAL_11:.*]] = tosa.const_shape {value = dense<1> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_12:.*]] = tosa.reshape %[[VAL_8]], %[[VAL_11]] : (tensor, !tosa.shape<2>) -> tensor<1x1xf32> +// CHECK: %[[VAL_13:.*]] = tosa.sub %[[VAL_3]], %[[VAL_2]] : (tensor<5x5xf32>, tensor<5x5xf32>) -> tensor<5x5xf32> +// CHECK: %[[VAL_14:.*]] = tosa.abs %[[VAL_13]] : (tensor<5x5xf32>) -> tensor<5x5xf32> +// CHECK: %[[VAL_15:.*]] = tosa.abs %[[VAL_2]] : (tensor<5x5xf32>) -> tensor<5x5xf32> +// CHECK: %[[VAL_16:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_17:.*]] = tosa.mul %[[VAL_10]], %[[VAL_15]], %[[VAL_16]] : (tensor<1x1xf32>, tensor<5x5xf32>, tensor<1xi8>) -> tensor<5x5xf32> +// CHECK: %[[VAL_18:.*]] = tosa.add %[[VAL_12]], %[[VAL_17]] : (tensor<1x1xf32>, tensor<5x5xf32>) -> tensor<5x5xf32> +// CHECK: %[[VAL_19:.*]] = tosa.greater_equal %[[VAL_18]], %[[VAL_14]] : (tensor<5x5xf32>, tensor<5x5xf32>) -> tensor<5x5xi1> +// CHECK: %[[VAL_20:.*]] = torch_c.from_builtin_tensor %[[VAL_19]] : tensor<5x5xi1> -> !torch.vtensor<[5,5],i1> +// CHECK: return %[[VAL_20]] : !torch.vtensor<[5,5],i1> // CHECK: } func.func @torch.aten.isclose$basic(%arg0: !torch.vtensor<[5,5],f32>, %arg1: !torch.vtensor<[5,5],f32>) -> !torch.vtensor<[5,5],i1> { %float1.000000e-08 = torch.constant.float 1.000000e-08 @@ -1687,9 +1877,10 @@ func.func @torch.aten.__interpolate.size_list_scale_list.nearest(%arg0: !torch.v // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,4],si32> -> tensor<2x4xi32> // CHECK: %[[VAL_2:.*]] = torch.constant.int 1 // CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<{{\[\[}}1, 1, 0, 0], [1, 1, 1, 0]]> : tensor<2x4xi32>}> : () -> tensor<2x4xi32> -// CHECK: %[[VAL_4:.*]] = tosa.mul %[[VAL_1]], %[[VAL_3]] {shift = 0 : i8} : (tensor<2x4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32> -// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<2x4xi32> -> !torch.vtensor<[2,4],si32> -// CHECK: return %[[VAL_5]] : !torch.vtensor<[2,4],si32> +// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_5:.*]] = tosa.mul %[[VAL_1]], %[[VAL_3]], %[[VAL_4]] : (tensor<2x4xi32>, tensor<2x4xi32>, tensor<1xi8>) -> tensor<2x4xi32> +// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<2x4xi32> -> !torch.vtensor<[2,4],si32> +// CHECK: return %[[VAL_6]] : !torch.vtensor<[2,4],si32> // CHECK: } func.func @torch.aten.tril$basic(%arg0: !torch.vtensor<[2,4], si32>) -> !torch.vtensor<[2,4], si32> { %int0 = torch.constant.int 1 @@ -1701,17 +1892,19 @@ func.func @torch.aten.tril$basic(%arg0: !torch.vtensor<[2,4], si32>) -> !torch.v // CHECK-LABEL: func.func @torch.aten.min.dim$basic( // CHECK-SAME: %[[VAL_0:.*]]: tensor<3x2x3xf32>) -> tensor<3x2x1xf32> { -// CHECK-DAG: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<3x2x3xf32> -> !torch.vtensor<[3,2,3],f32> -// CHECK-DAG: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[3,2,3],f32> -> tensor<3x2x3xf32> -// CHECK-DAG: %[[VAL_3:.*]] = torch.constant.bool true -// CHECK-DAG: %[[VAL_4:.*]] = torch.constant.int 2 -// CHECK-DAG: %[[VAL_5:.*]] = tosa.reduce_min %[[VAL_2]] {axis = 2 : i32} : (tensor<3x2x3xf32>) -> tensor<3x2x1xf32> -// CHECK-DAG: %[[VAL_6:.*]] = tosa.negate %[[VAL_2]] : (tensor<3x2x3xf32>) -> tensor<3x2x3xf32> -// CHECK-DAG: %[[VAL_7:.*]] = tosa.argmax %[[VAL_6]] {axis = 2 : i32} : (tensor<3x2x3xf32>) -> tensor<3x2xi64> -// CHECK-DAG: %[[VAL_8:.*]] = tosa.reshape %[[VAL_7]] {new_shape = array} : (tensor<3x2xi64>) -> tensor<3x2x1xi64> -// CHECK-DAG: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<3x2x1xf32> -> !torch.vtensor<[3,2,1],f32> -// CHECK-DAG: %[[VAL_10:.*]] = torch_c.to_builtin_tensor %[[VAL_9]] : !torch.vtensor<[3,2,1],f32> -> tensor<3x2x1xf32> -// CHECK: return %[[VAL_10]] : tensor<3x2x1xf32> +// CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<3x2x3xf32> -> !torch.vtensor<[3,2,3],f32> +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[3,2,3],f32> -> tensor<3x2x3xf32> +// CHECK: %[[VAL_3:.*]] = torch.constant.bool true +// CHECK: %[[VAL_4:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_5:.*]] = tosa.const_shape {value = dense<[3, 2]> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_6:.*]] = tosa.reduce_min %[[VAL_2]] {axis = 2 : i32} : (tensor<3x2x3xf32>) -> tensor<3x2x1xf32> +// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<3x2x1xf32> -> !torch.vtensor<[3,2,1],f32> +// CHECK: %[[VAL_8:.*]] = tosa.negate %[[VAL_2]] : (tensor<3x2x3xf32>) -> tensor<3x2x3xf32> +// CHECK: %[[VAL_9:.*]] = tosa.argmax %[[VAL_8]] {axis = 2 : i32} : (tensor<3x2x3xf32>) -> tensor<3x2xi64> +// CHECK: %[[VAL_10:.*]] = tosa.const_shape {value = dense<[3, 2, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_11:.*]] = tosa.reshape %[[VAL_9]], %[[VAL_10]] : (tensor<3x2xi64>, !tosa.shape<3>) -> tensor<3x2x1xi64> +// CHECK: %[[VAL_12:.*]] = torch_c.to_builtin_tensor %[[VAL_7]] : !torch.vtensor<[3,2,1],f32> -> tensor<3x2x1xf32> +// CHECK: return %[[VAL_12]] : tensor<3x2x1xf32> // CHECK: } func.func @torch.aten.min.dim$basic(%arg0: tensor<3x2x3xf32>) -> tensor<3x2x1xf32> { %0 = torch_c.from_builtin_tensor %arg0 : tensor<3x2x3xf32> -> !torch.vtensor<[3,2,3],f32> @@ -1730,9 +1923,10 @@ func.func @torch.aten.min.dim$basic(%arg0: tensor<3x2x3xf32>) -> tensor<3x2x1xf3 // CHECK: %[[VAL_2:.*]] = tosa.reduce_min %[[VAL_1]] {axis = 0 : i32} : (tensor<3x2x3xf32>) -> tensor<1x2x3xf32> // CHECK: %[[VAL_3:.*]] = tosa.reduce_min %[[VAL_2]] {axis = 1 : i32} : (tensor<1x2x3xf32>) -> tensor<1x1x3xf32> // CHECK: %[[VAL_4:.*]] = tosa.reduce_min %[[VAL_3]] {axis = 2 : i32} : (tensor<1x1x3xf32>) -> tensor<1x1x1xf32> -// CHECK: %[[VAL_5:.*]] = tosa.reshape %[[VAL_4]] {new_shape = array} : (tensor<1x1x1xf32>) -> tensor<1xf32> -// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<1xf32> -> !torch.vtensor<[1],f32> -// CHECK: return %[[VAL_6]] : !torch.vtensor<[1],f32> +// CHECK: %[[VAL_5:.*]] = tosa.const_shape {value = dense<1> : tensor<1xindex>} : () -> !tosa.shape<1> +// CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_4]], %[[VAL_5]] : (tensor<1x1x1xf32>, !tosa.shape<1>) -> tensor<1xf32> +// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<1xf32> -> !torch.vtensor<[1],f32> +// CHECK: return %[[VAL_7]] : !torch.vtensor<[1],f32> // CHECK: } func.func @torch.aten.min$basic(%arg0: !torch.vtensor<[3,2,3],f32>) -> !torch.vtensor<[1],f32> { %0 = torch.aten.min %arg0: !torch.vtensor<[3,2,3],f32> -> !torch.vtensor<[1],f32> @@ -1747,9 +1941,10 @@ func.func @torch.aten.min$basic(%arg0: !torch.vtensor<[3,2,3],f32>) -> !torch.vt // CHECK: %[[VAL_2:.*]] = tosa.reduce_max %[[VAL_1]] {axis = 0 : i32} : (tensor<3x2x3xf32>) -> tensor<1x2x3xf32> // CHECK: %[[VAL_3:.*]] = tosa.reduce_max %[[VAL_2]] {axis = 1 : i32} : (tensor<1x2x3xf32>) -> tensor<1x1x3xf32> // CHECK: %[[VAL_4:.*]] = tosa.reduce_max %[[VAL_3]] {axis = 2 : i32} : (tensor<1x1x3xf32>) -> tensor<1x1x1xf32> -// CHECK: %[[VAL_5:.*]] = tosa.reshape %[[VAL_4]] {new_shape = array} : (tensor<1x1x1xf32>) -> tensor<1xf32> -// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<1xf32> -> !torch.vtensor<[1],f32> -// CHECK: return %[[VAL_6]] : !torch.vtensor<[1],f32> +// CHECK: %[[VAL_5:.*]] = tosa.const_shape {value = dense<1> : tensor<1xindex>} : () -> !tosa.shape<1> +// CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_4]], %[[VAL_5]] : (tensor<1x1x1xf32>, !tosa.shape<1>) -> tensor<1xf32> +// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<1xf32> -> !torch.vtensor<[1],f32> +// CHECK: return %[[VAL_7]] : !torch.vtensor<[1],f32> // CHECK: } func.func @torch.aten.max$basic(%arg0: !torch.vtensor<[3,2,3],f32>) -> !torch.vtensor<[1],f32> { %0 = torch.aten.max %arg0: !torch.vtensor<[3,2,3],f32> -> !torch.vtensor<[1],f32> @@ -1803,20 +1998,25 @@ func.func @torch.aten.all.dim$basic(%arg0: !torch.vtensor<[3,2,3],i1>) -> !torch // CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[VAL_4:.*]] = torch.constant.str "trunc" // CHECK: %[[VAL_5:.*]] = tosa.reciprocal %[[VAL_2]] : (tensor) -> tensor -// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_3]], %[[VAL_5]] {shift = 0 : i8} : (tensor, tensor) -> tensor -// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor}> : () -> tensor -// CHECK: %[[VAL_8:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor}> : () -> tensor -// CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<-1.000000e+00> : tensor}> : () -> tensor -// CHECK: %[[VAL_10:.*]] = tosa.reshape %[[VAL_8]] {new_shape = array} : (tensor) -> tensor<1x1xf32> -// CHECK: %[[VAL_11:.*]] = tosa.reshape %[[VAL_7]] {new_shape = array} : (tensor) -> tensor<1x1xf32> -// CHECK: %[[VAL_12:.*]] = tosa.reshape %[[VAL_9]] {new_shape = array} : (tensor) -> tensor<1x1xf32> -// CHECK: %[[VAL_13:.*]] = tosa.greater_equal %[[VAL_6]], %[[VAL_11]] : (tensor, tensor<1x1xf32>) -> tensor -// CHECK: %[[VAL_14:.*]] = tosa.select %[[VAL_13]], %[[VAL_10]], %[[VAL_12]] : (tensor, tensor<1x1xf32>, tensor<1x1xf32>) -> tensor -// CHECK: %[[VAL_15:.*]] = tosa.abs %[[VAL_6]] : (tensor) -> tensor -// CHECK: %[[VAL_16:.*]] = tosa.floor %[[VAL_15]] : (tensor) -> tensor -// CHECK: %[[VAL_17:.*]] = tosa.mul %[[VAL_16]], %[[VAL_14]] {shift = 0 : i8} : (tensor, tensor) -> tensor -// CHECK: %[[VAL_18:.*]] = torch_c.from_builtin_tensor %[[VAL_17]] : tensor -> !torch.vtensor<[?,?],f32> -// CHECK: return %[[VAL_18]] : !torch.vtensor<[?,?],f32> +// CHECK: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_7:.*]] = tosa.mul %[[VAL_3]], %[[VAL_5]], %[[VAL_6]] : (tensor, tensor, tensor<1xi8>) -> tensor +// CHECK: %[[VAL_8:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor}> : () -> tensor +// CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor}> : () -> tensor +// CHECK: %[[VAL_10:.*]] = "tosa.const"() <{value = dense<-1.000000e+00> : tensor}> : () -> tensor +// CHECK: %[[VAL_11:.*]] = tosa.const_shape {value = dense<1> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_12:.*]] = tosa.reshape %[[VAL_9]], %[[VAL_11]] : (tensor, !tosa.shape<2>) -> tensor<1x1xf32> +// CHECK: %[[VAL_13:.*]] = tosa.const_shape {value = dense<1> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_14:.*]] = tosa.reshape %[[VAL_8]], %[[VAL_13]] : (tensor, !tosa.shape<2>) -> tensor<1x1xf32> +// CHECK: %[[VAL_15:.*]] = tosa.const_shape {value = dense<1> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_16:.*]] = tosa.reshape %[[VAL_10]], %[[VAL_15]] : (tensor, !tosa.shape<2>) -> tensor<1x1xf32> +// CHECK: %[[VAL_17:.*]] = tosa.greater_equal %[[VAL_7]], %[[VAL_14]] : (tensor, tensor<1x1xf32>) -> tensor +// CHECK: %[[VAL_18:.*]] = tosa.select %[[VAL_17]], %[[VAL_12]], %[[VAL_16]] : (tensor, tensor<1x1xf32>, tensor<1x1xf32>) -> tensor +// CHECK: %[[VAL_19:.*]] = tosa.abs %[[VAL_7]] : (tensor) -> tensor +// CHECK: %[[VAL_20:.*]] = tosa.floor %[[VAL_19]] : (tensor) -> tensor +// CHECK: %[[VAL_21:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_22:.*]] = tosa.mul %[[VAL_20]], %[[VAL_18]], %[[VAL_21]] : (tensor, tensor, tensor<1xi8>) -> tensor +// CHECK: %[[VAL_23:.*]] = torch_c.from_builtin_tensor %[[VAL_22]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_23]] : !torch.vtensor<[?,?],f32> // CHECK: } func.func @torch.aten.div.Tensor_mode$float_trunc(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vtensor<[?, ?],f32>) -> !torch.vtensor<[?, ?],f32> { %str = torch.constant.str "trunc" @@ -1854,10 +2054,11 @@ func.func @torch.aten.div.Tensor_mode$int_trunc(%arg0: !torch.vtensor<[?, ?],si6 // CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[VAL_4:.*]] = torch.constant.str "floor" // CHECK: %[[VAL_5:.*]] = tosa.reciprocal %[[VAL_2]] : (tensor) -> tensor -// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_3]], %[[VAL_5]] {shift = 0 : i8} : (tensor, tensor) -> tensor -// CHECK: %[[VAL_7:.*]] = tosa.floor %[[VAL_6]] : (tensor) -> tensor -// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor -> !torch.vtensor<[?,?],f32> -// CHECK: return %[[VAL_8]] : !torch.vtensor<[?,?],f32> +// CHECK: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_7:.*]] = tosa.mul %[[VAL_3]], %[[VAL_5]], %[[VAL_6]] : (tensor, tensor, tensor<1xi8>) -> tensor +// CHECK: %[[VAL_8:.*]] = tosa.floor %[[VAL_7]] : (tensor) -> tensor +// CHECK: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_8]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_9]] : !torch.vtensor<[?,?],f32> // CHECK: } func.func @torch.aten.div.Tensor_mode$float_floor(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vtensor<[?, ?],f32>) -> !torch.vtensor<[?, ?],f32> { %str = torch.constant.str "floor" @@ -1878,19 +2079,23 @@ func.func @torch.aten.div.Tensor_mode$float_floor(%arg0: !torch.vtensor<[?, ?],f // CHECK: %[[VAL_7:.*]] = tosa.int_div %[[VAL_5]], %[[VAL_6]] : (tensor, tensor) -> tensor // CHECK: %[[VAL_8:.*]] = "tosa.const"() <{value = dense<0> : tensor}> : () -> tensor // CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<1> : tensor}> : () -> tensor -// CHECK: %[[VAL_10:.*]] = tosa.reshape %[[VAL_9]] {new_shape = array} : (tensor) -> tensor<1x1xi32> -// CHECK: %[[VAL_11:.*]] = tosa.reshape %[[VAL_8]] {new_shape = array} : (tensor) -> tensor<1x1xi32> -// CHECK: %[[VAL_12:.*]] = tosa.mul %[[VAL_5]], %[[VAL_6]] {shift = 0 : i8} : (tensor, tensor) -> tensor -// CHECK: %[[VAL_13:.*]] = tosa.greater %[[VAL_11]], %[[VAL_12]] : (tensor<1x1xi32>, tensor) -> tensor -// CHECK: %[[VAL_14:.*]] = tosa.mul %[[VAL_7]], %[[VAL_6]] {shift = 0 : i8} : (tensor, tensor) -> tensor -// CHECK: %[[VAL_15:.*]] = tosa.equal %[[VAL_14]], %[[VAL_5]] : (tensor, tensor) -> tensor -// CHECK: %[[VAL_16:.*]] = tosa.logical_not %[[VAL_15]] : (tensor) -> tensor -// CHECK: %[[VAL_17:.*]] = tosa.sub %[[VAL_7]], %[[VAL_10]] : (tensor, tensor<1x1xi32>) -> tensor -// CHECK: %[[VAL_18:.*]] = tosa.logical_and %[[VAL_13]], %[[VAL_16]] : (tensor, tensor) -> tensor -// CHECK: %[[VAL_19:.*]] = tosa.select %[[VAL_18]], %[[VAL_17]], %[[VAL_7]] : (tensor, tensor, tensor) -> tensor -// CHECK: %[[VAL_20:.*]] = tosa.cast %[[VAL_19]] : (tensor) -> tensor -// CHECK: %[[VAL_21:.*]] = torch_c.from_builtin_tensor %[[VAL_20]] : tensor -> !torch.vtensor<[?,?],si64> -// CHECK: return %[[VAL_21]] : !torch.vtensor<[?,?],si64> +// CHECK: %[[VAL_10:.*]] = tosa.const_shape {value = dense<1> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_11:.*]] = tosa.reshape %[[VAL_9]], %[[VAL_10]] : (tensor, !tosa.shape<2>) -> tensor<1x1xi32> +// CHECK: %[[VAL_12:.*]] = tosa.const_shape {value = dense<1> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_8]], %[[VAL_12]] : (tensor, !tosa.shape<2>) -> tensor<1x1xi32> +// CHECK: %[[VAL_14:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_15:.*]] = tosa.mul %[[VAL_5]], %[[VAL_6]], %[[VAL_14]] : (tensor, tensor, tensor<1xi8>) -> tensor +// CHECK: %[[VAL_16:.*]] = tosa.greater %[[VAL_13]], %[[VAL_15]] : (tensor<1x1xi32>, tensor) -> tensor +// CHECK: %[[VAL_17:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_18:.*]] = tosa.mul %[[VAL_7]], %[[VAL_6]], %[[VAL_17]] : (tensor, tensor, tensor<1xi8>) -> tensor +// CHECK: %[[VAL_19:.*]] = tosa.equal %[[VAL_18]], %[[VAL_5]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_20:.*]] = tosa.logical_not %[[VAL_19]] : (tensor) -> tensor +// CHECK: %[[VAL_21:.*]] = tosa.sub %[[VAL_7]], %[[VAL_11]] : (tensor, tensor<1x1xi32>) -> tensor +// CHECK: %[[VAL_22:.*]] = tosa.logical_and %[[VAL_16]], %[[VAL_20]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_23:.*]] = tosa.select %[[VAL_22]], %[[VAL_21]], %[[VAL_7]] : (tensor, tensor, tensor) -> tensor +// CHECK: %[[VAL_24:.*]] = tosa.cast %[[VAL_23]] : (tensor) -> tensor +// CHECK: %[[VAL_25:.*]] = torch_c.from_builtin_tensor %[[VAL_24]] : tensor -> !torch.vtensor<[?,?],si64> +// CHECK: return %[[VAL_25]] : !torch.vtensor<[?,?],si64> // CHECK: } func.func @torch.aten.div.Tensor_mode$int_floor(%arg0: !torch.vtensor<[?, ?],si64>, %arg1: !torch.vtensor<[?, ?],si64>) -> !torch.vtensor<[?, ?],si64> { %str = torch.constant.str "floor" @@ -1907,9 +2112,10 @@ func.func @torch.aten.div.Tensor_mode$int_floor(%arg0: !torch.vtensor<[?, ?],si6 // CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[VAL_4:.*]] = torch.constant.str "" // CHECK: %[[VAL_5:.*]] = tosa.reciprocal %[[VAL_2]] : (tensor) -> tensor -// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_3]], %[[VAL_5]] {shift = 0 : i8} : (tensor, tensor) -> tensor -// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor -> !torch.vtensor<[?,?],f32> -// CHECK: return %[[VAL_7]] : !torch.vtensor<[?,?],f32> +// CHECK: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_7:.*]] = tosa.mul %[[VAL_3]], %[[VAL_5]], %[[VAL_6]] : (tensor, tensor, tensor<1xi8>) -> tensor +// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_8]] : !torch.vtensor<[?,?],f32> // CHECK: } func.func @torch.aten.div.Tensor_mode$float_basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vtensor<[?, ?],f32>) -> !torch.vtensor<[?, ?],f32> { %str = torch.constant.str "" @@ -1962,12 +2168,14 @@ func.func @torch.aten.ge.Tensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: ! // CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[2,4],f32> -> tensor<2x4xf32> // CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,4],f32> -> tensor<2x4xf32> // CHECK: %[[VAL_4:.*]] = tosa.reciprocal %[[VAL_2]] : (tensor<2x4xf32>) -> tensor<2x4xf32> -// CHECK: %[[VAL_5:.*]] = tosa.mul %[[VAL_3]], %[[VAL_4]] {shift = 0 : i8} : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> -// CHECK: %[[VAL_6:.*]] = tosa.floor %[[VAL_5]] : (tensor<2x4xf32>) -> tensor<2x4xf32> -// CHECK: %[[VAL_7:.*]] = tosa.mul %[[VAL_2]], %[[VAL_6]] {shift = 0 : i8} : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> -// CHECK: %[[VAL_8:.*]] = tosa.sub %[[VAL_3]], %[[VAL_7]] : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> -// CHECK: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_8]] : tensor<2x4xf32> -> !torch.vtensor<[2,4],f32> -// CHECK: return %[[VAL_9]] : !torch.vtensor<[2,4],f32> +// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_3]], %[[VAL_4]], %[[VAL_5]] : (tensor<2x4xf32>, tensor<2x4xf32>, tensor<1xi8>) -> tensor<2x4xf32> +// CHECK: %[[VAL_7:.*]] = tosa.floor %[[VAL_6]] : (tensor<2x4xf32>) -> tensor<2x4xf32> +// CHECK: %[[VAL_8:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_9:.*]] = tosa.mul %[[VAL_2]], %[[VAL_7]], %[[VAL_8]] : (tensor<2x4xf32>, tensor<2x4xf32>, tensor<1xi8>) -> tensor<2x4xf32> +// CHECK: %[[VAL_10:.*]] = tosa.sub %[[VAL_3]], %[[VAL_9]] : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> +// CHECK: %[[VAL_11:.*]] = torch_c.from_builtin_tensor %[[VAL_10]] : tensor<2x4xf32> -> !torch.vtensor<[2,4],f32> +// CHECK: return %[[VAL_11]] : !torch.vtensor<[2,4],f32> // CHECK: } func.func @torch.aten.remainder.Tensor(%arg0: !torch.vtensor<[2, 4],f32>, %arg1: !torch.vtensor<[2, 4],f32>) -> !torch.vtensor<[2, 4],f32> { %0 = torch.aten.remainder.Tensor %arg0, %arg1 : !torch.vtensor<[2, 4],f32>, !torch.vtensor<[2, 4],f32> -> !torch.vtensor<[2, 4],f32> @@ -1982,22 +2190,28 @@ func.func @torch.aten.remainder.Tensor(%arg0: !torch.vtensor<[2, 4],f32>, %arg1: // CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[2,4],f32> -> tensor<2x4xf32> // CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,4],f32> -> tensor<2x4xf32> // CHECK: %[[VAL_4:.*]] = tosa.reciprocal %[[VAL_2]] : (tensor<2x4xf32>) -> tensor<2x4xf32> -// CHECK: %[[VAL_5:.*]] = tosa.mul %[[VAL_3]], %[[VAL_4]] {shift = 0 : i8} : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> -// CHECK: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor}> : () -> tensor -// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor}> : () -> tensor -// CHECK: %[[VAL_8:.*]] = "tosa.const"() <{value = dense<-1.000000e+00> : tensor}> : () -> tensor -// CHECK: %[[VAL_9:.*]] = tosa.reshape %[[VAL_7]] {new_shape = array} : (tensor) -> tensor<1x1xf32> -// CHECK: %[[VAL_10:.*]] = tosa.reshape %[[VAL_6]] {new_shape = array} : (tensor) -> tensor<1x1xf32> -// CHECK: %[[VAL_11:.*]] = tosa.reshape %[[VAL_8]] {new_shape = array} : (tensor) -> tensor<1x1xf32> -// CHECK: %[[VAL_12:.*]] = tosa.greater_equal %[[VAL_5]], %[[VAL_10]] : (tensor<2x4xf32>, tensor<1x1xf32>) -> tensor<2x4xi1> -// CHECK: %[[VAL_13:.*]] = tosa.select %[[VAL_12]], %[[VAL_9]], %[[VAL_11]] : (tensor<2x4xi1>, tensor<1x1xf32>, tensor<1x1xf32>) -> tensor<2x4xf32> -// CHECK: %[[VAL_14:.*]] = tosa.abs %[[VAL_5]] : (tensor<2x4xf32>) -> tensor<2x4xf32> -// CHECK: %[[VAL_15:.*]] = tosa.floor %[[VAL_14]] : (tensor<2x4xf32>) -> tensor<2x4xf32> -// CHECK: %[[VAL_16:.*]] = tosa.mul %[[VAL_15]], %[[VAL_13]] {shift = 0 : i8} : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> -// CHECK: %[[VAL_17:.*]] = tosa.mul %[[VAL_2]], %[[VAL_16]] {shift = 0 : i8} : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> -// CHECK: %[[VAL_18:.*]] = tosa.sub %[[VAL_3]], %[[VAL_17]] : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> -// CHECK: %[[VAL_19:.*]] = torch_c.from_builtin_tensor %[[VAL_18]] : tensor<2x4xf32> -> !torch.vtensor<[2,4],f32> -// CHECK: return %[[VAL_19]] : !torch.vtensor<[2,4],f32> +// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_3]], %[[VAL_4]], %[[VAL_5]] : (tensor<2x4xf32>, tensor<2x4xf32>, tensor<1xi8>) -> tensor<2x4xf32> +// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor}> : () -> tensor +// CHECK: %[[VAL_8:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor}> : () -> tensor +// CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<-1.000000e+00> : tensor}> : () -> tensor +// CHECK: %[[VAL_10:.*]] = tosa.const_shape {value = dense<1> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_11:.*]] = tosa.reshape %[[VAL_8]], %[[VAL_10]] : (tensor, !tosa.shape<2>) -> tensor<1x1xf32> +// CHECK: %[[VAL_12:.*]] = tosa.const_shape {value = dense<1> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_7]], %[[VAL_12]] : (tensor, !tosa.shape<2>) -> tensor<1x1xf32> +// CHECK: %[[VAL_14:.*]] = tosa.const_shape {value = dense<1> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_15:.*]] = tosa.reshape %[[VAL_9]], %[[VAL_14]] : (tensor, !tosa.shape<2>) -> tensor<1x1xf32> +// CHECK: %[[VAL_16:.*]] = tosa.greater_equal %[[VAL_6]], %[[VAL_13]] : (tensor<2x4xf32>, tensor<1x1xf32>) -> tensor<2x4xi1> +// CHECK: %[[VAL_17:.*]] = tosa.select %[[VAL_16]], %[[VAL_11]], %[[VAL_15]] : (tensor<2x4xi1>, tensor<1x1xf32>, tensor<1x1xf32>) -> tensor<2x4xf32> +// CHECK: %[[VAL_18:.*]] = tosa.abs %[[VAL_6]] : (tensor<2x4xf32>) -> tensor<2x4xf32> +// CHECK: %[[VAL_19:.*]] = tosa.floor %[[VAL_18]] : (tensor<2x4xf32>) -> tensor<2x4xf32> +// CHECK: %[[VAL_20:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_21:.*]] = tosa.mul %[[VAL_19]], %[[VAL_17]], %[[VAL_20]] : (tensor<2x4xf32>, tensor<2x4xf32>, tensor<1xi8>) -> tensor<2x4xf32> +// CHECK: %[[VAL_22:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_23:.*]] = tosa.mul %[[VAL_2]], %[[VAL_21]], %[[VAL_22]] : (tensor<2x4xf32>, tensor<2x4xf32>, tensor<1xi8>) -> tensor<2x4xf32> +// CHECK: %[[VAL_24:.*]] = tosa.sub %[[VAL_3]], %[[VAL_23]] : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> +// CHECK: %[[VAL_25:.*]] = torch_c.from_builtin_tensor %[[VAL_24]] : tensor<2x4xf32> -> !torch.vtensor<[2,4],f32> +// CHECK: return %[[VAL_25]] : !torch.vtensor<[2,4],f32> // CHECK: } func.func @torch.aten.fmod.Tensor(%arg0: !torch.vtensor<[2, 4],f32>, %arg1: !torch.vtensor<[2, 4],f32>) -> !torch.vtensor<[2, 4],f32> { %0 = torch.aten.fmod.Tensor %arg0, %arg1 : !torch.vtensor<[2, 4],f32>, !torch.vtensor<[2, 4],f32> -> !torch.vtensor<[2, 4],f32> @@ -2053,10 +2267,11 @@ func.func @torch.aten.sin(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3 // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4],f32> -> tensor<3x4xf32> // CHECK: %[[VAL_2:.*]] = torch.constant.float 2.000000e+00 // CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<2.000000e+00> : tensor}> : () -> tensor -// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor) -> tensor<1x1xf32> -// CHECK: %[[VAL_5:.*]] = tosa.pow %[[VAL_4]], %[[VAL_1]] : (tensor<1x1xf32>, tensor<3x4xf32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> -// CHECK: return %[[VAL_6]] : !torch.vtensor<[3,4],f32> +// CHECK: %[[VAL_4:.*]] = tosa.const_shape {value = dense<1> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_5:.*]] = tosa.reshape %[[VAL_3]], %[[VAL_4]] : (tensor, !tosa.shape<2>) -> tensor<1x1xf32> +// CHECK: %[[VAL_6:.*]] = tosa.pow %[[VAL_5]], %[[VAL_1]] : (tensor<1x1xf32>, tensor<3x4xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> +// CHECK: return %[[VAL_7]] : !torch.vtensor<[3,4],f32> // CHECK: } func.func @torch.aten.pow.Scalar(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> { %float2.000000e00 = torch.constant.float 2.000000e+00 @@ -2101,11 +2316,12 @@ func.func @torch.aten.erf$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vten // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],si32> -> tensor // CHECK: %[[VAL_2:.*]] = torch.constant.int 2 // CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<2> : tensor}> : () -> tensor -// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor) -> tensor<1x1xi64> -// CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_4]] : (tensor<1x1xi64>) -> tensor<1x1xi32> -// CHECK: %[[VAL_6:.*]] = tosa.bitwise_and %[[VAL_1]], %[[VAL_5]] : (tensor, tensor<1x1xi32>) -> tensor -// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor -> !torch.vtensor<[?,?],si32> -// CHECK: return %[[VAL_7]] : !torch.vtensor<[?,?],si32> +// CHECK: %[[VAL_4:.*]] = tosa.const_shape {value = dense<1> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_5:.*]] = tosa.reshape %[[VAL_3]], %[[VAL_4]] : (tensor, !tosa.shape<2>) -> tensor<1x1xi64> +// CHECK: %[[VAL_6:.*]] = tosa.cast %[[VAL_5]] : (tensor<1x1xi64>) -> tensor<1x1xi32> +// CHECK: %[[VAL_7:.*]] = tosa.bitwise_and %[[VAL_1]], %[[VAL_6]] : (tensor, tensor<1x1xi32>) -> tensor +// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor -> !torch.vtensor<[?,?],si32> +// CHECK: return %[[VAL_8]] : !torch.vtensor<[?,?],si32> // CHECK: } func.func @torch.aten.bitwise_and.Scalar$basic(%arg0: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],si32> { %int2 = torch.constant.int 2 @@ -2136,11 +2352,12 @@ func.func @torch.aten.le.Tensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: ! // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[VAL_2:.*]] = torch.constant.int 2 // CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<2> : tensor}> : () -> tensor -// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor) -> tensor<1x1xi64> -// CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_4]] : (tensor<1x1xi64>) -> tensor<1x1xf32> -// CHECK: %[[VAL_6:.*]] = tosa.greater_equal %[[VAL_5]], %[[VAL_1]] : (tensor<1x1xf32>, tensor) -> tensor -// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor -> !torch.vtensor<[?,?],i1> -// CHECK: return %[[VAL_7]] : !torch.vtensor<[?,?],i1> +// CHECK: %[[VAL_4:.*]] = tosa.const_shape {value = dense<1> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_5:.*]] = tosa.reshape %[[VAL_3]], %[[VAL_4]] : (tensor, !tosa.shape<2>) -> tensor<1x1xi64> +// CHECK: %[[VAL_6:.*]] = tosa.cast %[[VAL_5]] : (tensor<1x1xi64>) -> tensor<1x1xf32> +// CHECK: %[[VAL_7:.*]] = tosa.greater_equal %[[VAL_6]], %[[VAL_1]] : (tensor<1x1xf32>, tensor) -> tensor +// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor -> !torch.vtensor<[?,?],i1> +// CHECK: return %[[VAL_8]] : !torch.vtensor<[?,?],i1> // CHECK: } func.func @torch.aten.le.Scalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> { %int2 = torch.constant.int 2 @@ -2207,12 +2424,16 @@ func.func @torch.aten.bitwise_right_shift.Tensor$basic(%arg0: !torch.vtensor<[?, // CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<[2, 3, 1, 0]> : tensor<4xi32>}> : () -> tensor<4xi32> // CHECK: %[[VAL_6:.*]] = tosa.transpose %[[VAL_1]], %[[VAL_5]] : (tensor<3x4x5x6xi32>, tensor<4xi32>) -> tensor<5x6x4x3xi32> // CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}0, 0, 0], [0, 0, 0], [1, 0, 0], [0, 1, 0]]]]> : tensor<1x1x4x3xi32>}> : () -> tensor<1x1x4x3xi32> -// CHECK: %[[VAL_8:.*]] = tosa.mul %[[VAL_6]], %[[VAL_7]] {shift = 0 : i8} : (tensor<5x6x4x3xi32>, tensor<1x1x4x3xi32>) -> tensor<5x6x4x3xi32> -// CHECK: %[[VAL_9:.*]] = tosa.slice %[[VAL_8]] {size = array, start = array} : (tensor<5x6x4x3xi32>) -> tensor<5x6x2x3xi32> -// CHECK: %[[VAL_10:.*]] = tosa.reduce_sum %[[VAL_9]] {axis = 3 : i32} : (tensor<5x6x2x3xi32>) -> tensor<5x6x2x1xi32> -// CHECK: %[[VAL_11:.*]] = tosa.reshape %[[VAL_10]] {new_shape = array} : (tensor<5x6x2x1xi32>) -> tensor<5x6x2xi32> -// CHECK: %[[VAL_12:.*]] = torch_c.from_builtin_tensor %[[VAL_11]] : tensor<5x6x2xi32> -> !torch.vtensor<[5,6,2],si32> -// CHECK: return %[[VAL_12]] : !torch.vtensor<[5,6,2],si32> +// CHECK: %[[VAL_8:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_9:.*]] = tosa.mul %[[VAL_6]], %[[VAL_7]], %[[VAL_8]] : (tensor<5x6x4x3xi32>, tensor<1x1x4x3xi32>, tensor<1xi8>) -> tensor<5x6x4x3xi32> +// CHECK: %[[VAL_10:.*]] = tosa.const_shape {value = dense<[0, 0, 2, 0]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_11:.*]] = tosa.const_shape {value = dense<[5, 6, 2, 3]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_12:.*]] = tosa.slice %[[VAL_9]], %[[VAL_10]], %[[VAL_11]] : (tensor<5x6x4x3xi32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<5x6x2x3xi32> +// CHECK: %[[VAL_13:.*]] = tosa.reduce_sum %[[VAL_12]] {axis = 3 : i32} : (tensor<5x6x2x3xi32>) -> tensor<5x6x2x1xi32> +// CHECK: %[[VAL_14:.*]] = tosa.const_shape {value = dense<[5, 6, 2]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_15:.*]] = tosa.reshape %[[VAL_13]], %[[VAL_14]] : (tensor<5x6x2x1xi32>, !tosa.shape<3>) -> tensor<5x6x2xi32> +// CHECK: %[[VAL_16:.*]] = torch_c.from_builtin_tensor %[[VAL_15]] : tensor<5x6x2xi32> -> !torch.vtensor<[5,6,2],si32> +// CHECK: return %[[VAL_16]] : !torch.vtensor<[5,6,2],si32> // CHECK: } func.func @torch.aten.diagonal$basic(%arg0: !torch.vtensor<[3,4,5,6], si32>) -> !torch.vtensor<[5,6,2], si32> { %dim1 = torch.constant.int 1 @@ -2231,24 +2452,32 @@ func.func @torch.aten.diagonal$basic(%arg0: !torch.vtensor<[3,4,5,6], si32>) -> // CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,5,6],f32> -> tensor<4x5x6xf32> // CHECK: %[[VAL_4:.*]] = torch.constant.int 2 // CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_2]] : (tensor<2xi64>) -> tensor<2xi32> -// CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_5]] {new_shape = array} : (tensor<2xi32>) -> tensor<1x1x2xi32> -// CHECK: %[[VAL_7:.*]] = tosa.const_shape {value = dense<[4, 5, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> -// CHECK: %[[VAL_8:.*]] = tosa.tile %[[VAL_6]], %[[VAL_7]] : (tensor<1x1x2xi32>, !tosa.shape<3>) -> tensor<4x5x2xi32> -// CHECK: %[[VAL_9:.*]] = tosa.reshape %[[VAL_8]] {new_shape = array} : (tensor<4x5x2xi32>) -> tensor<4x5x2x1xi32> -// CHECK: %[[VAL_10:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}0], [0]], {{\[\[}}0], [0]], {{\[\[}}0], [0]], {{\[\[}}0], [0]], {{\[\[}}0], [0]]], {{\[\[}}[1], [1]], {{\[\[}}1], [1]], {{\[\[}}1], [1]], {{\[\[}}1], [1]], {{\[\[}}1], [1]]], {{\[\[}}[2], [2]], {{\[\[}}2], [2]], {{\[\[}}2], [2]], {{\[\[}}2], [2]], {{\[\[}}2], [2]]], {{\[\[}}[3], [3]], {{\[\[}}3], [3]], {{\[\[}}3], [3]], {{\[\[}}3], [3]], {{\[\[}}3], [3]]]]> : tensor<4x5x2x1xi32>}> : () -> tensor<4x5x2x1xi32> -// CHECK: %[[VAL_11:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]], {{\[\[}}[0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]], {{\[\[}}[0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]], {{\[\[}}[0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]]]> : tensor<4x5x2x1xi32>}> : () -> tensor<4x5x2x1xi32> -// CHECK: %[[VAL_12:.*]] = tosa.concat %[[VAL_10]], %[[VAL_11]], %[[VAL_9]] {axis = 3 : i32} : (tensor<4x5x2x1xi32>, tensor<4x5x2x1xi32>, tensor<4x5x2x1xi32>) -> tensor<4x5x2x3xi32> -// CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor<4x5x6xf32>) -> tensor<1x120x1xf32> -// CHECK: %[[VAL_14:.*]] = tosa.reshape %[[VAL_12]] {new_shape = array} : (tensor<4x5x2x3xi32>) -> tensor<40x3xi32> -// CHECK: %[[VAL_15:.*]] = "tosa.const"() <{value = dense<[30, 6, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> -// CHECK: %[[VAL_16:.*]] = tosa.reshape %[[VAL_15]] {new_shape = array} : (tensor<3xi32>) -> tensor<1x3xi32> -// CHECK: %[[VAL_17:.*]] = tosa.mul %[[VAL_14]], %[[VAL_16]] {shift = 0 : i8} : (tensor<40x3xi32>, tensor<1x3xi32>) -> tensor<40x3xi32> -// CHECK: %[[VAL_18:.*]] = tosa.reduce_sum %[[VAL_17]] {axis = 1 : i32} : (tensor<40x3xi32>) -> tensor<40x1xi32> -// CHECK: %[[VAL_19:.*]] = tosa.reshape %[[VAL_18]] {new_shape = array} : (tensor<40x1xi32>) -> tensor<1x40xi32> -// CHECK: %[[VAL_20:.*]] = tosa.gather %[[VAL_13]], %[[VAL_19]] : (tensor<1x120x1xf32>, tensor<1x40xi32>) -> tensor<1x40x1xf32> -// CHECK: %[[VAL_21:.*]] = tosa.reshape %[[VAL_20]] {new_shape = array} : (tensor<1x40x1xf32>) -> tensor<4x5x2xf32> -// CHECK: %[[VAL_22:.*]] = torch_c.from_builtin_tensor %[[VAL_21]] : tensor<4x5x2xf32> -> !torch.vtensor<[4,5,2],f32> -// CHECK: return %[[VAL_22]] : !torch.vtensor<[4,5,2],f32> +// CHECK: %[[VAL_6:.*]] = tosa.const_shape {value = dense<[1, 1, 2]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_7:.*]] = tosa.reshape %[[VAL_5]], %[[VAL_6]] : (tensor<2xi32>, !tosa.shape<3>) -> tensor<1x1x2xi32> +// CHECK: %[[VAL_8:.*]] = tosa.const_shape {value = dense<[4, 5, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_9:.*]] = tosa.tile %[[VAL_7]], %[[VAL_8]] : (tensor<1x1x2xi32>, !tosa.shape<3>) -> tensor<4x5x2xi32> +// CHECK: %[[VAL_10:.*]] = tosa.const_shape {value = dense<[4, 5, 2, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_11:.*]] = tosa.reshape %[[VAL_9]], %[[VAL_10]] : (tensor<4x5x2xi32>, !tosa.shape<4>) -> tensor<4x5x2x1xi32> +// CHECK: %[[VAL_12:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}0], [0]], {{\[\[}}0], [0]], {{\[\[}}0], [0]], {{\[\[}}0], [0]], {{\[\[}}0], [0]]], {{\[\[}}[1], [1]], {{\[\[}}1], [1]], {{\[\[}}1], [1]], {{\[\[}}1], [1]], {{\[\[}}1], [1]]], {{\[\[}}[2], [2]], {{\[\[}}2], [2]], {{\[\[}}2], [2]], {{\[\[}}2], [2]], {{\[\[}}2], [2]]], {{\[\[}}[3], [3]], {{\[\[}}3], [3]], {{\[\[}}3], [3]], {{\[\[}}3], [3]], {{\[\[}}3], [3]]]]> : tensor<4x5x2x1xi32>}> : () -> tensor<4x5x2x1xi32> +// CHECK: %[[VAL_13:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]], {{\[\[}}[0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]], {{\[\[}}[0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]], {{\[\[}}[0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]]]> : tensor<4x5x2x1xi32>}> : () -> tensor<4x5x2x1xi32> +// CHECK: %[[VAL_14:.*]] = tosa.concat %[[VAL_12]], %[[VAL_13]], %[[VAL_11]] {axis = 3 : i32} : (tensor<4x5x2x1xi32>, tensor<4x5x2x1xi32>, tensor<4x5x2x1xi32>) -> tensor<4x5x2x3xi32> +// CHECK: %[[VAL_15:.*]] = tosa.const_shape {value = dense<[1, 120, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_16:.*]] = tosa.reshape %[[VAL_3]], %[[VAL_15]] : (tensor<4x5x6xf32>, !tosa.shape<3>) -> tensor<1x120x1xf32> +// CHECK: %[[VAL_17:.*]] = tosa.const_shape {value = dense<[40, 3]> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_18:.*]] = tosa.reshape %[[VAL_14]], %[[VAL_17]] : (tensor<4x5x2x3xi32>, !tosa.shape<2>) -> tensor<40x3xi32> +// CHECK: %[[VAL_19:.*]] = "tosa.const"() <{value = dense<[30, 6, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> +// CHECK: %[[VAL_20:.*]] = tosa.const_shape {value = dense<[1, 3]> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_21:.*]] = tosa.reshape %[[VAL_19]], %[[VAL_20]] : (tensor<3xi32>, !tosa.shape<2>) -> tensor<1x3xi32> +// CHECK: %[[VAL_22:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_23:.*]] = tosa.mul %[[VAL_18]], %[[VAL_21]], %[[VAL_22]] : (tensor<40x3xi32>, tensor<1x3xi32>, tensor<1xi8>) -> tensor<40x3xi32> +// CHECK: %[[VAL_24:.*]] = tosa.reduce_sum %[[VAL_23]] {axis = 1 : i32} : (tensor<40x3xi32>) -> tensor<40x1xi32> +// CHECK: %[[VAL_25:.*]] = tosa.const_shape {value = dense<[1, 40]> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_26:.*]] = tosa.reshape %[[VAL_24]], %[[VAL_25]] : (tensor<40x1xi32>, !tosa.shape<2>) -> tensor<1x40xi32> +// CHECK: %[[VAL_27:.*]] = tosa.gather %[[VAL_16]], %[[VAL_26]] : (tensor<1x120x1xf32>, tensor<1x40xi32>) -> tensor<1x40x1xf32> +// CHECK: %[[VAL_28:.*]] = tosa.const_shape {value = dense<[4, 5, 2]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_29:.*]] = tosa.reshape %[[VAL_27]], %[[VAL_28]] : (tensor<1x40x1xf32>, !tosa.shape<3>) -> tensor<4x5x2xf32> +// CHECK: %[[VAL_30:.*]] = torch_c.from_builtin_tensor %[[VAL_29]] : tensor<4x5x2xf32> -> !torch.vtensor<[4,5,2],f32> +// CHECK: return %[[VAL_30]] : !torch.vtensor<[4,5,2],f32> // CHECK: } func.func @torch.aten.index_select(%arg0: !torch.vtensor<[4,5,6],f32>, %arg1: !torch.vtensor<[2],si64>) -> !torch.vtensor<[4,5,2],f32> { %int2 = torch.constant.int 2 @@ -2262,9 +2491,8 @@ func.func @torch.aten.index_select(%arg0: !torch.vtensor<[4,5,6],f32>, %arg1: !t // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,12,128,128],f32>) -> !torch.vtensor<[1,12,128,128],f32> { // CHECK: %[[VAL_1:.*]] = torch.constant.int 0 // CHECK: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1x12x128x128xf32>}> : () -> tensor<1x12x128x128xf32> -// CHECK: %[[VAL_3:.*]] = tosa.cast %[[VAL_2]] : (tensor<1x12x128x128xf32>) -> tensor<1x12x128x128xf32> -// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<1x12x128x128xf32> -> !torch.vtensor<[1,12,128,128],f32> -// CHECK: return %[[VAL_4]] : !torch.vtensor<[1,12,128,128],f32> +// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<1x12x128x128xf32> -> !torch.vtensor<[1,12,128,128],f32> +// CHECK: return %[[VAL_3]] : !torch.vtensor<[1,12,128,128],f32> // CHECK: } func.func @torch.aten.fill.Scalar(%arg0: !torch.vtensor<[1,12,128,128],f32>) -> !torch.vtensor<[1,12,128,128],f32> { %int0 = torch.constant.int 0 @@ -2278,12 +2506,13 @@ func.func @torch.aten.fill.Scalar(%arg0: !torch.vtensor<[1,12,128,128],f32>) -> // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,12,128,128],f32>, // CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[1],si32>) -> !torch.vtensor<[1,12,128,128],f32> { // CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[1],si32> -> tensor<1xi32> -// CHECK: %[[VAL_3:.*]] = tosa.reshape %[[VAL_2]] {new_shape = array} : (tensor<1xi32>) -> tensor<1x1x1x1xi32> -// CHECK: %[[VAL_4:.*]] = tosa.const_shape {value = dense<[1, 12, 128, 128]> : tensor<4xindex>} : () -> !tosa.shape<4> -// CHECK: %[[VAL_5:.*]] = tosa.tile %[[VAL_3]], %[[VAL_4]] : (tensor<1x1x1x1xi32>, !tosa.shape<4>) -> tensor<1x12x128x128xi32> -// CHECK: %[[VAL_6:.*]] = tosa.cast %[[VAL_5]] : (tensor<1x12x128x128xi32>) -> tensor<1x12x128x128xf32> -// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<1x12x128x128xf32> -> !torch.vtensor<[1,12,128,128],f32> -// CHECK: return %[[VAL_7]] : !torch.vtensor<[1,12,128,128],f32> +// CHECK: %[[VAL_3:.*]] = tosa.const_shape {value = dense<1> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_2]], %[[VAL_3]] : (tensor<1xi32>, !tosa.shape<4>) -> tensor<1x1x1x1xi32> +// CHECK: %[[VAL_5:.*]] = tosa.const_shape {value = dense<[1, 12, 128, 128]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_6:.*]] = tosa.tile %[[VAL_4]], %[[VAL_5]] : (tensor<1x1x1x1xi32>, !tosa.shape<4>) -> tensor<1x12x128x128xi32> +// CHECK: %[[VAL_7:.*]] = tosa.cast %[[VAL_6]] : (tensor<1x12x128x128xi32>) -> tensor<1x12x128x128xf32> +// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<1x12x128x128xf32> -> !torch.vtensor<[1,12,128,128],f32> +// CHECK: return %[[VAL_8]] : !torch.vtensor<[1,12,128,128],f32> // CHECK: } func.func @torch.aten.fill.Tensor(%arg0: !torch.vtensor<[1,12,128,128],f32>, %arg1: !torch.vtensor<[1],si32>) -> !torch.vtensor<[1,12,128,128],f32> { %0 = torch.aten.fill.Tensor %arg0, %arg1 : !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[1],si32> -> !torch.vtensor<[1,12,128,128],f32> @@ -2318,22 +2547,26 @@ func.func @torch.aten.flip(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4,5],f32> -> tensor<3x4x5xf32> // CHECK: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<5.000000e-01> : tensor}> : () -> tensor // CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<2.000000e+00> : tensor}> : () -> tensor -// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_2]] {new_shape = array} : (tensor) -> tensor<1x1x1xf32> -// CHECK: %[[VAL_5:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor) -> tensor<1x1x1xf32> -// CHECK: %[[VAL_6:.*]] = tosa.floor %[[VAL_1]] : (tensor<3x4x5xf32>) -> tensor<3x4x5xf32> -// CHECK: %[[VAL_7:.*]] = tosa.sub %[[VAL_1]], %[[VAL_6]] : (tensor<3x4x5xf32>, tensor<3x4x5xf32>) -> tensor<3x4x5xf32> -// CHECK: %[[VAL_8:.*]] = tosa.ceil %[[VAL_1]] : (tensor<3x4x5xf32>) -> tensor<3x4x5xf32> -// CHECK: %[[VAL_9:.*]] = tosa.mul %[[VAL_6]], %[[VAL_4]] {shift = 0 : i8} : (tensor<3x4x5xf32>, tensor<1x1x1xf32>) -> tensor<3x4x5xf32> -// CHECK: %[[VAL_10:.*]] = tosa.floor %[[VAL_9]] : (tensor<3x4x5xf32>) -> tensor<3x4x5xf32> -// CHECK: %[[VAL_11:.*]] = tosa.mul %[[VAL_10]], %[[VAL_5]] {shift = 0 : i8} : (tensor<3x4x5xf32>, tensor<1x1x1xf32>) -> tensor<3x4x5xf32> -// CHECK: %[[VAL_12:.*]] = tosa.equal %[[VAL_6]], %[[VAL_11]] : (tensor<3x4x5xf32>, tensor<3x4x5xf32>) -> tensor<3x4x5xi1> -// CHECK: %[[VAL_13:.*]] = tosa.equal %[[VAL_7]], %[[VAL_4]] : (tensor<3x4x5xf32>, tensor<1x1x1xf32>) -> tensor<3x4x5xi1> -// CHECK: %[[VAL_14:.*]] = tosa.greater %[[VAL_4]], %[[VAL_7]] : (tensor<1x1x1xf32>, tensor<3x4x5xf32>) -> tensor<3x4x5xi1> -// CHECK: %[[VAL_15:.*]] = tosa.logical_and %[[VAL_13]], %[[VAL_12]] : (tensor<3x4x5xi1>, tensor<3x4x5xi1>) -> tensor<3x4x5xi1> -// CHECK: %[[VAL_16:.*]] = tosa.logical_or %[[VAL_14]], %[[VAL_15]] : (tensor<3x4x5xi1>, tensor<3x4x5xi1>) -> tensor<3x4x5xi1> -// CHECK: %[[VAL_17:.*]] = tosa.select %[[VAL_16]], %[[VAL_6]], %[[VAL_8]] : (tensor<3x4x5xi1>, tensor<3x4x5xf32>, tensor<3x4x5xf32>) -> tensor<3x4x5xf32> -// CHECK: %[[VAL_18:.*]] = torch_c.from_builtin_tensor %[[VAL_17]] : tensor<3x4x5xf32> -> !torch.vtensor<[3,4,5],f32> -// CHECK: return %[[VAL_18]] : !torch.vtensor<[3,4,5],f32> +// CHECK: %[[VAL_4:.*]] = tosa.const_shape {value = dense<1> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_5:.*]] = tosa.reshape %[[VAL_2]], %[[VAL_4]] : (tensor, !tosa.shape<3>) -> tensor<1x1x1xf32> +// CHECK: %[[VAL_6:.*]] = tosa.const_shape {value = dense<1> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_7:.*]] = tosa.reshape %[[VAL_3]], %[[VAL_6]] : (tensor, !tosa.shape<3>) -> tensor<1x1x1xf32> +// CHECK: %[[VAL_8:.*]] = tosa.floor %[[VAL_1]] : (tensor<3x4x5xf32>) -> tensor<3x4x5xf32> +// CHECK: %[[VAL_9:.*]] = tosa.sub %[[VAL_1]], %[[VAL_8]] : (tensor<3x4x5xf32>, tensor<3x4x5xf32>) -> tensor<3x4x5xf32> +// CHECK: %[[VAL_10:.*]] = tosa.ceil %[[VAL_1]] : (tensor<3x4x5xf32>) -> tensor<3x4x5xf32> +// CHECK: %[[VAL_11:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_12:.*]] = tosa.mul %[[VAL_8]], %[[VAL_5]], %[[VAL_11]] : (tensor<3x4x5xf32>, tensor<1x1x1xf32>, tensor<1xi8>) -> tensor<3x4x5xf32> +// CHECK: %[[VAL_13:.*]] = tosa.floor %[[VAL_12]] : (tensor<3x4x5xf32>) -> tensor<3x4x5xf32> +// CHECK: %[[VAL_14:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_15:.*]] = tosa.mul %[[VAL_13]], %[[VAL_7]], %[[VAL_14]] : (tensor<3x4x5xf32>, tensor<1x1x1xf32>, tensor<1xi8>) -> tensor<3x4x5xf32> +// CHECK: %[[VAL_16:.*]] = tosa.equal %[[VAL_8]], %[[VAL_15]] : (tensor<3x4x5xf32>, tensor<3x4x5xf32>) -> tensor<3x4x5xi1> +// CHECK: %[[VAL_17:.*]] = tosa.equal %[[VAL_9]], %[[VAL_5]] : (tensor<3x4x5xf32>, tensor<1x1x1xf32>) -> tensor<3x4x5xi1> +// CHECK: %[[VAL_18:.*]] = tosa.greater %[[VAL_5]], %[[VAL_9]] : (tensor<1x1x1xf32>, tensor<3x4x5xf32>) -> tensor<3x4x5xi1> +// CHECK: %[[VAL_19:.*]] = tosa.logical_and %[[VAL_17]], %[[VAL_16]] : (tensor<3x4x5xi1>, tensor<3x4x5xi1>) -> tensor<3x4x5xi1> +// CHECK: %[[VAL_20:.*]] = tosa.logical_or %[[VAL_18]], %[[VAL_19]] : (tensor<3x4x5xi1>, tensor<3x4x5xi1>) -> tensor<3x4x5xi1> +// CHECK: %[[VAL_21:.*]] = tosa.select %[[VAL_20]], %[[VAL_8]], %[[VAL_10]] : (tensor<3x4x5xi1>, tensor<3x4x5xf32>, tensor<3x4x5xf32>) -> tensor<3x4x5xf32> +// CHECK: %[[VAL_22:.*]] = torch_c.from_builtin_tensor %[[VAL_21]] : tensor<3x4x5xf32> -> !torch.vtensor<[3,4,5],f32> +// CHECK: return %[[VAL_22]] : !torch.vtensor<[3,4,5],f32> // CHECK: } func.func @torch.aten.round(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> { %0 = torch.aten.round %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> @@ -2389,9 +2622,8 @@ func.func @torch.aten.avg_pool2d.divisor_override_unsupported_value(%arg0: !torc // CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<0> : tensor<3x4xi32>}> : () -> tensor<3x4xi32> // CHECK: %[[VAL_8:.*]] = tosa.cast %[[VAL_7]] : (tensor<3x4xi32>) -> tensor<3x4xi64> // CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<0> : tensor<3x4xi64>}> : () -> tensor<3x4xi64> -// CHECK: %[[VAL_10:.*]] = tosa.cast %[[VAL_9]] : (tensor<3x4xi64>) -> tensor<3x4xi64> -// CHECK: %[[VAL_11:.*]] = torch_c.from_builtin_tensor %[[VAL_10]] : tensor<3x4xi64> -> !torch.vtensor<[3,4],si64> -// CHECK: return %[[VAL_11]] : !torch.vtensor<[3,4],si64> +// CHECK: %[[VAL_10:.*]] = torch_c.from_builtin_tensor %[[VAL_9]] : tensor<3x4xi64> -> !torch.vtensor<[3,4],si64> +// CHECK: return %[[VAL_10]] : !torch.vtensor<[3,4],si64> // CHECK: } func.func @torch.aten.empty.memory_format$basic() -> !torch.vtensor<[3,4],si64> { %int0 = torch.constant.int 0 @@ -2417,22 +2649,30 @@ func.func @torch.aten.empty.memory_format$basic() -> !torch.vtensor<[3,4],si64> // CHECK: %[[VAL_5:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[10,8,6],f32> -> tensor<10x8x6xf32> // CHECK: %[[VAL_6:.*]] = torch.constant.int 1 // CHECK: %[[VAL_7:.*]] = tosa.cast %[[VAL_4]] : (tensor<2x4x3xi64>) -> tensor<2x4x3xi32> -// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_7]] {new_shape = array} : (tensor<2x4x3xi32>) -> tensor<2x4x3x1xi32> -// CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}0], [0], [0]], {{\[\[}}0], [0], [0]], {{\[\[}}0], [0], [0]], {{\[\[}}0], [0], [0]]], {{\[\[}}[1], [1], [1]], {{\[\[}}1], [1], [1]], {{\[\[}}1], [1], [1]], {{\[\[}}1], [1], [1]]]]> : tensor<2x4x3x1xi32>}> : () -> tensor<2x4x3x1xi32> -// CHECK: %[[VAL_10:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}0], [1], [2]], {{\[\[}}0], [1], [2]], {{\[\[}}0], [1], [2]], {{\[\[}}0], [1], [2]]], {{\[\[}}[0], [1], [2]], {{\[\[}}0], [1], [2]], {{\[\[}}0], [1], [2]], {{\[\[}}0], [1], [2]]]]> : tensor<2x4x3x1xi32>}> : () -> tensor<2x4x3x1xi32> -// CHECK: %[[VAL_11:.*]] = tosa.concat %[[VAL_9]], %[[VAL_8]], %[[VAL_10]] {axis = 3 : i32} : (tensor<2x4x3x1xi32>, tensor<2x4x3x1xi32>, tensor<2x4x3x1xi32>) -> tensor<2x4x3x3xi32> -// CHECK: %[[VAL_12:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor<3x4x3xf32>) -> tensor<1x36x1xf32> -// CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_5]] {new_shape = array} : (tensor<10x8x6xf32>) -> tensor<1x480x1xf32> -// CHECK: %[[VAL_14:.*]] = tosa.reshape %[[VAL_11]] {new_shape = array} : (tensor<2x4x3x3xi32>) -> tensor<24x3xi32> -// CHECK: %[[VAL_15:.*]] = "tosa.const"() <{value = dense<[48, 6, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> -// CHECK: %[[VAL_16:.*]] = tosa.reshape %[[VAL_15]] {new_shape = array} : (tensor<3xi32>) -> tensor<1x3xi32> -// CHECK: %[[VAL_17:.*]] = tosa.mul %[[VAL_14]], %[[VAL_16]] {shift = 0 : i8} : (tensor<24x3xi32>, tensor<1x3xi32>) -> tensor<24x3xi32> -// CHECK: %[[VAL_18:.*]] = tosa.reduce_sum %[[VAL_17]] {axis = 1 : i32} : (tensor<24x3xi32>) -> tensor<24x1xi32> -// CHECK: %[[VAL_19:.*]] = tosa.reshape %[[VAL_18]] {new_shape = array} : (tensor<24x1xi32>) -> tensor<1x24xi32> -// CHECK: %[[VAL_20:.*]] = tosa.scatter %[[VAL_13]], %[[VAL_19]], %[[VAL_12]] : (tensor<1x480x1xf32>, tensor<1x24xi32>, tensor<1x36x1xf32>) -> tensor<1x480x1xf32> -// CHECK: %[[VAL_21:.*]] = tosa.reshape %[[VAL_20]] {new_shape = array} : (tensor<1x480x1xf32>) -> tensor<10x8x6xf32> -// CHECK: %[[VAL_22:.*]] = torch_c.from_builtin_tensor %[[VAL_21]] : tensor<10x8x6xf32> -> !torch.vtensor<[10,8,6],f32> -// CHECK: return %[[VAL_22]] : !torch.vtensor<[10,8,6],f32> +// CHECK: %[[VAL_8:.*]] = tosa.const_shape {value = dense<[2, 4, 3, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_9:.*]] = tosa.reshape %[[VAL_7]], %[[VAL_8]] : (tensor<2x4x3xi32>, !tosa.shape<4>) -> tensor<2x4x3x1xi32> +// CHECK: %[[VAL_10:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}0], [0], [0]], {{\[\[}}0], [0], [0]], {{\[\[}}0], [0], [0]], {{\[\[}}0], [0], [0]]], {{\[\[}}[1], [1], [1]], {{\[\[}}1], [1], [1]], {{\[\[}}1], [1], [1]], {{\[\[}}1], [1], [1]]]]> : tensor<2x4x3x1xi32>}> : () -> tensor<2x4x3x1xi32> +// CHECK: %[[VAL_11:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}0], [1], [2]], {{\[\[}}0], [1], [2]], {{\[\[}}0], [1], [2]], {{\[\[}}0], [1], [2]]], {{\[\[}}[0], [1], [2]], {{\[\[}}0], [1], [2]], {{\[\[}}0], [1], [2]], {{\[\[}}0], [1], [2]]]]> : tensor<2x4x3x1xi32>}> : () -> tensor<2x4x3x1xi32> +// CHECK: %[[VAL_12:.*]] = tosa.concat %[[VAL_10]], %[[VAL_9]], %[[VAL_11]] {axis = 3 : i32} : (tensor<2x4x3x1xi32>, tensor<2x4x3x1xi32>, tensor<2x4x3x1xi32>) -> tensor<2x4x3x3xi32> +// CHECK: %[[VAL_13:.*]] = tosa.const_shape {value = dense<[1, 36, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_14:.*]] = tosa.reshape %[[VAL_3]], %[[VAL_13]] : (tensor<3x4x3xf32>, !tosa.shape<3>) -> tensor<1x36x1xf32> +// CHECK: %[[VAL_15:.*]] = tosa.const_shape {value = dense<[1, 480, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_16:.*]] = tosa.reshape %[[VAL_5]], %[[VAL_15]] : (tensor<10x8x6xf32>, !tosa.shape<3>) -> tensor<1x480x1xf32> +// CHECK: %[[VAL_17:.*]] = tosa.const_shape {value = dense<[24, 3]> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_18:.*]] = tosa.reshape %[[VAL_12]], %[[VAL_17]] : (tensor<2x4x3x3xi32>, !tosa.shape<2>) -> tensor<24x3xi32> +// CHECK: %[[VAL_19:.*]] = "tosa.const"() <{value = dense<[48, 6, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> +// CHECK: %[[VAL_20:.*]] = tosa.const_shape {value = dense<[1, 3]> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_21:.*]] = tosa.reshape %[[VAL_19]], %[[VAL_20]] : (tensor<3xi32>, !tosa.shape<2>) -> tensor<1x3xi32> +// CHECK: %[[VAL_22:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_23:.*]] = tosa.mul %[[VAL_18]], %[[VAL_21]], %[[VAL_22]] : (tensor<24x3xi32>, tensor<1x3xi32>, tensor<1xi8>) -> tensor<24x3xi32> +// CHECK: %[[VAL_24:.*]] = tosa.reduce_sum %[[VAL_23]] {axis = 1 : i32} : (tensor<24x3xi32>) -> tensor<24x1xi32> +// CHECK: %[[VAL_25:.*]] = tosa.const_shape {value = dense<[1, 24]> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_26:.*]] = tosa.reshape %[[VAL_24]], %[[VAL_25]] : (tensor<24x1xi32>, !tosa.shape<2>) -> tensor<1x24xi32> +// CHECK: %[[VAL_27:.*]] = tosa.scatter %[[VAL_16]], %[[VAL_26]], %[[VAL_14]] : (tensor<1x480x1xf32>, tensor<1x24xi32>, tensor<1x36x1xf32>) -> tensor<1x480x1xf32> +// CHECK: %[[VAL_28:.*]] = tosa.const_shape {value = dense<[10, 8, 6]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_29:.*]] = tosa.reshape %[[VAL_27]], %[[VAL_28]] : (tensor<1x480x1xf32>, !tosa.shape<3>) -> tensor<10x8x6xf32> +// CHECK: %[[VAL_30:.*]] = torch_c.from_builtin_tensor %[[VAL_29]] : tensor<10x8x6xf32> -> !torch.vtensor<[10,8,6],f32> +// CHECK: return %[[VAL_30]] : !torch.vtensor<[10,8,6],f32> // CHECK: } func.func @torch.aten.scatter.src$basic(%arg0: !torch.vtensor<[10,8,6],f32>, %arg1: !torch.vtensor<[2,4,3],si64>, %arg2: !torch.vtensor<[3,4,3],f32>) -> !torch.vtensor<[10,8,6],f32> { %int1 = torch.constant.int 1 @@ -2450,21 +2690,29 @@ func.func @torch.aten.scatter.src$basic(%arg0: !torch.vtensor<[10,8,6],f32>, %ar // CHECK: %[[VAL_4:.*]] = torch.constant.int 1 // CHECK: %[[VAL_5:.*]] = torch.constant.int 0 // CHECK: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<0> : tensor<6x1xi32>}> : () -> tensor<6x1xi32> -// CHECK: %[[VAL_7:.*]] = tosa.reshape %[[VAL_6]] {new_shape = array} : (tensor<6x1xi32>) -> tensor<6x1x1xi32> -// CHECK: %[[VAL_8:.*]] = "tosa.const"() <{value = dense<{{\[\[}}[0]], {{\[\[}}1]], {{\[\[}}2]], {{\[\[}}3]], {{\[\[}}4]], {{\[\[}}5]]]> : tensor<6x1x1xi32>}> : () -> tensor<6x1x1xi32> -// CHECK: %[[VAL_9:.*]] = tosa.concat %[[VAL_8]], %[[VAL_7]] {axis = 2 : i32} : (tensor<6x1x1xi32>, tensor<6x1x1xi32>) -> tensor<6x1x2xi32> -// CHECK: %[[VAL_10:.*]] = tosa.reshape %[[VAL_2]] {new_shape = array} : (tensor<6x1xf32>) -> tensor<1x6x1xf32> -// CHECK: %[[VAL_11:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor<6x8xf32>) -> tensor<1x48x1xf32> -// CHECK: %[[VAL_12:.*]] = tosa.reshape %[[VAL_9]] {new_shape = array} : (tensor<6x1x2xi32>) -> tensor<6x2xi32> -// CHECK: %[[VAL_13:.*]] = "tosa.const"() <{value = dense<[8, 1]> : tensor<2xi32>}> : () -> tensor<2xi32> -// CHECK: %[[VAL_14:.*]] = tosa.reshape %[[VAL_13]] {new_shape = array} : (tensor<2xi32>) -> tensor<1x2xi32> -// CHECK: %[[VAL_15:.*]] = tosa.mul %[[VAL_12]], %[[VAL_14]] {shift = 0 : i8} : (tensor<6x2xi32>, tensor<1x2xi32>) -> tensor<6x2xi32> -// CHECK: %[[VAL_16:.*]] = tosa.reduce_sum %[[VAL_15]] {axis = 1 : i32} : (tensor<6x2xi32>) -> tensor<6x1xi32> -// CHECK: %[[VAL_17:.*]] = tosa.reshape %[[VAL_16]] {new_shape = array} : (tensor<6x1xi32>) -> tensor<1x6xi32> -// CHECK: %[[VAL_18:.*]] = tosa.scatter %[[VAL_11]], %[[VAL_17]], %[[VAL_10]] : (tensor<1x48x1xf32>, tensor<1x6xi32>, tensor<1x6x1xf32>) -> tensor<1x48x1xf32> -// CHECK: %[[VAL_19:.*]] = tosa.reshape %[[VAL_18]] {new_shape = array} : (tensor<1x48x1xf32>) -> tensor<6x8xf32> -// CHECK: %[[VAL_20:.*]] = torch_c.from_builtin_tensor %[[VAL_19]] : tensor<6x8xf32> -> !torch.vtensor<[6,8],f32> -// CHECK: return %[[VAL_20]] : !torch.vtensor<[6,8],f32> +// CHECK: %[[VAL_7:.*]] = tosa.const_shape {value = dense<[6, 1, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_6]], %[[VAL_7]] : (tensor<6x1xi32>, !tosa.shape<3>) -> tensor<6x1x1xi32> +// CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<{{\[\[}}[0]], {{\[\[}}1]], {{\[\[}}2]], {{\[\[}}3]], {{\[\[}}4]], {{\[\[}}5]]]> : tensor<6x1x1xi32>}> : () -> tensor<6x1x1xi32> +// CHECK: %[[VAL_10:.*]] = tosa.concat %[[VAL_9]], %[[VAL_8]] {axis = 2 : i32} : (tensor<6x1x1xi32>, tensor<6x1x1xi32>) -> tensor<6x1x2xi32> +// CHECK: %[[VAL_11:.*]] = tosa.const_shape {value = dense<[1, 6, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_12:.*]] = tosa.reshape %[[VAL_2]], %[[VAL_11]] : (tensor<6x1xf32>, !tosa.shape<3>) -> tensor<1x6x1xf32> +// CHECK: %[[VAL_13:.*]] = tosa.const_shape {value = dense<[1, 48, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_14:.*]] = tosa.reshape %[[VAL_3]], %[[VAL_13]] : (tensor<6x8xf32>, !tosa.shape<3>) -> tensor<1x48x1xf32> +// CHECK: %[[VAL_15:.*]] = tosa.const_shape {value = dense<[6, 2]> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_16:.*]] = tosa.reshape %[[VAL_10]], %[[VAL_15]] : (tensor<6x1x2xi32>, !tosa.shape<2>) -> tensor<6x2xi32> +// CHECK: %[[VAL_17:.*]] = "tosa.const"() <{value = dense<[8, 1]> : tensor<2xi32>}> : () -> tensor<2xi32> +// CHECK: %[[VAL_18:.*]] = tosa.const_shape {value = dense<[1, 2]> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_19:.*]] = tosa.reshape %[[VAL_17]], %[[VAL_18]] : (tensor<2xi32>, !tosa.shape<2>) -> tensor<1x2xi32> +// CHECK: %[[VAL_20:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_21:.*]] = tosa.mul %[[VAL_16]], %[[VAL_19]], %[[VAL_20]] : (tensor<6x2xi32>, tensor<1x2xi32>, tensor<1xi8>) -> tensor<6x2xi32> +// CHECK: %[[VAL_22:.*]] = tosa.reduce_sum %[[VAL_21]] {axis = 1 : i32} : (tensor<6x2xi32>) -> tensor<6x1xi32> +// CHECK: %[[VAL_23:.*]] = tosa.const_shape {value = dense<[1, 6]> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_24:.*]] = tosa.reshape %[[VAL_22]], %[[VAL_23]] : (tensor<6x1xi32>, !tosa.shape<2>) -> tensor<1x6xi32> +// CHECK: %[[VAL_25:.*]] = tosa.scatter %[[VAL_14]], %[[VAL_24]], %[[VAL_12]] : (tensor<1x48x1xf32>, tensor<1x6xi32>, tensor<1x6x1xf32>) -> tensor<1x48x1xf32> +// CHECK: %[[VAL_26:.*]] = tosa.const_shape {value = dense<[6, 8]> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_27:.*]] = tosa.reshape %[[VAL_25]], %[[VAL_26]] : (tensor<1x48x1xf32>, !tosa.shape<2>) -> tensor<6x8xf32> +// CHECK: %[[VAL_28:.*]] = torch_c.from_builtin_tensor %[[VAL_27]] : tensor<6x8xf32> -> !torch.vtensor<[6,8],f32> +// CHECK: return %[[VAL_28]] : !torch.vtensor<[6,8],f32> // CHECK: } func.func @torch.aten.slice_scatter$basic(%arg0: !torch.vtensor<[6,8],f32>, %arg1: !torch.vtensor<[6,1],f32>) -> !torch.vtensor<[6,8],f32> { %int1 = torch.constant.int 1 @@ -2482,27 +2730,36 @@ func.func @torch.aten.slice_scatter$basic(%arg0: !torch.vtensor<[6,8],f32>, %arg // CHECK: %[[VAL_3:.*]] = torch.constant.int -2 // CHECK: %[[VAL_4:.*]] = torch.constant.int -1 // CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}0], [1], [2], [3]], {{\[\[}}0], [1], [2], [3]], {{\[\[}}0], [1], [2], [3]]], {{\[\[}}[0], [1], [2], [3]], {{\[\[}}0], [1], [2], [3]], {{\[\[}}0], [1], [2], [3]]]]> : tensor<2x3x4x1xi32>}> : () -> tensor<2x3x4x1xi32> -// CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array} : (tensor<2x3x4xf32>) -> tensor<2x3x4x1xf32> -// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<2x3x4x4xf32>}> : () -> tensor<2x3x4x4xf32> -// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_5]] {new_shape = array} : (tensor<2x3x4x1xi32>) -> tensor<2x3x4x1x1xi32> -// CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}[0]], {{\[\[}}0]], {{\[\[}}0]], {{\[\[}}0]]], {{\[\[}}[0]], {{\[\[}}0]], {{\[\[}}0]], {{\[\[}}0]]], {{\[\[}}[0]], {{\[\[}}0]], {{\[\[}}0]], {{\[\[}}0]]]], {{\[\[}}{{\[\[}}1]], {{\[\[}}1]], {{\[\[}}1]], {{\[\[}}1]]], {{\[\[}}[1]], {{\[\[}}1]], {{\[\[}}1]], {{\[\[}}1]]], {{\[\[}}[1]], {{\[\[}}1]], {{\[\[}}1]], {{\[\[}}1]]]]]> : tensor<2x3x4x1x1xi32>}> : () -> tensor<2x3x4x1x1xi32> -// CHECK: %[[VAL_10:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}[0]], {{\[\[}}0]], {{\[\[}}0]], {{\[\[}}0]]], {{\[\[}}[1]], {{\[\[}}1]], {{\[\[}}1]], {{\[\[}}1]]], {{\[\[}}[2]], {{\[\[}}2]], {{\[\[}}2]], {{\[\[}}2]]]], {{\[\[}}{{\[\[}}0]], {{\[\[}}0]], {{\[\[}}0]], {{\[\[}}0]]], {{\[\[}}[1]], {{\[\[}}1]], {{\[\[}}1]], {{\[\[}}1]]], {{\[\[}}[2]], {{\[\[}}2]], {{\[\[}}2]], {{\[\[}}2]]]]]> : tensor<2x3x4x1x1xi32>}> : () -> tensor<2x3x4x1x1xi32> -// CHECK: %[[VAL_11:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}[0]], {{\[\[}}1]], {{\[\[}}2]], {{\[\[}}3]]], {{\[\[}}[0]], {{\[\[}}1]], {{\[\[}}2]], {{\[\[}}3]]], {{\[\[}}[0]], {{\[\[}}1]], {{\[\[}}2]], {{\[\[}}3]]]], {{\[\[}}{{\[\[}}0]], {{\[\[}}1]], {{\[\[}}2]], {{\[\[}}3]]], {{\[\[}}[0]], {{\[\[}}1]], {{\[\[}}2]], {{\[\[}}3]]], {{\[\[}}[0]], {{\[\[}}1]], {{\[\[}}2]], {{\[\[}}3]]]]]> : tensor<2x3x4x1x1xi32>}> : () -> tensor<2x3x4x1x1xi32> -// CHECK: %[[VAL_12:.*]] = tosa.concat %[[VAL_9]], %[[VAL_10]], %[[VAL_11]], %[[VAL_8]] {axis = 4 : i32} : (tensor<2x3x4x1x1xi32>, tensor<2x3x4x1x1xi32>, tensor<2x3x4x1x1xi32>, tensor<2x3x4x1x1xi32>) -> tensor<2x3x4x1x4xi32> -// CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_6]] {new_shape = array} : (tensor<2x3x4x1xf32>) -> tensor<1x24x1xf32> -// CHECK: %[[VAL_14:.*]] = tosa.reshape %[[VAL_7]] {new_shape = array} : (tensor<2x3x4x4xf32>) -> tensor<1x96x1xf32> -// CHECK: %[[VAL_15:.*]] = tosa.reshape %[[VAL_12]] {new_shape = array} : (tensor<2x3x4x1x4xi32>) -> tensor<24x4xi32> -// CHECK: %[[VAL_16:.*]] = "tosa.const"() <{value = dense<[48, 16, 4, 1]> : tensor<4xi32>}> : () -> tensor<4xi32> -// CHECK: %[[VAL_17:.*]] = tosa.reshape %[[VAL_16]] {new_shape = array} : (tensor<4xi32>) -> tensor<1x4xi32> -// CHECK: %[[VAL_18:.*]] = tosa.mul %[[VAL_15]], %[[VAL_17]] {shift = 0 : i8} : (tensor<24x4xi32>, tensor<1x4xi32>) -> tensor<24x4xi32> -// CHECK: %[[VAL_19:.*]] = tosa.reduce_sum %[[VAL_18]] {axis = 1 : i32} : (tensor<24x4xi32>) -> tensor<24x1xi32> -// CHECK: %[[VAL_20:.*]] = tosa.reshape %[[VAL_19]] {new_shape = array} : (tensor<24x1xi32>) -> tensor<1x24xi32> -// CHECK: %[[VAL_21:.*]] = tosa.scatter %[[VAL_14]], %[[VAL_20]], %[[VAL_13]] : (tensor<1x96x1xf32>, tensor<1x24xi32>, tensor<1x24x1xf32>) -> tensor<1x96x1xf32> -// CHECK: %[[VAL_22:.*]] = tosa.reshape %[[VAL_21]] {new_shape = array} : (tensor<1x96x1xf32>) -> tensor<2x3x4x4xf32> -// CHECK: %[[VAL_23:.*]] = "tosa.const"() <{value = dense<[0, 1, 2, 3]> : tensor<4xi32>}> : () -> tensor<4xi32> -// CHECK: %[[VAL_24:.*]] = tosa.transpose %[[VAL_22]], %[[VAL_23]] : (tensor<2x3x4x4xf32>, tensor<4xi32>) -> tensor<2x3x4x4xf32> -// CHECK: %[[VAL_25:.*]] = torch_c.from_builtin_tensor %[[VAL_24]] : tensor<2x3x4x4xf32> -> !torch.vtensor<[2,3,4,4],f32> -// CHECK: return %[[VAL_25]] : !torch.vtensor<[2,3,4,4],f32> +// CHECK: %[[VAL_6:.*]] = tosa.const_shape {value = dense<[2, 3, 4, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_7:.*]] = tosa.reshape %[[VAL_1]], %[[VAL_6]] : (tensor<2x3x4xf32>, !tosa.shape<4>) -> tensor<2x3x4x1xf32> +// CHECK: %[[VAL_8:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<2x3x4x4xf32>}> : () -> tensor<2x3x4x4xf32> +// CHECK: %[[VAL_9:.*]] = tosa.const_shape {value = dense<[2, 3, 4, 1, 1]> : tensor<5xindex>} : () -> !tosa.shape<5> +// CHECK: %[[VAL_10:.*]] = tosa.reshape %[[VAL_5]], %[[VAL_9]] : (tensor<2x3x4x1xi32>, !tosa.shape<5>) -> tensor<2x3x4x1x1xi32> +// CHECK: %[[VAL_11:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}[0]], {{\[\[}}0]], {{\[\[}}0]], {{\[\[}}0]]], {{\[\[}}[0]], {{\[\[}}0]], {{\[\[}}0]], {{\[\[}}0]]], {{\[\[}}[0]], {{\[\[}}0]], {{\[\[}}0]], {{\[\[}}0]]]], {{\[\[}}{{\[\[}}1]], {{\[\[}}1]], {{\[\[}}1]], {{\[\[}}1]]], {{\[\[}}[1]], {{\[\[}}1]], {{\[\[}}1]], {{\[\[}}1]]], {{\[\[}}[1]], {{\[\[}}1]], {{\[\[}}1]], {{\[\[}}1]]]]]> : tensor<2x3x4x1x1xi32>}> : () -> tensor<2x3x4x1x1xi32> +// CHECK: %[[VAL_12:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}[0]], {{\[\[}}0]], {{\[\[}}0]], {{\[\[}}0]]], {{\[\[}}[1]], {{\[\[}}1]], {{\[\[}}1]], {{\[\[}}1]]], {{\[\[}}[2]], {{\[\[}}2]], {{\[\[}}2]], {{\[\[}}2]]]], {{\[\[}}{{\[\[}}0]], {{\[\[}}0]], {{\[\[}}0]], {{\[\[}}0]]], {{\[\[}}[1]], {{\[\[}}1]], {{\[\[}}1]], {{\[\[}}1]]], {{\[\[}}[2]], {{\[\[}}2]], {{\[\[}}2]], {{\[\[}}2]]]]]> : tensor<2x3x4x1x1xi32>}> : () -> tensor<2x3x4x1x1xi32> +// CHECK: %[[VAL_13:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}[0]], {{\[\[}}1]], {{\[\[}}2]], {{\[\[}}3]]], {{\[\[}}[0]], {{\[\[}}1]], {{\[\[}}2]], {{\[\[}}3]]], {{\[\[}}[0]], {{\[\[}}1]], {{\[\[}}2]], {{\[\[}}3]]]], {{\[\[}}{{\[\[}}0]], {{\[\[}}1]], {{\[\[}}2]], {{\[\[}}3]]], {{\[\[}}[0]], {{\[\[}}1]], {{\[\[}}2]], {{\[\[}}3]]], {{\[\[}}[0]], {{\[\[}}1]], {{\[\[}}2]], {{\[\[}}3]]]]]> : tensor<2x3x4x1x1xi32>}> : () -> tensor<2x3x4x1x1xi32> +// CHECK: %[[VAL_14:.*]] = tosa.concat %[[VAL_11]], %[[VAL_12]], %[[VAL_13]], %[[VAL_10]] {axis = 4 : i32} : (tensor<2x3x4x1x1xi32>, tensor<2x3x4x1x1xi32>, tensor<2x3x4x1x1xi32>, tensor<2x3x4x1x1xi32>) -> tensor<2x3x4x1x4xi32> +// CHECK: %[[VAL_15:.*]] = tosa.const_shape {value = dense<[1, 24, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_16:.*]] = tosa.reshape %[[VAL_7]], %[[VAL_15]] : (tensor<2x3x4x1xf32>, !tosa.shape<3>) -> tensor<1x24x1xf32> +// CHECK: %[[VAL_17:.*]] = tosa.const_shape {value = dense<[1, 96, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_18:.*]] = tosa.reshape %[[VAL_8]], %[[VAL_17]] : (tensor<2x3x4x4xf32>, !tosa.shape<3>) -> tensor<1x96x1xf32> +// CHECK: %[[VAL_19:.*]] = tosa.const_shape {value = dense<[24, 4]> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_20:.*]] = tosa.reshape %[[VAL_14]], %[[VAL_19]] : (tensor<2x3x4x1x4xi32>, !tosa.shape<2>) -> tensor<24x4xi32> +// CHECK: %[[VAL_21:.*]] = "tosa.const"() <{value = dense<[48, 16, 4, 1]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK: %[[VAL_22:.*]] = tosa.const_shape {value = dense<[1, 4]> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_23:.*]] = tosa.reshape %[[VAL_21]], %[[VAL_22]] : (tensor<4xi32>, !tosa.shape<2>) -> tensor<1x4xi32> +// CHECK: %[[VAL_24:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_25:.*]] = tosa.mul %[[VAL_20]], %[[VAL_23]], %[[VAL_24]] : (tensor<24x4xi32>, tensor<1x4xi32>, tensor<1xi8>) -> tensor<24x4xi32> +// CHECK: %[[VAL_26:.*]] = tosa.reduce_sum %[[VAL_25]] {axis = 1 : i32} : (tensor<24x4xi32>) -> tensor<24x1xi32> +// CHECK: %[[VAL_27:.*]] = tosa.const_shape {value = dense<[1, 24]> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_28:.*]] = tosa.reshape %[[VAL_26]], %[[VAL_27]] : (tensor<24x1xi32>, !tosa.shape<2>) -> tensor<1x24xi32> +// CHECK: %[[VAL_29:.*]] = tosa.scatter %[[VAL_18]], %[[VAL_28]], %[[VAL_16]] : (tensor<1x96x1xf32>, tensor<1x24xi32>, tensor<1x24x1xf32>) -> tensor<1x96x1xf32> +// CHECK: %[[VAL_30:.*]] = tosa.const_shape {value = dense<[2, 3, 4, 4]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_31:.*]] = tosa.reshape %[[VAL_29]], %[[VAL_30]] : (tensor<1x96x1xf32>, !tosa.shape<4>) -> tensor<2x3x4x4xf32> +// CHECK: %[[VAL_32:.*]] = "tosa.const"() <{value = dense<[0, 1, 2, 3]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK: %[[VAL_33:.*]] = tosa.transpose %[[VAL_31]], %[[VAL_32]] : (tensor<2x3x4x4xf32>, tensor<4xi32>) -> tensor<2x3x4x4xf32> +// CHECK: %[[VAL_34:.*]] = torch_c.from_builtin_tensor %[[VAL_33]] : tensor<2x3x4x4xf32> -> !torch.vtensor<[2,3,4,4],f32> +// CHECK: return %[[VAL_34]] : !torch.vtensor<[2,3,4,4],f32> // CHECK: } func.func @torch.aten.diag_embed$basic(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,4,4],f32> { %int0 = torch.constant.int 0 @@ -2526,18 +2783,25 @@ func.func @torch.aten.diag_embed$basic(%arg0: !torch.vtensor<[2,3,4],f32>) -> !t // CHECK: %[[VAL_8:.*]] = tosa.add %[[VAL_7]], %[[VAL_5]] : (tensor, tensor) -> tensor // CHECK: %[[VAL_9:.*]] = tosa.greater %[[VAL_6]], %[[VAL_5]] : (tensor, tensor) -> tensor // CHECK: %[[VAL_10:.*]] = tosa.select %[[VAL_9]], %[[VAL_8]], %[[VAL_5]] : (tensor, tensor, tensor) -> tensor -// CHECK: %[[VAL_11:.*]] = tosa.reshape %[[VAL_10]] {new_shape = array} : (tensor) -> tensor<1xi32> -// CHECK: %[[VAL_12:.*]] = tosa.reshape %[[VAL_2]] {new_shape = array} : (tensor<2x4x2xi64>) -> tensor<1x2x8xi64> -// CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_11]] {new_shape = array} : (tensor<1xi32>) -> tensor<1x1xi32> -// CHECK: %[[VAL_14:.*]] = "tosa.const"() <{value = dense<1> : tensor<1xi32>}> : () -> tensor<1xi32> -// CHECK: %[[VAL_15:.*]] = tosa.reshape %[[VAL_14]] {new_shape = array} : (tensor<1xi32>) -> tensor<1x1xi32> -// CHECK: %[[VAL_16:.*]] = tosa.mul %[[VAL_13]], %[[VAL_15]] {shift = 0 : i8} : (tensor<1x1xi32>, tensor<1x1xi32>) -> tensor<1x1xi32> -// CHECK: %[[VAL_17:.*]] = tosa.reduce_sum %[[VAL_16]] {axis = 1 : i32} : (tensor<1x1xi32>) -> tensor<1x1xi32> -// CHECK: %[[VAL_18:.*]] = tosa.reshape %[[VAL_17]] {new_shape = array} : (tensor<1x1xi32>) -> tensor<1x1xi32> -// CHECK: %[[VAL_19:.*]] = tosa.gather %[[VAL_12]], %[[VAL_18]] : (tensor<1x2x8xi64>, tensor<1x1xi32>) -> tensor<1x1x8xi64> -// CHECK: %[[VAL_20:.*]] = tosa.reshape %[[VAL_19]] {new_shape = array} : (tensor<1x1x8xi64>) -> tensor<4x2xi64> -// CHECK: %[[VAL_21:.*]] = torch_c.from_builtin_tensor %[[VAL_20]] : tensor<4x2xi64> -> !torch.vtensor<[4,2],si64> -// CHECK: return %[[VAL_21]] : !torch.vtensor<[4,2],si64> +// CHECK: %[[VAL_11:.*]] = tosa.const_shape {value = dense<1> : tensor<1xindex>} : () -> !tosa.shape<1> +// CHECK: %[[VAL_12:.*]] = tosa.reshape %[[VAL_10]], %[[VAL_11]] : (tensor, !tosa.shape<1>) -> tensor<1xi32> +// CHECK: %[[VAL_13:.*]] = tosa.const_shape {value = dense<[1, 2, 8]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_14:.*]] = tosa.reshape %[[VAL_2]], %[[VAL_13]] : (tensor<2x4x2xi64>, !tosa.shape<3>) -> tensor<1x2x8xi64> +// CHECK: %[[VAL_15:.*]] = tosa.const_shape {value = dense<1> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_16:.*]] = tosa.reshape %[[VAL_12]], %[[VAL_15]] : (tensor<1xi32>, !tosa.shape<2>) -> tensor<1x1xi32> +// CHECK: %[[VAL_17:.*]] = "tosa.const"() <{value = dense<1> : tensor<1xi32>}> : () -> tensor<1xi32> +// CHECK: %[[VAL_18:.*]] = tosa.const_shape {value = dense<1> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_19:.*]] = tosa.reshape %[[VAL_17]], %[[VAL_18]] : (tensor<1xi32>, !tosa.shape<2>) -> tensor<1x1xi32> +// CHECK: %[[VAL_20:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_21:.*]] = tosa.mul %[[VAL_16]], %[[VAL_19]], %[[VAL_20]] : (tensor<1x1xi32>, tensor<1x1xi32>, tensor<1xi8>) -> tensor<1x1xi32> +// CHECK: %[[VAL_22:.*]] = tosa.reduce_sum %[[VAL_21]] {axis = 1 : i32} : (tensor<1x1xi32>) -> tensor<1x1xi32> +// CHECK: %[[VAL_23:.*]] = tosa.const_shape {value = dense<1> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_24:.*]] = tosa.reshape %[[VAL_22]], %[[VAL_23]] : (tensor<1x1xi32>, !tosa.shape<2>) -> tensor<1x1xi32> +// CHECK: %[[VAL_25:.*]] = tosa.gather %[[VAL_14]], %[[VAL_24]] : (tensor<1x2x8xi64>, tensor<1x1xi32>) -> tensor<1x1x8xi64> +// CHECK: %[[VAL_26:.*]] = tosa.const_shape {value = dense<[4, 2]> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_27:.*]] = tosa.reshape %[[VAL_25]], %[[VAL_26]] : (tensor<1x1x8xi64>, !tosa.shape<2>) -> tensor<4x2xi64> +// CHECK: %[[VAL_28:.*]] = torch_c.from_builtin_tensor %[[VAL_27]] : tensor<4x2xi64> -> !torch.vtensor<[4,2],si64> +// CHECK: return %[[VAL_28]] : !torch.vtensor<[4,2],si64> // CHECK: } func.func @torch.aten.index.Tensor_hacked_twin(%arg0: !torch.vtensor<[2,4,2],si64>, %arg1: !torch.vtensor<[],si64>) -> !torch.vtensor<[4,2],si64> { %0 = torch.prim.ListConstruct %arg1 : (!torch.vtensor<[],si64>) -> !torch.list @@ -2556,10 +2820,11 @@ func.func @torch.aten.index.Tensor_hacked_twin(%arg0: !torch.vtensor<[2,4,2],si6 // CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<1> : tensor<4xi64>}> : () -> tensor<4xi64> // CHECK: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<0> : tensor}> : () -> tensor // CHECK: %[[VAL_7:.*]] = tosa.greater_equal %[[VAL_5]], %[[VAL_2]] : (tensor<4xi64>, tensor<4xi64>) -> tensor<4xi1> -// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_6]] {new_shape = array} : (tensor) -> tensor<1xi64> -// CHECK: %[[VAL_9:.*]] = tosa.select %[[VAL_7]], %[[VAL_8]], %[[VAL_3]] : (tensor<4xi1>, tensor<1xi64>, tensor<4xi64>) -> tensor<4xi64> -// CHECK: %[[VAL_10:.*]] = torch_c.from_builtin_tensor %[[VAL_9]] : tensor<4xi64> -> !torch.vtensor<[4],si64> -// CHECK: return %[[VAL_10]] : !torch.vtensor<[4],si64> +// CHECK: %[[VAL_8:.*]] = tosa.const_shape {value = dense<1> : tensor<1xindex>} : () -> !tosa.shape<1> +// CHECK: %[[VAL_9:.*]] = tosa.reshape %[[VAL_6]], %[[VAL_8]] : (tensor, !tosa.shape<1>) -> tensor<1xi64> +// CHECK: %[[VAL_10:.*]] = tosa.select %[[VAL_7]], %[[VAL_9]], %[[VAL_3]] : (tensor<4xi1>, tensor<1xi64>, tensor<4xi64>) -> tensor<4xi64> +// CHECK: %[[VAL_11:.*]] = torch_c.from_builtin_tensor %[[VAL_10]] : tensor<4xi64> -> !torch.vtensor<[4],si64> +// CHECK: return %[[VAL_11]] : !torch.vtensor<[4],si64> // CHECK: } func.func @torch.aten.threshold_backward$basic(%arg0: !torch.vtensor<[4],si64>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[4],si64> { %int1 = torch.constant.int 1 @@ -2627,22 +2892,31 @@ func.func @torch.aten.uniform$basic(%arg0: !torch.vtensor<[3,4],f64>) -> (!torch // CHECK: %[[VAL_5:.*]] = torch.constant.int 3 // CHECK: %[[VAL_6:.*]] = torch.prim.ListConstruct %[[VAL_5]], %[[VAL_5]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array} : (tensor<5x5xf32>) -> tensor<25xf32> -// CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<[0, 1, 2, 2, 3, 4, 4, 5, 6]> : tensor<9xi32>}> : () -> tensor<9xi32> -// CHECK: %[[VAL_10:.*]] = tosa.reshape %[[VAL_9]] {new_shape = array} : (tensor<9xi32>) -> tensor<9x1xi32> -// CHECK: %[[VAL_11:.*]] = tosa.concat %[[VAL_10]] {axis = 1 : i32} : (tensor<9x1xi32>) -> tensor<9x1xi32> -// CHECK: %[[VAL_12:.*]] = tosa.reshape %[[VAL_8]] {new_shape = array} : (tensor<25xf32>) -> tensor<1x25x1xf32> -// CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_11]] {new_shape = array} : (tensor<9x1xi32>) -> tensor<9x1xi32> -// CHECK: %[[VAL_14:.*]] = "tosa.const"() <{value = dense<1> : tensor<1xi32>}> : () -> tensor<1xi32> -// CHECK: %[[VAL_15:.*]] = tosa.reshape %[[VAL_14]] {new_shape = array} : (tensor<1xi32>) -> tensor<1x1xi32> -// CHECK: %[[VAL_16:.*]] = tosa.mul %[[VAL_13]], %[[VAL_15]] {shift = 0 : i8} : (tensor<9x1xi32>, tensor<1x1xi32>) -> tensor<9x1xi32> -// CHECK: %[[VAL_17:.*]] = tosa.reduce_sum %[[VAL_16]] {axis = 1 : i32} : (tensor<9x1xi32>) -> tensor<9x1xi32> -// CHECK: %[[VAL_18:.*]] = tosa.reshape %[[VAL_17]] {new_shape = array} : (tensor<9x1xi32>) -> tensor<1x9xi32> -// CHECK: %[[VAL_19:.*]] = tosa.gather %[[VAL_12]], %[[VAL_18]] : (tensor<1x25x1xf32>, tensor<1x9xi32>) -> tensor<1x9x1xf32> -// CHECK: %[[VAL_20:.*]] = tosa.reshape %[[VAL_19]] {new_shape = array} : (tensor<1x9x1xf32>) -> tensor<9xf32> -// CHECK: %[[VAL_21:.*]] = tosa.reshape %[[VAL_20]] {new_shape = array} : (tensor<9xf32>) -> tensor<3x3xf32> -// CHECK: %[[VAL_22:.*]] = torch_c.from_builtin_tensor %[[VAL_21]] : tensor<3x3xf32> -> !torch.vtensor<[3,3],f32> -// CHECK: return %[[VAL_22]] : !torch.vtensor<[3,3],f32> +// CHECK: %[[VAL_8:.*]] = tosa.const_shape {value = dense<25> : tensor<1xindex>} : () -> !tosa.shape<1> +// CHECK: %[[VAL_9:.*]] = tosa.reshape %[[VAL_1]], %[[VAL_8]] : (tensor<5x5xf32>, !tosa.shape<1>) -> tensor<25xf32> +// CHECK: %[[VAL_10:.*]] = "tosa.const"() <{value = dense<[0, 1, 2, 2, 3, 4, 4, 5, 6]> : tensor<9xi32>}> : () -> tensor<9xi32> +// CHECK: %[[VAL_11:.*]] = tosa.const_shape {value = dense<[9, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_12:.*]] = tosa.reshape %[[VAL_10]], %[[VAL_11]] : (tensor<9xi32>, !tosa.shape<2>) -> tensor<9x1xi32> +// CHECK: %[[VAL_13:.*]] = tosa.concat %[[VAL_12]] {axis = 1 : i32} : (tensor<9x1xi32>) -> tensor<9x1xi32> +// CHECK: %[[VAL_14:.*]] = tosa.const_shape {value = dense<[1, 25, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_15:.*]] = tosa.reshape %[[VAL_9]], %[[VAL_14]] : (tensor<25xf32>, !tosa.shape<3>) -> tensor<1x25x1xf32> +// CHECK: %[[VAL_16:.*]] = tosa.const_shape {value = dense<[9, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_17:.*]] = tosa.reshape %[[VAL_13]], %[[VAL_16]] : (tensor<9x1xi32>, !tosa.shape<2>) -> tensor<9x1xi32> +// CHECK: %[[VAL_18:.*]] = "tosa.const"() <{value = dense<1> : tensor<1xi32>}> : () -> tensor<1xi32> +// CHECK: %[[VAL_19:.*]] = tosa.const_shape {value = dense<1> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_20:.*]] = tosa.reshape %[[VAL_18]], %[[VAL_19]] : (tensor<1xi32>, !tosa.shape<2>) -> tensor<1x1xi32> +// CHECK: %[[VAL_21:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_22:.*]] = tosa.mul %[[VAL_17]], %[[VAL_20]], %[[VAL_21]] : (tensor<9x1xi32>, tensor<1x1xi32>, tensor<1xi8>) -> tensor<9x1xi32> +// CHECK: %[[VAL_23:.*]] = tosa.reduce_sum %[[VAL_22]] {axis = 1 : i32} : (tensor<9x1xi32>) -> tensor<9x1xi32> +// CHECK: %[[VAL_24:.*]] = tosa.const_shape {value = dense<[1, 9]> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_25:.*]] = tosa.reshape %[[VAL_23]], %[[VAL_24]] : (tensor<9x1xi32>, !tosa.shape<2>) -> tensor<1x9xi32> +// CHECK: %[[VAL_26:.*]] = tosa.gather %[[VAL_15]], %[[VAL_25]] : (tensor<1x25x1xf32>, tensor<1x9xi32>) -> tensor<1x9x1xf32> +// CHECK: %[[VAL_27:.*]] = tosa.const_shape {value = dense<9> : tensor<1xindex>} : () -> !tosa.shape<1> +// CHECK: %[[VAL_28:.*]] = tosa.reshape %[[VAL_26]], %[[VAL_27]] : (tensor<1x9x1xf32>, !tosa.shape<1>) -> tensor<9xf32> +// CHECK: %[[VAL_29:.*]] = tosa.const_shape {value = dense<3> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_30:.*]] = tosa.reshape %[[VAL_28]], %[[VAL_29]] : (tensor<9xf32>, !tosa.shape<2>) -> tensor<3x3xf32> +// CHECK: %[[VAL_31:.*]] = torch_c.from_builtin_tensor %[[VAL_30]] : tensor<3x3xf32> -> !torch.vtensor<[3,3],f32> +// CHECK: return %[[VAL_31]] : !torch.vtensor<[3,3],f32> // CHECK: } func.func @torch.aten.as_strided$basic(%arg0: !torch.vtensor<[5,5],f32>) -> !torch.vtensor<[3,3],f32> { %none = torch.constant.none @@ -2668,16 +2942,18 @@ func.func @torch.aten.as_strided$basic(%arg0: !torch.vtensor<[5,5],f32>) -> !tor // CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_4]] : (!torch.int) -> !torch.list // CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_3]] : (!torch.int) -> !torch.list // CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_3]] : (!torch.int) -> !torch.list -// CHECK: %[[VAL_10:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array} : (tensor<1x64x112xf32>) -> tensor<1x64x112x1xf32> -// CHECK: %[[VAL_11:.*]] = "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32> -// CHECK: %[[VAL_12:.*]] = tosa.transpose %[[VAL_10]], %[[VAL_11]] : (tensor<1x64x112x1xf32>, tensor<4xi32>) -> tensor<1x112x1x64xf32> -// CHECK: %[[VAL_13:.*]] = tosa.max_pool2d %[[VAL_12]] {kernel = array, pad = array, stride = array} : (tensor<1x112x1x64xf32>) -> tensor<1x56x1x64xf32> -// CHECK: %[[VAL_14:.*]] = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32> -// CHECK: %[[VAL_15:.*]] = tosa.transpose %[[VAL_13]], %[[VAL_14]] : (tensor<1x56x1x64xf32>, tensor<4xi32>) -> tensor<1x64x56x1xf32> -// CHECK: %[[VAL_16:.*]] = tosa.reshape %[[VAL_15]] {new_shape = array} : (tensor<1x64x56x1xf32>) -> tensor<1x64x56xf32> -// CHECK: %[[VAL_17:.*]] = tensor.cast %[[VAL_16]] : tensor<1x64x56xf32> to tensor<1x64x56xf32> -// CHECK: %[[VAL_18:.*]] = torch_c.from_builtin_tensor %[[VAL_17]] : tensor<1x64x56xf32> -> !torch.vtensor<[1,64,56],f32> -// CHECK: return %[[VAL_18]] : !torch.vtensor<[1,64,56],f32> +// CHECK: %[[VAL_10:.*]] = tosa.const_shape {value = dense<[1, 64, 112, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_11:.*]] = tosa.reshape %[[VAL_1]], %[[VAL_10]] : (tensor<1x64x112xf32>, !tosa.shape<4>) -> tensor<1x64x112x1xf32> +// CHECK: %[[VAL_12:.*]] = "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK: %[[VAL_13:.*]] = tosa.transpose %[[VAL_11]], %[[VAL_12]] : (tensor<1x64x112x1xf32>, tensor<4xi32>) -> tensor<1x112x1x64xf32> +// CHECK: %[[VAL_14:.*]] = tosa.max_pool2d %[[VAL_13]] {kernel = array, pad = array, stride = array} : (tensor<1x112x1x64xf32>) -> tensor<1x56x1x64xf32> +// CHECK: %[[VAL_15:.*]] = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK: %[[VAL_16:.*]] = tosa.transpose %[[VAL_14]], %[[VAL_15]] : (tensor<1x56x1x64xf32>, tensor<4xi32>) -> tensor<1x64x56x1xf32> +// CHECK: %[[VAL_17:.*]] = tosa.const_shape {value = dense<[1, 64, 56]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_18:.*]] = tosa.reshape %[[VAL_16]], %[[VAL_17]] : (tensor<1x64x56x1xf32>, !tosa.shape<3>) -> tensor<1x64x56xf32> +// CHECK: %[[VAL_19:.*]] = tensor.cast %[[VAL_18]] : tensor<1x64x56xf32> to tensor<1x64x56xf32> +// CHECK: %[[VAL_20:.*]] = torch_c.from_builtin_tensor %[[VAL_19]] : tensor<1x64x56xf32> -> !torch.vtensor<[1,64,56],f32> +// CHECK: return %[[VAL_20]] : !torch.vtensor<[1,64,56],f32> // CHECK: } func.func @torch.aten.max_pool1d$basic(%arg0: !torch.vtensor<[1,64,112],f32>) -> !torch.vtensor<[1,64,56],f32> { %false = torch.constant.bool false @@ -2703,16 +2979,18 @@ func.func @torch.aten.max_pool1d$basic(%arg0: !torch.vtensor<[1,64,112],f32>) -> // CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %[[VAL_2]] : (!torch.int) -> !torch.list // CHECK: %[[VAL_6:.*]] = torch.prim.ListConstruct %[[VAL_2]] : (!torch.int) -> !torch.list // CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_3]] : (!torch.int) -> !torch.list -// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array} : (tensor<1x512x10xf32>) -> tensor<1x512x10x1xf32> -// CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32> -// CHECK: %[[VAL_10:.*]] = tosa.transpose %[[VAL_8]], %[[VAL_9]] : (tensor<1x512x10x1xf32>, tensor<4xi32>) -> tensor<1x10x1x512xf32> -// CHECK: %[[VAL_11:.*]] = tosa.avg_pool2d %[[VAL_10]] {acc_type = f32, kernel = array, pad = array, stride = array} : (tensor<1x10x1x512xf32>) -> tensor<1x10x1x512xf32> -// CHECK: %[[VAL_12:.*]] = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32> -// CHECK: %[[VAL_13:.*]] = tosa.transpose %[[VAL_11]], %[[VAL_12]] : (tensor<1x10x1x512xf32>, tensor<4xi32>) -> tensor<1x512x10x1xf32> -// CHECK: %[[VAL_14:.*]] = tosa.reshape %[[VAL_13]] {new_shape = array} : (tensor<1x512x10x1xf32>) -> tensor<1x512x10xf32> -// CHECK: %[[VAL_15:.*]] = tensor.cast %[[VAL_14]] : tensor<1x512x10xf32> to tensor<1x512x10xf32> -// CHECK: %[[VAL_16:.*]] = torch_c.from_builtin_tensor %[[VAL_15]] : tensor<1x512x10xf32> -> !torch.vtensor<[1,512,10],f32> -// CHECK: return %[[VAL_16]] : !torch.vtensor<[1,512,10],f32> +// CHECK: %[[VAL_8:.*]] = tosa.const_shape {value = dense<[1, 512, 10, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_9:.*]] = tosa.reshape %[[VAL_1]], %[[VAL_8]] : (tensor<1x512x10xf32>, !tosa.shape<4>) -> tensor<1x512x10x1xf32> +// CHECK: %[[VAL_10:.*]] = "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK: %[[VAL_11:.*]] = tosa.transpose %[[VAL_9]], %[[VAL_10]] : (tensor<1x512x10x1xf32>, tensor<4xi32>) -> tensor<1x10x1x512xf32> +// CHECK: %[[VAL_12:.*]] = tosa.avg_pool2d %[[VAL_11]] {acc_type = f32, kernel = array, pad = array, stride = array} : (tensor<1x10x1x512xf32>) -> tensor<1x10x1x512xf32> +// CHECK: %[[VAL_13:.*]] = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK: %[[VAL_14:.*]] = tosa.transpose %[[VAL_12]], %[[VAL_13]] : (tensor<1x10x1x512xf32>, tensor<4xi32>) -> tensor<1x512x10x1xf32> +// CHECK: %[[VAL_15:.*]] = tosa.const_shape {value = dense<[1, 512, 10]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_16:.*]] = tosa.reshape %[[VAL_14]], %[[VAL_15]] : (tensor<1x512x10x1xf32>, !tosa.shape<3>) -> tensor<1x512x10xf32> +// CHECK: %[[VAL_17:.*]] = tensor.cast %[[VAL_16]] : tensor<1x512x10xf32> to tensor<1x512x10xf32> +// CHECK: %[[VAL_18:.*]] = torch_c.from_builtin_tensor %[[VAL_17]] : tensor<1x512x10xf32> -> !torch.vtensor<[1,512,10],f32> +// CHECK: return %[[VAL_18]] : !torch.vtensor<[1,512,10],f32> // CHECK: } func.func @torch.aten.avg_pool1d$basic(%arg0: !torch.vtensor<[1,512,10],f32>) -> !torch.vtensor<[1,512,10],f32> { %int1 = torch.constant.int 1 @@ -2736,23 +3014,29 @@ func.func @torch.aten.avg_pool1d$basic(%arg0: !torch.vtensor<[1,512,10],f32>) -> // CHECK: %[[VAL_5:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,5],f32> -> tensor<3x5xf32> // CHECK: %[[VAL_6:.*]] = torch.constant.none // CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<3.40282347E+38> : tensor}> : () -> tensor -// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_4]] {new_shape = array} : (tensor<1xf32>) -> tensor<1x1xf32> -// CHECK: %[[VAL_9:.*]] = tosa.reshape %[[VAL_7]] {new_shape = array} : (tensor) -> tensor<1x1xf32> -// CHECK: %[[VAL_10:.*]] = tosa.maximum %[[VAL_5]], %[[VAL_8]] : (tensor<3x5xf32>, tensor<1x1xf32>) -> tensor<3x5xf32> -// CHECK: %[[VAL_11:.*]] = tosa.minimum %[[VAL_10]], %[[VAL_9]] : (tensor<3x5xf32>, tensor<1x1xf32>) -> tensor<3x5xf32> -// CHECK: %[[VAL_12:.*]] = torch_c.from_builtin_tensor %[[VAL_11]] : tensor<3x5xf32> -> !torch.vtensor<[3,5],f32> -// CHECK: %[[VAL_13:.*]] = "tosa.const"() <{value = dense<-3.40282347E+38> : tensor}> : () -> tensor -// CHECK: %[[VAL_14:.*]] = tosa.reshape %[[VAL_13]] {new_shape = array} : (tensor) -> tensor<1x1xf32> -// CHECK: %[[VAL_15:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor<1xf32>) -> tensor<1x1xf32> -// CHECK: %[[VAL_16:.*]] = tosa.maximum %[[VAL_5]], %[[VAL_14]] : (tensor<3x5xf32>, tensor<1x1xf32>) -> tensor<3x5xf32> -// CHECK: %[[VAL_17:.*]] = tosa.minimum %[[VAL_16]], %[[VAL_15]] : (tensor<3x5xf32>, tensor<1x1xf32>) -> tensor<3x5xf32> -// CHECK: %[[VAL_18:.*]] = torch_c.from_builtin_tensor %[[VAL_17]] : tensor<3x5xf32> -> !torch.vtensor<[3,5],f32> -// CHECK: %[[VAL_19:.*]] = tosa.reshape %[[VAL_4]] {new_shape = array} : (tensor<1xf32>) -> tensor<1x1xf32> -// CHECK: %[[VAL_20:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor<1xf32>) -> tensor<1x1xf32> -// CHECK: %[[VAL_21:.*]] = tosa.maximum %[[VAL_5]], %[[VAL_19]] : (tensor<3x5xf32>, tensor<1x1xf32>) -> tensor<3x5xf32> -// CHECK: %[[VAL_22:.*]] = tosa.minimum %[[VAL_21]], %[[VAL_20]] : (tensor<3x5xf32>, tensor<1x1xf32>) -> tensor<3x5xf32> -// CHECK: %[[VAL_23:.*]] = torch_c.from_builtin_tensor %[[VAL_22]] : tensor<3x5xf32> -> !torch.vtensor<[3,5],f32> -// CHECK: return %[[VAL_12]], %[[VAL_18]], %[[VAL_23]] : !torch.vtensor<[3,5],f32>, !torch.vtensor<[3,5],f32>, !torch.vtensor<[3,5],f32> +// CHECK: %[[VAL_8:.*]] = tosa.const_shape {value = dense<1> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_9:.*]] = tosa.reshape %[[VAL_4]], %[[VAL_8]] : (tensor<1xf32>, !tosa.shape<2>) -> tensor<1x1xf32> +// CHECK: %[[VAL_10:.*]] = tosa.const_shape {value = dense<1> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_11:.*]] = tosa.reshape %[[VAL_7]], %[[VAL_10]] : (tensor, !tosa.shape<2>) -> tensor<1x1xf32> +// CHECK: %[[VAL_12:.*]] = tosa.maximum %[[VAL_5]], %[[VAL_9]] : (tensor<3x5xf32>, tensor<1x1xf32>) -> tensor<3x5xf32> +// CHECK: %[[VAL_13:.*]] = tosa.minimum %[[VAL_12]], %[[VAL_11]] : (tensor<3x5xf32>, tensor<1x1xf32>) -> tensor<3x5xf32> +// CHECK: %[[VAL_14:.*]] = torch_c.from_builtin_tensor %[[VAL_13]] : tensor<3x5xf32> -> !torch.vtensor<[3,5],f32> +// CHECK: %[[VAL_15:.*]] = "tosa.const"() <{value = dense<-3.40282347E+38> : tensor}> : () -> tensor +// CHECK: %[[VAL_16:.*]] = tosa.const_shape {value = dense<1> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_17:.*]] = tosa.reshape %[[VAL_15]], %[[VAL_16]] : (tensor, !tosa.shape<2>) -> tensor<1x1xf32> +// CHECK: %[[VAL_18:.*]] = tosa.const_shape {value = dense<1> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_19:.*]] = tosa.reshape %[[VAL_3]], %[[VAL_18]] : (tensor<1xf32>, !tosa.shape<2>) -> tensor<1x1xf32> +// CHECK: %[[VAL_20:.*]] = tosa.maximum %[[VAL_5]], %[[VAL_17]] : (tensor<3x5xf32>, tensor<1x1xf32>) -> tensor<3x5xf32> +// CHECK: %[[VAL_21:.*]] = tosa.minimum %[[VAL_20]], %[[VAL_19]] : (tensor<3x5xf32>, tensor<1x1xf32>) -> tensor<3x5xf32> +// CHECK: %[[VAL_22:.*]] = torch_c.from_builtin_tensor %[[VAL_21]] : tensor<3x5xf32> -> !torch.vtensor<[3,5],f32> +// CHECK: %[[VAL_23:.*]] = tosa.const_shape {value = dense<1> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_24:.*]] = tosa.reshape %[[VAL_4]], %[[VAL_23]] : (tensor<1xf32>, !tosa.shape<2>) -> tensor<1x1xf32> +// CHECK: %[[VAL_25:.*]] = tosa.const_shape {value = dense<1> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_26:.*]] = tosa.reshape %[[VAL_3]], %[[VAL_25]] : (tensor<1xf32>, !tosa.shape<2>) -> tensor<1x1xf32> +// CHECK: %[[VAL_27:.*]] = tosa.maximum %[[VAL_5]], %[[VAL_24]] : (tensor<3x5xf32>, tensor<1x1xf32>) -> tensor<3x5xf32> +// CHECK: %[[VAL_28:.*]] = tosa.minimum %[[VAL_27]], %[[VAL_26]] : (tensor<3x5xf32>, tensor<1x1xf32>) -> tensor<3x5xf32> +// CHECK: %[[VAL_29:.*]] = torch_c.from_builtin_tensor %[[VAL_28]] : tensor<3x5xf32> -> !torch.vtensor<[3,5],f32> +// CHECK: return %[[VAL_14]], %[[VAL_22]], %[[VAL_29]] : !torch.vtensor<[3,5],f32>, !torch.vtensor<[3,5],f32>, !torch.vtensor<[3,5],f32> // CHECK: } func.func @torch.aten.clamp.Tensor$basic(%arg0: !torch.vtensor<[3,5],f32>, %arg1: !torch.vtensor<[1],f32>, %arg2: !torch.vtensor<[1],f32>) -> (!torch.vtensor<[3,5],f32>, !torch.vtensor<[3,5],f32>, !torch.vtensor<[3,5],f32>) { %none = torch.constant.none @@ -2769,9 +3053,10 @@ func.func @torch.aten.clamp.Tensor$basic(%arg0: !torch.vtensor<[3,5],f32>, %arg1 // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,3,4],f32> -> tensor<2x3x4xf32> // CHECK: %[[VAL_2:.*]] = torch.constant.int 1 // CHECK: %[[VAL_3:.*]] = torch.constant.int 2 -// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array} : (tensor<2x3x4xf32>) -> tensor<2x12xf32> -// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<2x12xf32> -> !torch.vtensor<[2,12],f32> -// CHECK: return %[[VAL_5]] : !torch.vtensor<[2,12],f32> +// CHECK: %[[VAL_4:.*]] = tosa.const_shape {value = dense<[2, 12]> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_5:.*]] = tosa.reshape %[[VAL_1]], %[[VAL_4]] : (tensor<2x3x4xf32>, !tosa.shape<2>) -> tensor<2x12xf32> +// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<2x12xf32> -> !torch.vtensor<[2,12],f32> +// CHECK: return %[[VAL_6]] : !torch.vtensor<[2,12],f32> // CHECK: } func.func @torch.prims.collapse$basic(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,12],f32> { %int1 = torch.constant.int 1 @@ -2803,13 +3088,17 @@ func.func @torch.aten.avg_pool1d.count_include_pad_unsupported_value(%arg0: !tor // CHECK: %[[VAL_2:.*]] = torch.constant.int 3 // CHECK: %[[VAL_3:.*]] = torch.constant.int 1 // CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %[[VAL_2]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[VAL_5:.*]] = tosa.slice %[[VAL_1]] {size = array, start = array} : (tensor<1x2x4xf32>) -> tensor<1x2x3xf32> -// CHECK: %[[VAL_6:.*]] = tosa.reverse %[[VAL_5]] {axis = 2 : i32} : (tensor<1x2x3xf32>) -> tensor<1x2x3xf32> -// CHECK: %[[VAL_7:.*]] = tosa.slice %[[VAL_1]] {size = array, start = array} : (tensor<1x2x4xf32>) -> tensor<1x2x1xf32> -// CHECK: %[[VAL_8:.*]] = tosa.reverse %[[VAL_7]] {axis = 2 : i32} : (tensor<1x2x1xf32>) -> tensor<1x2x1xf32> -// CHECK: %[[VAL_9:.*]] = tosa.concat %[[VAL_6]], %[[VAL_1]], %[[VAL_8]] {axis = 2 : i32} : (tensor<1x2x3xf32>, tensor<1x2x4xf32>, tensor<1x2x1xf32>) -> tensor<1x2x8xf32> -// CHECK: %[[VAL_10:.*]] = torch_c.from_builtin_tensor %[[VAL_9]] : tensor<1x2x8xf32> -> !torch.vtensor<[1,2,8],f32> -// CHECK: return %[[VAL_10]] : !torch.vtensor<[1,2,8],f32> +// CHECK: %[[VAL_5:.*]] = tosa.const_shape {value = dense<[0, 0, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_6:.*]] = tosa.const_shape {value = dense<[1, 2, 3]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_7:.*]] = tosa.slice %[[VAL_1]], %[[VAL_5]], %[[VAL_6]] : (tensor<1x2x4xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<1x2x3xf32> +// CHECK: %[[VAL_8:.*]] = tosa.reverse %[[VAL_7]] {axis = 2 : i32} : (tensor<1x2x3xf32>) -> tensor<1x2x3xf32> +// CHECK: %[[VAL_9:.*]] = tosa.const_shape {value = dense<[0, 0, 2]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_10:.*]] = tosa.const_shape {value = dense<[1, 2, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_11:.*]] = tosa.slice %[[VAL_1]], %[[VAL_9]], %[[VAL_10]] : (tensor<1x2x4xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<1x2x1xf32> +// CHECK: %[[VAL_12:.*]] = tosa.reverse %[[VAL_11]] {axis = 2 : i32} : (tensor<1x2x1xf32>) -> tensor<1x2x1xf32> +// CHECK: %[[VAL_13:.*]] = tosa.concat %[[VAL_8]], %[[VAL_1]], %[[VAL_12]] {axis = 2 : i32} : (tensor<1x2x3xf32>, tensor<1x2x4xf32>, tensor<1x2x1xf32>) -> tensor<1x2x8xf32> +// CHECK: %[[VAL_14:.*]] = torch_c.from_builtin_tensor %[[VAL_13]] : tensor<1x2x8xf32> -> !torch.vtensor<[1,2,8],f32> +// CHECK: return %[[VAL_14]] : !torch.vtensor<[1,2,8],f32> // CHECK: } func.func @torch.aten.reflection_pad1d$basic(%arg0: !torch.vtensor<[1,2,4],f32>) -> !torch.vtensor<[1,2,8],f32> { %int3 = torch.constant.int 3 @@ -2826,18 +3115,26 @@ func.func @torch.aten.reflection_pad1d$basic(%arg0: !torch.vtensor<[1,2,4],f32>) // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,20,20],f32> -> tensor<1x20x20xf32> // CHECK: %[[VAL_2:.*]] = torch.constant.int 10 // CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %[[VAL_2]], %[[VAL_2]], %[[VAL_2]], %[[VAL_2]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list -// CHECK: %[[VAL_4:.*]] = tosa.slice %[[VAL_1]] {size = array, start = array} : (tensor<1x20x20xf32>) -> tensor<1x20x10xf32> -// CHECK: %[[VAL_5:.*]] = tosa.reverse %[[VAL_4]] {axis = 2 : i32} : (tensor<1x20x10xf32>) -> tensor<1x20x10xf32> -// CHECK: %[[VAL_6:.*]] = tosa.slice %[[VAL_1]] {size = array, start = array} : (tensor<1x20x20xf32>) -> tensor<1x20x10xf32> +// CHECK: %[[VAL_4:.*]] = tosa.const_shape {value = dense<[0, 0, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_5:.*]] = tosa.const_shape {value = dense<[1, 20, 10]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_6:.*]] = tosa.slice %[[VAL_1]], %[[VAL_4]], %[[VAL_5]] : (tensor<1x20x20xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<1x20x10xf32> // CHECK: %[[VAL_7:.*]] = tosa.reverse %[[VAL_6]] {axis = 2 : i32} : (tensor<1x20x10xf32>) -> tensor<1x20x10xf32> -// CHECK: %[[VAL_8:.*]] = tosa.concat %[[VAL_5]], %[[VAL_1]], %[[VAL_7]] {axis = 2 : i32} : (tensor<1x20x10xf32>, tensor<1x20x20xf32>, tensor<1x20x10xf32>) -> tensor<1x20x40xf32> -// CHECK: %[[VAL_9:.*]] = tosa.slice %[[VAL_8]] {size = array, start = array} : (tensor<1x20x40xf32>) -> tensor<1x10x40xf32> -// CHECK: %[[VAL_10:.*]] = tosa.reverse %[[VAL_9]] {axis = 1 : i32} : (tensor<1x10x40xf32>) -> tensor<1x10x40xf32> -// CHECK: %[[VAL_11:.*]] = tosa.slice %[[VAL_8]] {size = array, start = array} : (tensor<1x20x40xf32>) -> tensor<1x10x40xf32> -// CHECK: %[[VAL_12:.*]] = tosa.reverse %[[VAL_11]] {axis = 1 : i32} : (tensor<1x10x40xf32>) -> tensor<1x10x40xf32> -// CHECK: %[[VAL_13:.*]] = tosa.concat %[[VAL_10]], %[[VAL_8]], %[[VAL_12]] {axis = 1 : i32} : (tensor<1x10x40xf32>, tensor<1x20x40xf32>, tensor<1x10x40xf32>) -> tensor<1x40x40xf32> -// CHECK: %[[VAL_14:.*]] = torch_c.from_builtin_tensor %[[VAL_13]] : tensor<1x40x40xf32> -> !torch.vtensor<[1,40,40],f32> -// CHECK: return %[[VAL_14]] : !torch.vtensor<[1,40,40],f32> +// CHECK: %[[VAL_8:.*]] = tosa.const_shape {value = dense<[0, 0, 9]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_9:.*]] = tosa.const_shape {value = dense<[1, 20, 10]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_10:.*]] = tosa.slice %[[VAL_1]], %[[VAL_8]], %[[VAL_9]] : (tensor<1x20x20xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<1x20x10xf32> +// CHECK: %[[VAL_11:.*]] = tosa.reverse %[[VAL_10]] {axis = 2 : i32} : (tensor<1x20x10xf32>) -> tensor<1x20x10xf32> +// CHECK: %[[VAL_12:.*]] = tosa.concat %[[VAL_7]], %[[VAL_1]], %[[VAL_11]] {axis = 2 : i32} : (tensor<1x20x10xf32>, tensor<1x20x20xf32>, tensor<1x20x10xf32>) -> tensor<1x20x40xf32> +// CHECK: %[[VAL_13:.*]] = tosa.const_shape {value = dense<[0, 1, 0]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_14:.*]] = tosa.const_shape {value = dense<[1, 10, 40]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_15:.*]] = tosa.slice %[[VAL_12]], %[[VAL_13]], %[[VAL_14]] : (tensor<1x20x40xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<1x10x40xf32> +// CHECK: %[[VAL_16:.*]] = tosa.reverse %[[VAL_15]] {axis = 1 : i32} : (tensor<1x10x40xf32>) -> tensor<1x10x40xf32> +// CHECK: %[[VAL_17:.*]] = tosa.const_shape {value = dense<[0, 9, 0]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_18:.*]] = tosa.const_shape {value = dense<[1, 10, 40]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_19:.*]] = tosa.slice %[[VAL_12]], %[[VAL_17]], %[[VAL_18]] : (tensor<1x20x40xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<1x10x40xf32> +// CHECK: %[[VAL_20:.*]] = tosa.reverse %[[VAL_19]] {axis = 1 : i32} : (tensor<1x10x40xf32>) -> tensor<1x10x40xf32> +// CHECK: %[[VAL_21:.*]] = tosa.concat %[[VAL_16]], %[[VAL_12]], %[[VAL_20]] {axis = 1 : i32} : (tensor<1x10x40xf32>, tensor<1x20x40xf32>, tensor<1x10x40xf32>) -> tensor<1x40x40xf32> +// CHECK: %[[VAL_22:.*]] = torch_c.from_builtin_tensor %[[VAL_21]] : tensor<1x40x40xf32> -> !torch.vtensor<[1,40,40],f32> +// CHECK: return %[[VAL_22]] : !torch.vtensor<[1,40,40],f32> // CHECK: } func.func @torch.aten.reflection_pad2d$basic(%arg0: !torch.vtensor<[1,20,20],f32>) -> !torch.vtensor<[1,40,40],f32> { %int10 = torch.constant.int 10 @@ -2849,27 +3146,40 @@ func.func @torch.aten.reflection_pad2d$basic(%arg0: !torch.vtensor<[1,20,20],f32 // ----- // CHECK-LABEL: func.func @torch.aten.reflection_pad3d$basic( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[4,5,7,3,4],f32>) -> !torch.vtensor<[4,5,11,7,8],f32> { -// CHECK: %[[VAL_0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[4,5,7,3,4],f32> -> tensor<4x5x7x3x4xf32> -// CHECK: %[[VAL_1:.*]] = torch.constant.int 2 -// CHECK: %[[VAL_2:.*]] = torch.prim.ListConstruct %[[VAL_1]], %[[VAL_1]], %[[VAL_1]], %[[VAL_1]], %[[VAL_1]], %[[VAL_1]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list -// CHECK: %[[SLICE_L:.*]] = tosa.slice %[[VAL_0]] {size = array, start = array} : (tensor<4x5x7x3x4xf32>) -> tensor<4x5x7x3x2xf32> -// CHECK: %[[REVERSE_L:.*]] = tosa.reverse %[[SLICE_L]] {axis = 4 : i32} : (tensor<4x5x7x3x2xf32>) -> tensor<4x5x7x3x2xf32> -// CHECK: %[[SLICE_R:.*]] = tosa.slice %[[VAL_0]] {size = array, start = array} : (tensor<4x5x7x3x4xf32>) -> tensor<4x5x7x3x2xf32> -// CHECK: %[[REVERSE_R:.*]] = tosa.reverse %[[SLICE_R]] {axis = 4 : i32} : (tensor<4x5x7x3x2xf32>) -> tensor<4x5x7x3x2xf32> -// CHECK: %[[CONCAT_LR:.*]] = tosa.concat %[[REVERSE_L]], %[[VAL_0]], %[[REVERSE_R]] {axis = 4 : i32} : (tensor<4x5x7x3x2xf32>, tensor<4x5x7x3x4xf32>, tensor<4x5x7x3x2xf32>) -> tensor<4x5x7x3x8xf32> -// CHECK: %[[SLICE_T:.*]] = tosa.slice %[[CONCAT_LR]] {size = array, start = array} : (tensor<4x5x7x3x8xf32>) -> tensor<4x5x7x2x8xf32> -// CHECK: %[[REVERSE_T:.*]] = tosa.reverse %[[SLICE_T]] {axis = 3 : i32} : (tensor<4x5x7x2x8xf32>) -> tensor<4x5x7x2x8xf32> -// CHECK: %[[SLICE_B:.*]] = tosa.slice %[[CONCAT_LR]] {size = array, start = array} : (tensor<4x5x7x3x8xf32>) -> tensor<4x5x7x2x8xf32> -// CHECK: %[[REVERSE_B:.*]] = tosa.reverse %[[SLICE_B]] {axis = 3 : i32} : (tensor<4x5x7x2x8xf32>) -> tensor<4x5x7x2x8xf32> -// CHECK: %[[CONCAT_TB:.*]] = tosa.concat %[[REVERSE_T]], %[[CONCAT_LR]], %[[REVERSE_B]] {axis = 3 : i32} : (tensor<4x5x7x2x8xf32>, tensor<4x5x7x3x8xf32>, tensor<4x5x7x2x8xf32>) -> tensor<4x5x7x7x8xf32> -// CHECK: %[[SLICE_F:.*]] = tosa.slice %[[CONCAT_TB]] {size = array, start = array} : (tensor<4x5x7x7x8xf32>) -> tensor<4x5x2x7x8xf32> -// CHECK: %[[REVERSE_F:.*]] = tosa.reverse %[[SLICE_F]] {axis = 2 : i32} : (tensor<4x5x2x7x8xf32>) -> tensor<4x5x2x7x8xf32> -// CHECK: %[[SLICE_BACK:.*]] = tosa.slice %[[CONCAT_TB]] {size = array, start = array} : (tensor<4x5x7x7x8xf32>) -> tensor<4x5x2x7x8xf32> -// CHECK: %[[REVERSE_BACK:.*]] = tosa.reverse %[[SLICE_BACK]] {axis = 2 : i32} : (tensor<4x5x2x7x8xf32>) -> tensor<4x5x2x7x8xf32> -// CHECK: %[[CONCAT_FB:.*]] = tosa.concat %[[REVERSE_F]], %[[CONCAT_TB]], %[[REVERSE_BACK]] {axis = 2 : i32} : (tensor<4x5x2x7x8xf32>, tensor<4x5x7x7x8xf32>, tensor<4x5x2x7x8xf32>) -> tensor<4x5x11x7x8xf32> -// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[CONCAT_FB]] : tensor<4x5x11x7x8xf32> -> !torch.vtensor<[4,5,11,7,8],f32> -// CHECK: return %[[RESULT]] +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,5,7,3,4],f32>) -> !torch.vtensor<[4,5,11,7,8],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,5,7,3,4],f32> -> tensor<4x5x7x3x4xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %[[VAL_2]], %[[VAL_2]], %[[VAL_2]], %[[VAL_2]], %[[VAL_2]], %[[VAL_2]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_4:.*]] = tosa.const_shape {value = dense<[0, 0, 0, 0, 1]> : tensor<5xindex>} : () -> !tosa.shape<5> +// CHECK: %[[VAL_5:.*]] = tosa.const_shape {value = dense<[4, 5, 7, 3, 2]> : tensor<5xindex>} : () -> !tosa.shape<5> +// CHECK: %[[VAL_6:.*]] = tosa.slice %[[VAL_1]], %[[VAL_4]], %[[VAL_5]] : (tensor<4x5x7x3x4xf32>, !tosa.shape<5>, !tosa.shape<5>) -> tensor<4x5x7x3x2xf32> +// CHECK: %[[VAL_7:.*]] = tosa.reverse %[[VAL_6]] {axis = 4 : i32} : (tensor<4x5x7x3x2xf32>) -> tensor<4x5x7x3x2xf32> +// CHECK: %[[VAL_8:.*]] = tosa.const_shape {value = dense<[0, 0, 0, 0, 1]> : tensor<5xindex>} : () -> !tosa.shape<5> +// CHECK: %[[VAL_9:.*]] = tosa.const_shape {value = dense<[4, 5, 7, 3, 2]> : tensor<5xindex>} : () -> !tosa.shape<5> +// CHECK: %[[VAL_10:.*]] = tosa.slice %[[VAL_1]], %[[VAL_8]], %[[VAL_9]] : (tensor<4x5x7x3x4xf32>, !tosa.shape<5>, !tosa.shape<5>) -> tensor<4x5x7x3x2xf32> +// CHECK: %[[VAL_11:.*]] = tosa.reverse %[[VAL_10]] {axis = 4 : i32} : (tensor<4x5x7x3x2xf32>) -> tensor<4x5x7x3x2xf32> +// CHECK: %[[VAL_12:.*]] = tosa.concat %[[VAL_7]], %[[VAL_1]], %[[VAL_11]] {axis = 4 : i32} : (tensor<4x5x7x3x2xf32>, tensor<4x5x7x3x4xf32>, tensor<4x5x7x3x2xf32>) -> tensor<4x5x7x3x8xf32> +// CHECK: %[[VAL_13:.*]] = tosa.const_shape {value = dense<[0, 0, 0, 1, 0]> : tensor<5xindex>} : () -> !tosa.shape<5> +// CHECK: %[[VAL_14:.*]] = tosa.const_shape {value = dense<[4, 5, 7, 2, 8]> : tensor<5xindex>} : () -> !tosa.shape<5> +// CHECK: %[[VAL_15:.*]] = tosa.slice %[[VAL_12]], %[[VAL_13]], %[[VAL_14]] : (tensor<4x5x7x3x8xf32>, !tosa.shape<5>, !tosa.shape<5>) -> tensor<4x5x7x2x8xf32> +// CHECK: %[[VAL_16:.*]] = tosa.reverse %[[VAL_15]] {axis = 3 : i32} : (tensor<4x5x7x2x8xf32>) -> tensor<4x5x7x2x8xf32> +// CHECK: %[[VAL_17:.*]] = tosa.const_shape {value = dense<0> : tensor<5xindex>} : () -> !tosa.shape<5> +// CHECK: %[[VAL_18:.*]] = tosa.const_shape {value = dense<[4, 5, 7, 2, 8]> : tensor<5xindex>} : () -> !tosa.shape<5> +// CHECK: %[[VAL_19:.*]] = tosa.slice %[[VAL_12]], %[[VAL_17]], %[[VAL_18]] : (tensor<4x5x7x3x8xf32>, !tosa.shape<5>, !tosa.shape<5>) -> tensor<4x5x7x2x8xf32> +// CHECK: %[[VAL_20:.*]] = tosa.reverse %[[VAL_19]] {axis = 3 : i32} : (tensor<4x5x7x2x8xf32>) -> tensor<4x5x7x2x8xf32> +// CHECK: %[[VAL_21:.*]] = tosa.concat %[[VAL_16]], %[[VAL_12]], %[[VAL_20]] {axis = 3 : i32} : (tensor<4x5x7x2x8xf32>, tensor<4x5x7x3x8xf32>, tensor<4x5x7x2x8xf32>) -> tensor<4x5x7x7x8xf32> +// CHECK: %[[VAL_22:.*]] = tosa.const_shape {value = dense<[0, 0, 1, 0, 0]> : tensor<5xindex>} : () -> !tosa.shape<5> +// CHECK: %[[VAL_23:.*]] = tosa.const_shape {value = dense<[4, 5, 2, 7, 8]> : tensor<5xindex>} : () -> !tosa.shape<5> +// CHECK: %[[VAL_24:.*]] = tosa.slice %[[VAL_21]], %[[VAL_22]], %[[VAL_23]] : (tensor<4x5x7x7x8xf32>, !tosa.shape<5>, !tosa.shape<5>) -> tensor<4x5x2x7x8xf32> +// CHECK: %[[VAL_25:.*]] = tosa.reverse %[[VAL_24]] {axis = 2 : i32} : (tensor<4x5x2x7x8xf32>) -> tensor<4x5x2x7x8xf32> +// CHECK: %[[VAL_26:.*]] = tosa.const_shape {value = dense<[0, 0, 4, 0, 0]> : tensor<5xindex>} : () -> !tosa.shape<5> +// CHECK: %[[VAL_27:.*]] = tosa.const_shape {value = dense<[4, 5, 2, 7, 8]> : tensor<5xindex>} : () -> !tosa.shape<5> +// CHECK: %[[VAL_28:.*]] = tosa.slice %[[VAL_21]], %[[VAL_26]], %[[VAL_27]] : (tensor<4x5x7x7x8xf32>, !tosa.shape<5>, !tosa.shape<5>) -> tensor<4x5x2x7x8xf32> +// CHECK: %[[VAL_29:.*]] = tosa.reverse %[[VAL_28]] {axis = 2 : i32} : (tensor<4x5x2x7x8xf32>) -> tensor<4x5x2x7x8xf32> +// CHECK: %[[VAL_30:.*]] = tosa.concat %[[VAL_25]], %[[VAL_21]], %[[VAL_29]] {axis = 2 : i32} : (tensor<4x5x2x7x8xf32>, tensor<4x5x7x7x8xf32>, tensor<4x5x2x7x8xf32>) -> tensor<4x5x11x7x8xf32> +// CHECK: %[[VAL_31:.*]] = torch_c.from_builtin_tensor %[[VAL_30]] : tensor<4x5x11x7x8xf32> -> !torch.vtensor<[4,5,11,7,8],f32> +// CHECK: return %[[VAL_31]] : !torch.vtensor<[4,5,11,7,8],f32> +// CHECK: } func.func @torch.aten.reflection_pad3d$basic(%arg0: !torch.vtensor<[4,5,7,3,4],f32>) -> !torch.vtensor<[4,5,11,7,8],f32> { %int2 = torch.constant.int 2 %0 = torch.prim.ListConstruct %int2, %int2, %int2, %int2, %int2, %int2 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list @@ -2887,14 +3197,22 @@ func.func @torch.aten.reflection_pad3d$basic(%arg0: !torch.vtensor<[4,5,7,3,4],f // CHECK: %[[VAL_4:.*]] = torch.constant.int 3 // CHECK: %[[VAL_5:.*]] = torch.constant.int 4 // CHECK: %[[VAL_6:.*]] = torch.prim.ListConstruct %[[VAL_2]], %[[VAL_3]], %[[VAL_4]], %[[VAL_5]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list -// CHECK: %[[VAL_7:.*]] = tosa.slice %[[VAL_1]] {size = array, start = array} : (tensor<1x1x3x3xf32>) -> tensor<1x1x3x1xf32> -// CHECK: %[[VAL_8:.*]] = tosa.slice %[[VAL_1]] {size = array, start = array} : (tensor<1x1x3x3xf32>) -> tensor<1x1x3x1xf32> -// CHECK: %[[VAL_9:.*]] = tosa.concat %[[VAL_7]], %[[VAL_1]], %[[VAL_8]], %[[VAL_8]] {axis = 3 : i32} : (tensor<1x1x3x1xf32>, tensor<1x1x3x3xf32>, tensor<1x1x3x1xf32>, tensor<1x1x3x1xf32>) -> tensor<1x1x3x6xf32> -// CHECK: %[[VAL_10:.*]] = tosa.slice %[[VAL_9]] {size = array, start = array} : (tensor<1x1x3x6xf32>) -> tensor<1x1x1x6xf32> -// CHECK: %[[VAL_11:.*]] = tosa.slice %[[VAL_9]] {size = array, start = array} : (tensor<1x1x3x6xf32>) -> tensor<1x1x1x6xf32> -// CHECK: %[[VAL_12:.*]] = tosa.concat %[[VAL_10]], %[[VAL_10]], %[[VAL_10]], %[[VAL_9]], %[[VAL_11]], %[[VAL_11]], %[[VAL_11]], %[[VAL_11]] {axis = 2 : i32} : (tensor<1x1x1x6xf32>, tensor<1x1x1x6xf32>, tensor<1x1x1x6xf32>, tensor<1x1x3x6xf32>, tensor<1x1x1x6xf32>, tensor<1x1x1x6xf32>, tensor<1x1x1x6xf32>, tensor<1x1x1x6xf32>) -> tensor<1x1x10x6xf32> -// CHECK: %[[VAL_13:.*]] = torch_c.from_builtin_tensor %[[VAL_12]] : tensor<1x1x10x6xf32> -> !torch.vtensor<[1,1,10,6],f32> -// CHECK: return %[[VAL_13]] : !torch.vtensor<[1,1,10,6],f32> +// CHECK: %[[VAL_7:.*]] = tosa.const_shape {value = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_8:.*]] = tosa.const_shape {value = dense<[1, 1, 3, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_9:.*]] = tosa.slice %[[VAL_1]], %[[VAL_7]], %[[VAL_8]] : (tensor<1x1x3x3xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<1x1x3x1xf32> +// CHECK: %[[VAL_10:.*]] = tosa.const_shape {value = dense<[0, 0, 0, 2]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_11:.*]] = tosa.const_shape {value = dense<[1, 1, 3, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_12:.*]] = tosa.slice %[[VAL_1]], %[[VAL_10]], %[[VAL_11]] : (tensor<1x1x3x3xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<1x1x3x1xf32> +// CHECK: %[[VAL_13:.*]] = tosa.concat %[[VAL_9]], %[[VAL_1]], %[[VAL_12]], %[[VAL_12]] {axis = 3 : i32} : (tensor<1x1x3x1xf32>, tensor<1x1x3x3xf32>, tensor<1x1x3x1xf32>, tensor<1x1x3x1xf32>) -> tensor<1x1x3x6xf32> +// CHECK: %[[VAL_14:.*]] = tosa.const_shape {value = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_15:.*]] = tosa.const_shape {value = dense<[1, 1, 1, 6]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_16:.*]] = tosa.slice %[[VAL_13]], %[[VAL_14]], %[[VAL_15]] : (tensor<1x1x3x6xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<1x1x1x6xf32> +// CHECK: %[[VAL_17:.*]] = tosa.const_shape {value = dense<[0, 0, 2, 0]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_18:.*]] = tosa.const_shape {value = dense<[1, 1, 1, 6]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_19:.*]] = tosa.slice %[[VAL_13]], %[[VAL_17]], %[[VAL_18]] : (tensor<1x1x3x6xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<1x1x1x6xf32> +// CHECK: %[[VAL_20:.*]] = tosa.concat %[[VAL_16]], %[[VAL_16]], %[[VAL_16]], %[[VAL_13]], %[[VAL_19]], %[[VAL_19]], %[[VAL_19]], %[[VAL_19]] {axis = 2 : i32} : (tensor<1x1x1x6xf32>, tensor<1x1x1x6xf32>, tensor<1x1x1x6xf32>, tensor<1x1x3x6xf32>, tensor<1x1x1x6xf32>, tensor<1x1x1x6xf32>, tensor<1x1x1x6xf32>, tensor<1x1x1x6xf32>) -> tensor<1x1x10x6xf32> +// CHECK: %[[VAL_21:.*]] = torch_c.from_builtin_tensor %[[VAL_20]] : tensor<1x1x10x6xf32> -> !torch.vtensor<[1,1,10,6],f32> +// CHECK: return %[[VAL_21]] : !torch.vtensor<[1,1,10,6],f32> // CHECK: } func.func @torch.aten.replication_pad2d$basic(%arg0: !torch.vtensor<[1,1,3,3],f32>) -> !torch.vtensor<[1,1,10,6],f32> { %int1 = torch.constant.int 1 @@ -2913,15 +3231,18 @@ func.func @torch.aten.replication_pad2d$basic(%arg0: !torch.vtensor<[1,1,3,3],f3 // CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[4],f32>) -> !torch.vtensor<[3,4],f32> { // CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[4],f32> -> tensor<4xf32> // CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3],f32> -> tensor<3xf32> -// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor<3xf32>) -> tensor<3x1xf32> -// CHECK: %[[VAL_5:.*]] = tosa.const_shape {value = dense<[1, 4]> : tensor<2xindex>} : () -> !tosa.shape<2> -// CHECK: %[[VAL_6:.*]] = tosa.tile %[[VAL_4]], %[[VAL_5]] : (tensor<3x1xf32>, !tosa.shape<2>) -> tensor<3x4xf32> -// CHECK: %[[VAL_7:.*]] = tosa.reshape %[[VAL_2]] {new_shape = array} : (tensor<4xf32>) -> tensor<1x4xf32> -// CHECK: %[[VAL_8:.*]] = tosa.const_shape {value = dense<[3, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> -// CHECK: %[[VAL_9:.*]] = tosa.tile %[[VAL_7]], %[[VAL_8]] : (tensor<1x4xf32>, !tosa.shape<2>) -> tensor<3x4xf32> -// CHECK: %[[VAL_10:.*]] = tosa.mul %[[VAL_6]], %[[VAL_9]] {shift = 0 : i8} : (tensor<3x4xf32>, tensor<3x4xf32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_11:.*]] = torch_c.from_builtin_tensor %[[VAL_10]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> -// CHECK: return %[[VAL_11]] : !torch.vtensor<[3,4],f32> +// CHECK: %[[VAL_4:.*]] = tosa.const_shape {value = dense<[3, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_5:.*]] = tosa.reshape %[[VAL_3]], %[[VAL_4]] : (tensor<3xf32>, !tosa.shape<2>) -> tensor<3x1xf32> +// CHECK: %[[VAL_6:.*]] = tosa.const_shape {value = dense<[1, 4]> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_7:.*]] = tosa.tile %[[VAL_5]], %[[VAL_6]] : (tensor<3x1xf32>, !tosa.shape<2>) -> tensor<3x4xf32> +// CHECK: %[[VAL_8:.*]] = tosa.const_shape {value = dense<[1, 4]> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_9:.*]] = tosa.reshape %[[VAL_2]], %[[VAL_8]] : (tensor<4xf32>, !tosa.shape<2>) -> tensor<1x4xf32> +// CHECK: %[[VAL_10:.*]] = tosa.const_shape {value = dense<[3, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_11:.*]] = tosa.tile %[[VAL_9]], %[[VAL_10]] : (tensor<1x4xf32>, !tosa.shape<2>) -> tensor<3x4xf32> +// CHECK: %[[VAL_12:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_13:.*]] = tosa.mul %[[VAL_7]], %[[VAL_11]], %[[VAL_12]] : (tensor<3x4xf32>, tensor<3x4xf32>, tensor<1xi8>) -> tensor<3x4xf32> +// CHECK: %[[VAL_14:.*]] = torch_c.from_builtin_tensor %[[VAL_13]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> +// CHECK: return %[[VAL_14]] : !torch.vtensor<[3,4],f32> // CHECK: } func.func @torch.aten.outer$basic(%arg0: !torch.vtensor<[3],f32>, %arg1: !torch.vtensor<[4],f32>) -> !torch.vtensor<[3,4],f32> { %0 = torch.aten.outer %arg0, %arg1 : !torch.vtensor<[3],f32>, !torch.vtensor<[4],f32> -> !torch.vtensor<[3,4],f32> @@ -2935,10 +3256,12 @@ func.func @torch.aten.outer$basic(%arg0: !torch.vtensor<[3],f32>, %arg1: !torch. // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,8,3,3],si64> -> tensor<1x8x3x3xi64> // CHECK: %[[VAL_2:.*]] = torch.constant.int 1 // CHECK: %[[VAL_3:.*]] = torch.constant.int 2 -// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array} : (tensor<1x8x3x3xi64>) -> tensor<1x2x4x3x3xi64> -// CHECK: %[[VAL_5:.*]] = tosa.reshape %[[VAL_4]] {new_shape = array} : (tensor<1x2x4x3x3xi64>) -> tensor<1x2x2x2x3x3xi64> -// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<1x2x2x2x3x3xi64> -> !torch.vtensor<[1,2,2,2,3,3],si64> -// CHECK: return %[[VAL_6]] : !torch.vtensor<[1,2,2,2,3,3],si64> +// CHECK: %[[VAL_4:.*]] = tosa.const_shape {value = dense<[1, 2, 4, 3, 3]> : tensor<5xindex>} : () -> !tosa.shape<5> +// CHECK: %[[VAL_5:.*]] = tosa.reshape %[[VAL_1]], %[[VAL_4]] : (tensor<1x8x3x3xi64>, !tosa.shape<5>) -> tensor<1x2x4x3x3xi64> +// CHECK: %[[VAL_6:.*]] = tosa.const_shape {value = dense<[1, 2, 2, 2, 3, 3]> : tensor<6xindex>} : () -> !tosa.shape<6> +// CHECK: %[[VAL_7:.*]] = tosa.reshape %[[VAL_5]], %[[VAL_6]] : (tensor<1x2x4x3x3xi64>, !tosa.shape<6>) -> tensor<1x2x2x2x3x3xi64> +// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<1x2x2x2x3x3xi64> -> !torch.vtensor<[1,2,2,2,3,3],si64> +// CHECK: return %[[VAL_8]] : !torch.vtensor<[1,2,2,2,3,3],si64> // CHECK: } func.func @torch.prims.split_dim$basic(%arg0: !torch.vtensor<[1,8,3,3],si64>) -> !torch.vtensor<[1,2,2,2,3,3],si64> { %int1 = torch.constant.int 1 @@ -2958,24 +3281,33 @@ func.func @torch.prims.split_dim$basic(%arg0: !torch.vtensor<[1,8,3,3],si64>) -> // CHECK: %[[VAL_4:.*]] = torch.constant.int 8 // CHECK: %[[VAL_5:.*]] = torch.constant.int 9 // CHECK: %[[VAL_6:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_5]] : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[VAL_7:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array} : (tensor<1x1x2x3xf64>) -> tensor<1x1x6xf64> -// CHECK: %[[VAL_8:.*]] = "tosa.const"() <{value = dense<{{\[\[}}[0, 0, 0, 1, 1, 1, 2, 2, 2, 0, 0, 0, 1, 1, 1, 2, 2, 2, 0, 0, 0, 1, 1, 1, 2, 2, 2, 0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 3, 3, 3, 4, 4, 4, 5, 5, 5, 3, 3, 3, 4, 4, 4, 5, 5, 5, 3, 3, 3, 4, 4, 4, 5, 5, 5]]]> : tensor<1x1x72xi32>}> : () -> tensor<1x1x72xi32> -// CHECK: %[[VAL_9:.*]] = tosa.reshape %[[VAL_8]] {new_shape = array} : (tensor<1x1x72xi32>) -> tensor<1x1x72x1xi32> -// CHECK: %[[VAL_10:.*]] = "tosa.const"() <{value = dense<0> : tensor<1x1x72x1xi32>}> : () -> tensor<1x1x72x1xi32> -// CHECK: %[[VAL_11:.*]] = "tosa.const"() <{value = dense<0> : tensor<1x1x72x1xi32>}> : () -> tensor<1x1x72x1xi32> -// CHECK: %[[VAL_12:.*]] = tosa.concat %[[VAL_10]], %[[VAL_11]], %[[VAL_9]] {axis = 3 : i32} : (tensor<1x1x72x1xi32>, tensor<1x1x72x1xi32>, tensor<1x1x72x1xi32>) -> tensor<1x1x72x3xi32> -// CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_7]] {new_shape = array} : (tensor<1x1x6xf64>) -> tensor<1x6x1xf64> -// CHECK: %[[VAL_14:.*]] = tosa.reshape %[[VAL_12]] {new_shape = array} : (tensor<1x1x72x3xi32>) -> tensor<72x3xi32> -// CHECK: %[[VAL_15:.*]] = "tosa.const"() <{value = dense<[6, 6, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> -// CHECK: %[[VAL_16:.*]] = tosa.reshape %[[VAL_15]] {new_shape = array} : (tensor<3xi32>) -> tensor<1x3xi32> -// CHECK: %[[VAL_17:.*]] = tosa.mul %[[VAL_14]], %[[VAL_16]] {shift = 0 : i8} : (tensor<72x3xi32>, tensor<1x3xi32>) -> tensor<72x3xi32> -// CHECK: %[[VAL_18:.*]] = tosa.reduce_sum %[[VAL_17]] {axis = 1 : i32} : (tensor<72x3xi32>) -> tensor<72x1xi32> -// CHECK: %[[VAL_19:.*]] = tosa.reshape %[[VAL_18]] {new_shape = array} : (tensor<72x1xi32>) -> tensor<1x72xi32> -// CHECK: %[[VAL_20:.*]] = tosa.gather %[[VAL_13]], %[[VAL_19]] : (tensor<1x6x1xf64>, tensor<1x72xi32>) -> tensor<1x72x1xf64> -// CHECK: %[[VAL_21:.*]] = tosa.reshape %[[VAL_20]] {new_shape = array} : (tensor<1x72x1xf64>) -> tensor<1x1x72xf64> -// CHECK: %[[VAL_22:.*]] = tosa.reshape %[[VAL_21]] {new_shape = array} : (tensor<1x1x72xf64>) -> tensor<1x1x8x9xf64> -// CHECK: %[[VAL_23:.*]] = torch_c.from_builtin_tensor %[[VAL_22]] : tensor<1x1x8x9xf64> -> !torch.vtensor<[1,1,8,9],f64> -// CHECK: return %[[VAL_23]] : !torch.vtensor<[1,1,8,9],f64> +// CHECK: %[[VAL_7:.*]] = tosa.const_shape {value = dense<[1, 1, 6]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_1]], %[[VAL_7]] : (tensor<1x1x2x3xf64>, !tosa.shape<3>) -> tensor<1x1x6xf64> +// CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<{{\[\[}}[0, 0, 0, 1, 1, 1, 2, 2, 2, 0, 0, 0, 1, 1, 1, 2, 2, 2, 0, 0, 0, 1, 1, 1, 2, 2, 2, 0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 3, 3, 3, 4, 4, 4, 5, 5, 5, 3, 3, 3, 4, 4, 4, 5, 5, 5, 3, 3, 3, 4, 4, 4, 5, 5, 5]]]> : tensor<1x1x72xi32>}> : () -> tensor<1x1x72xi32> +// CHECK: %[[VAL_10:.*]] = tosa.const_shape {value = dense<[1, 1, 72, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_11:.*]] = tosa.reshape %[[VAL_9]], %[[VAL_10]] : (tensor<1x1x72xi32>, !tosa.shape<4>) -> tensor<1x1x72x1xi32> +// CHECK: %[[VAL_12:.*]] = "tosa.const"() <{value = dense<0> : tensor<1x1x72x1xi32>}> : () -> tensor<1x1x72x1xi32> +// CHECK: %[[VAL_13:.*]] = "tosa.const"() <{value = dense<0> : tensor<1x1x72x1xi32>}> : () -> tensor<1x1x72x1xi32> +// CHECK: %[[VAL_14:.*]] = tosa.concat %[[VAL_12]], %[[VAL_13]], %[[VAL_11]] {axis = 3 : i32} : (tensor<1x1x72x1xi32>, tensor<1x1x72x1xi32>, tensor<1x1x72x1xi32>) -> tensor<1x1x72x3xi32> +// CHECK: %[[VAL_15:.*]] = tosa.const_shape {value = dense<[1, 6, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_16:.*]] = tosa.reshape %[[VAL_8]], %[[VAL_15]] : (tensor<1x1x6xf64>, !tosa.shape<3>) -> tensor<1x6x1xf64> +// CHECK: %[[VAL_17:.*]] = tosa.const_shape {value = dense<[72, 3]> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_18:.*]] = tosa.reshape %[[VAL_14]], %[[VAL_17]] : (tensor<1x1x72x3xi32>, !tosa.shape<2>) -> tensor<72x3xi32> +// CHECK: %[[VAL_19:.*]] = "tosa.const"() <{value = dense<[6, 6, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> +// CHECK: %[[VAL_20:.*]] = tosa.const_shape {value = dense<[1, 3]> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_21:.*]] = tosa.reshape %[[VAL_19]], %[[VAL_20]] : (tensor<3xi32>, !tosa.shape<2>) -> tensor<1x3xi32> +// CHECK: %[[VAL_22:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_23:.*]] = tosa.mul %[[VAL_18]], %[[VAL_21]], %[[VAL_22]] : (tensor<72x3xi32>, tensor<1x3xi32>, tensor<1xi8>) -> tensor<72x3xi32> +// CHECK: %[[VAL_24:.*]] = tosa.reduce_sum %[[VAL_23]] {axis = 1 : i32} : (tensor<72x3xi32>) -> tensor<72x1xi32> +// CHECK: %[[VAL_25:.*]] = tosa.const_shape {value = dense<[1, 72]> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_26:.*]] = tosa.reshape %[[VAL_24]], %[[VAL_25]] : (tensor<72x1xi32>, !tosa.shape<2>) -> tensor<1x72xi32> +// CHECK: %[[VAL_27:.*]] = tosa.gather %[[VAL_16]], %[[VAL_26]] : (tensor<1x6x1xf64>, tensor<1x72xi32>) -> tensor<1x72x1xf64> +// CHECK: %[[VAL_28:.*]] = tosa.const_shape {value = dense<[1, 1, 72]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_29:.*]] = tosa.reshape %[[VAL_27]], %[[VAL_28]] : (tensor<1x72x1xf64>, !tosa.shape<3>) -> tensor<1x1x72xf64> +// CHECK: %[[VAL_30:.*]] = tosa.const_shape {value = dense<[1, 1, 8, 9]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_31:.*]] = tosa.reshape %[[VAL_29]], %[[VAL_30]] : (tensor<1x1x72xf64>, !tosa.shape<4>) -> tensor<1x1x8x9xf64> +// CHECK: %[[VAL_32:.*]] = torch_c.from_builtin_tensor %[[VAL_31]] : tensor<1x1x8x9xf64> -> !torch.vtensor<[1,1,8,9],f64> +// CHECK: return %[[VAL_32]] : !torch.vtensor<[1,1,8,9],f64> // CHECK: } func.func @torch.aten.upsample_nearest2d$basic(%arg0: !torch.vtensor<[1,1,2,3],f64>) -> !torch.vtensor<[1,1,8,9],f64> { %float4.000000e00 = torch.constant.float 4.000000e+00 @@ -2996,24 +3328,33 @@ func.func @torch.aten.upsample_nearest2d$basic(%arg0: !torch.vtensor<[1,1,2,3],f // CHECK: %[[VAL_3:.*]] = torch.constant.int 2 // CHECK: %[[VAL_4:.*]] = torch.constant.int 7 // CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_4]] : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array} : (tensor<1x1x4x5xf32>) -> tensor<1x1x20xf32> -// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<{{\[\[}}[0, 0, 1, 2, 2, 3, 4, 10, 10, 11, 12, 12, 13, 14]]]> : tensor<1x1x14xi32>}> : () -> tensor<1x1x14xi32> -// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_7]] {new_shape = array} : (tensor<1x1x14xi32>) -> tensor<1x1x14x1xi32> -// CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<0> : tensor<1x1x14x1xi32>}> : () -> tensor<1x1x14x1xi32> -// CHECK: %[[VAL_10:.*]] = "tosa.const"() <{value = dense<0> : tensor<1x1x14x1xi32>}> : () -> tensor<1x1x14x1xi32> -// CHECK: %[[VAL_11:.*]] = tosa.concat %[[VAL_9]], %[[VAL_10]], %[[VAL_8]] {axis = 3 : i32} : (tensor<1x1x14x1xi32>, tensor<1x1x14x1xi32>, tensor<1x1x14x1xi32>) -> tensor<1x1x14x3xi32> -// CHECK: %[[VAL_12:.*]] = tosa.reshape %[[VAL_6]] {new_shape = array} : (tensor<1x1x20xf32>) -> tensor<1x20x1xf32> -// CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_11]] {new_shape = array} : (tensor<1x1x14x3xi32>) -> tensor<14x3xi32> -// CHECK: %[[VAL_14:.*]] = "tosa.const"() <{value = dense<[20, 20, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> -// CHECK: %[[VAL_15:.*]] = tosa.reshape %[[VAL_14]] {new_shape = array} : (tensor<3xi32>) -> tensor<1x3xi32> -// CHECK: %[[VAL_16:.*]] = tosa.mul %[[VAL_13]], %[[VAL_15]] {shift = 0 : i8} : (tensor<14x3xi32>, tensor<1x3xi32>) -> tensor<14x3xi32> -// CHECK: %[[VAL_17:.*]] = tosa.reduce_sum %[[VAL_16]] {axis = 1 : i32} : (tensor<14x3xi32>) -> tensor<14x1xi32> -// CHECK: %[[VAL_18:.*]] = tosa.reshape %[[VAL_17]] {new_shape = array} : (tensor<14x1xi32>) -> tensor<1x14xi32> -// CHECK: %[[VAL_19:.*]] = tosa.gather %[[VAL_12]], %[[VAL_18]] : (tensor<1x20x1xf32>, tensor<1x14xi32>) -> tensor<1x14x1xf32> -// CHECK: %[[VAL_20:.*]] = tosa.reshape %[[VAL_19]] {new_shape = array} : (tensor<1x14x1xf32>) -> tensor<1x1x14xf32> -// CHECK: %[[VAL_21:.*]] = tosa.reshape %[[VAL_20]] {new_shape = array} : (tensor<1x1x14xf32>) -> tensor<1x1x2x7xf32> -// CHECK: %[[VAL_22:.*]] = torch_c.from_builtin_tensor %[[VAL_21]] : tensor<1x1x2x7xf32> -> !torch.vtensor<[1,1,2,7],f32> -// CHECK: return %[[VAL_22]] : !torch.vtensor<[1,1,2,7],f32> +// CHECK: %[[VAL_6:.*]] = tosa.const_shape {value = dense<[1, 1, 20]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_7:.*]] = tosa.reshape %[[VAL_1]], %[[VAL_6]] : (tensor<1x1x4x5xf32>, !tosa.shape<3>) -> tensor<1x1x20xf32> +// CHECK: %[[VAL_8:.*]] = "tosa.const"() <{value = dense<{{\[\[}}[0, 0, 1, 2, 2, 3, 4, 10, 10, 11, 12, 12, 13, 14]]]> : tensor<1x1x14xi32>}> : () -> tensor<1x1x14xi32> +// CHECK: %[[VAL_9:.*]] = tosa.const_shape {value = dense<[1, 1, 14, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_10:.*]] = tosa.reshape %[[VAL_8]], %[[VAL_9]] : (tensor<1x1x14xi32>, !tosa.shape<4>) -> tensor<1x1x14x1xi32> +// CHECK: %[[VAL_11:.*]] = "tosa.const"() <{value = dense<0> : tensor<1x1x14x1xi32>}> : () -> tensor<1x1x14x1xi32> +// CHECK: %[[VAL_12:.*]] = "tosa.const"() <{value = dense<0> : tensor<1x1x14x1xi32>}> : () -> tensor<1x1x14x1xi32> +// CHECK: %[[VAL_13:.*]] = tosa.concat %[[VAL_11]], %[[VAL_12]], %[[VAL_10]] {axis = 3 : i32} : (tensor<1x1x14x1xi32>, tensor<1x1x14x1xi32>, tensor<1x1x14x1xi32>) -> tensor<1x1x14x3xi32> +// CHECK: %[[VAL_14:.*]] = tosa.const_shape {value = dense<[1, 20, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_15:.*]] = tosa.reshape %[[VAL_7]], %[[VAL_14]] : (tensor<1x1x20xf32>, !tosa.shape<3>) -> tensor<1x20x1xf32> +// CHECK: %[[VAL_16:.*]] = tosa.const_shape {value = dense<[14, 3]> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_17:.*]] = tosa.reshape %[[VAL_13]], %[[VAL_16]] : (tensor<1x1x14x3xi32>, !tosa.shape<2>) -> tensor<14x3xi32> +// CHECK: %[[VAL_18:.*]] = "tosa.const"() <{value = dense<[20, 20, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> +// CHECK: %[[VAL_19:.*]] = tosa.const_shape {value = dense<[1, 3]> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_20:.*]] = tosa.reshape %[[VAL_18]], %[[VAL_19]] : (tensor<3xi32>, !tosa.shape<2>) -> tensor<1x3xi32> +// CHECK: %[[VAL_21:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_22:.*]] = tosa.mul %[[VAL_17]], %[[VAL_20]], %[[VAL_21]] : (tensor<14x3xi32>, tensor<1x3xi32>, tensor<1xi8>) -> tensor<14x3xi32> +// CHECK: %[[VAL_23:.*]] = tosa.reduce_sum %[[VAL_22]] {axis = 1 : i32} : (tensor<14x3xi32>) -> tensor<14x1xi32> +// CHECK: %[[VAL_24:.*]] = tosa.const_shape {value = dense<[1, 14]> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_25:.*]] = tosa.reshape %[[VAL_23]], %[[VAL_24]] : (tensor<14x1xi32>, !tosa.shape<2>) -> tensor<1x14xi32> +// CHECK: %[[VAL_26:.*]] = tosa.gather %[[VAL_15]], %[[VAL_25]] : (tensor<1x20x1xf32>, tensor<1x14xi32>) -> tensor<1x14x1xf32> +// CHECK: %[[VAL_27:.*]] = tosa.const_shape {value = dense<[1, 1, 14]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_28:.*]] = tosa.reshape %[[VAL_26]], %[[VAL_27]] : (tensor<1x14x1xf32>, !tosa.shape<3>) -> tensor<1x1x14xf32> +// CHECK: %[[VAL_29:.*]] = tosa.const_shape {value = dense<[1, 1, 2, 7]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_30:.*]] = tosa.reshape %[[VAL_28]], %[[VAL_29]] : (tensor<1x1x14xf32>, !tosa.shape<4>) -> tensor<1x1x2x7xf32> +// CHECK: %[[VAL_31:.*]] = torch_c.from_builtin_tensor %[[VAL_30]] : tensor<1x1x2x7xf32> -> !torch.vtensor<[1,1,2,7],f32> +// CHECK: return %[[VAL_31]] : !torch.vtensor<[1,1,2,7],f32> // CHECK: } func.func @torch.aten.upsample_nearest2d.vec$basic(%arg0: !torch.vtensor<[1,1,4,5],f32>) -> !torch.vtensor<[1,1,2,7],f32> { %none = torch.constant.none @@ -3035,17 +3376,21 @@ func.func @torch.aten.upsample_nearest2d.vec$basic(%arg0: !torch.vtensor<[1,1,4, // CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<3.000000e+00> : tensor<5x3xf32>}> : () -> tensor<5x3xf32> // CHECK: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<4.471500e-02> : tensor<5x3xf32>}> : () -> tensor<5x3xf32> // CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<0.636619746> : tensor<5x3xf32>}> : () -> tensor<5x3xf32> -// CHECK: %[[VAL_8:.*]] = tosa.mul %[[VAL_3]], %[[VAL_1]] {shift = 0 : i8} : (tensor<5x3xf32>, tensor<5x3xf32>) -> tensor<5x3xf32> -// CHECK: %[[VAL_9:.*]] = tosa.pow %[[VAL_7]], %[[VAL_3]] : (tensor<5x3xf32>, tensor<5x3xf32>) -> tensor<5x3xf32> -// CHECK: %[[VAL_10:.*]] = tosa.pow %[[VAL_1]], %[[VAL_5]] : (tensor<5x3xf32>, tensor<5x3xf32>) -> tensor<5x3xf32> -// CHECK: %[[VAL_11:.*]] = tosa.mul %[[VAL_6]], %[[VAL_10]] {shift = 0 : i8} : (tensor<5x3xf32>, tensor<5x3xf32>) -> tensor<5x3xf32> -// CHECK: %[[VAL_12:.*]] = tosa.add %[[VAL_1]], %[[VAL_11]] : (tensor<5x3xf32>, tensor<5x3xf32>) -> tensor<5x3xf32> -// CHECK: %[[VAL_13:.*]] = tosa.mul %[[VAL_9]], %[[VAL_12]] {shift = 0 : i8} : (tensor<5x3xf32>, tensor<5x3xf32>) -> tensor<5x3xf32> -// CHECK: %[[VAL_14:.*]] = tosa.tanh %[[VAL_13]] : (tensor<5x3xf32>) -> tensor<5x3xf32> -// CHECK: %[[VAL_15:.*]] = tosa.add %[[VAL_4]], %[[VAL_14]] : (tensor<5x3xf32>, tensor<5x3xf32>) -> tensor<5x3xf32> -// CHECK: %[[VAL_16:.*]] = tosa.mul %[[VAL_8]], %[[VAL_15]] {shift = 0 : i8} : (tensor<5x3xf32>, tensor<5x3xf32>) -> tensor<5x3xf32> -// CHECK: %[[VAL_17:.*]] = torch_c.from_builtin_tensor %[[VAL_16]] : tensor<5x3xf32> -> !torch.vtensor<[5,3],f32> -// CHECK: return %[[VAL_17]] : !torch.vtensor<[5,3],f32> +// CHECK: %[[VAL_8:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_9:.*]] = tosa.mul %[[VAL_3]], %[[VAL_1]], %[[VAL_8]] : (tensor<5x3xf32>, tensor<5x3xf32>, tensor<1xi8>) -> tensor<5x3xf32> +// CHECK: %[[VAL_10:.*]] = tosa.pow %[[VAL_7]], %[[VAL_3]] : (tensor<5x3xf32>, tensor<5x3xf32>) -> tensor<5x3xf32> +// CHECK: %[[VAL_11:.*]] = tosa.pow %[[VAL_1]], %[[VAL_5]] : (tensor<5x3xf32>, tensor<5x3xf32>) -> tensor<5x3xf32> +// CHECK: %[[VAL_12:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_13:.*]] = tosa.mul %[[VAL_6]], %[[VAL_11]], %[[VAL_12]] : (tensor<5x3xf32>, tensor<5x3xf32>, tensor<1xi8>) -> tensor<5x3xf32> +// CHECK: %[[VAL_14:.*]] = tosa.add %[[VAL_1]], %[[VAL_13]] : (tensor<5x3xf32>, tensor<5x3xf32>) -> tensor<5x3xf32> +// CHECK: %[[VAL_15:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_16:.*]] = tosa.mul %[[VAL_10]], %[[VAL_14]], %[[VAL_15]] : (tensor<5x3xf32>, tensor<5x3xf32>, tensor<1xi8>) -> tensor<5x3xf32> +// CHECK: %[[VAL_17:.*]] = tosa.tanh %[[VAL_16]] : (tensor<5x3xf32>) -> tensor<5x3xf32> +// CHECK: %[[VAL_18:.*]] = tosa.add %[[VAL_4]], %[[VAL_17]] : (tensor<5x3xf32>, tensor<5x3xf32>) -> tensor<5x3xf32> +// CHECK: %[[VAL_19:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_20:.*]] = tosa.mul %[[VAL_9]], %[[VAL_18]], %[[VAL_19]] : (tensor<5x3xf32>, tensor<5x3xf32>, tensor<1xi8>) -> tensor<5x3xf32> +// CHECK: %[[VAL_21:.*]] = torch_c.from_builtin_tensor %[[VAL_20]] : tensor<5x3xf32> -> !torch.vtensor<[5,3],f32> +// CHECK: return %[[VAL_21]] : !torch.vtensor<[5,3],f32> // CHECK: } func.func @torch.aten.gelu$tanh(%arg0: !torch.vtensor<[5,3],f32>) -> !torch.vtensor<[5,3],f32> { %str = torch.constant.str "tanh" @@ -3074,13 +3419,15 @@ func.func @torch.aten.exp$int(%arg0: !torch.vtensor<[3,4],si32>) -> !torch.vtens // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4],f32> -> tensor<3x4xf32> // CHECK: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<1.000000e+01> : tensor}> : () -> tensor -// CHECK: %[[VAL_3:.*]] = tosa.reshape %[[VAL_2]] {new_shape = array} : (tensor) -> tensor<1x1xf32> -// CHECK: %[[VAL_4:.*]] = tosa.log %[[VAL_1]] : (tensor<3x4xf32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_5:.*]] = tosa.log %[[VAL_3]] : (tensor<1x1xf32>) -> tensor<1x1xf32> -// CHECK: %[[VAL_6:.*]] = tosa.reciprocal %[[VAL_5]] : (tensor<1x1xf32>) -> tensor<1x1xf32> -// CHECK: %[[VAL_7:.*]] = tosa.mul %[[VAL_4]], %[[VAL_6]] {shift = 0 : i8} : (tensor<3x4xf32>, tensor<1x1xf32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> -// CHECK: return %[[VAL_8]] : !torch.vtensor<[3,4],f32> +// CHECK: %[[VAL_3:.*]] = tosa.const_shape {value = dense<1> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_2]], %[[VAL_3]] : (tensor, !tosa.shape<2>) -> tensor<1x1xf32> +// CHECK: %[[VAL_5:.*]] = tosa.log %[[VAL_1]] : (tensor<3x4xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_6:.*]] = tosa.log %[[VAL_4]] : (tensor<1x1xf32>) -> tensor<1x1xf32> +// CHECK: %[[VAL_7:.*]] = tosa.reciprocal %[[VAL_6]] : (tensor<1x1xf32>) -> tensor<1x1xf32> +// CHECK: %[[VAL_8:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_9:.*]] = tosa.mul %[[VAL_5]], %[[VAL_7]], %[[VAL_8]] : (tensor<3x4xf32>, tensor<1x1xf32>, tensor<1xi8>) -> tensor<3x4xf32> +// CHECK: %[[VAL_10:.*]] = torch_c.from_builtin_tensor %[[VAL_9]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> +// CHECK: return %[[VAL_10]] : !torch.vtensor<[3,4],f32> // CHECK: } func.func @torch.aten.log10$basic(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> { %0 = torch.aten.log10 %arg0 : !torch.vtensor<[3,4],f32> -> !torch.vtensor<[3,4],f32> @@ -3094,13 +3441,15 @@ func.func @torch.aten.log10$basic(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vt // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4],si32> -> tensor<3x4xi32> // CHECK: %[[VAL_2:.*]] = tosa.cast %[[VAL_1]] : (tensor<3x4xi32>) -> tensor<3x4xf32> // CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<1.000000e+01> : tensor}> : () -> tensor -// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor) -> tensor<1x1xf32> -// CHECK: %[[VAL_5:.*]] = tosa.log %[[VAL_2]] : (tensor<3x4xf32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_6:.*]] = tosa.log %[[VAL_4]] : (tensor<1x1xf32>) -> tensor<1x1xf32> -// CHECK: %[[VAL_7:.*]] = tosa.reciprocal %[[VAL_6]] : (tensor<1x1xf32>) -> tensor<1x1xf32> -// CHECK: %[[VAL_8:.*]] = tosa.mul %[[VAL_5]], %[[VAL_7]] {shift = 0 : i8} : (tensor<3x4xf32>, tensor<1x1xf32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_8]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> -// CHECK: return %[[VAL_9]] : !torch.vtensor<[3,4],f32> +// CHECK: %[[VAL_4:.*]] = tosa.const_shape {value = dense<1> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_5:.*]] = tosa.reshape %[[VAL_3]], %[[VAL_4]] : (tensor, !tosa.shape<2>) -> tensor<1x1xf32> +// CHECK: %[[VAL_6:.*]] = tosa.log %[[VAL_2]] : (tensor<3x4xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_7:.*]] = tosa.log %[[VAL_5]] : (tensor<1x1xf32>) -> tensor<1x1xf32> +// CHECK: %[[VAL_8:.*]] = tosa.reciprocal %[[VAL_7]] : (tensor<1x1xf32>) -> tensor<1x1xf32> +// CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_10:.*]] = tosa.mul %[[VAL_6]], %[[VAL_8]], %[[VAL_9]] : (tensor<3x4xf32>, tensor<1x1xf32>, tensor<1xi8>) -> tensor<3x4xf32> +// CHECK: %[[VAL_11:.*]] = torch_c.from_builtin_tensor %[[VAL_10]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> +// CHECK: return %[[VAL_11]] : !torch.vtensor<[3,4],f32> // CHECK: } func.func @torch.aten.log10$int(%arg0: !torch.vtensor<[3,4],si32>) -> !torch.vtensor<[3,4],f32> { %0 = torch.aten.log10 %arg0 : !torch.vtensor<[3,4],si32> -> !torch.vtensor<[3,4],f32> @@ -3113,11 +3462,12 @@ func.func @torch.aten.log10$int(%arg0: !torch.vtensor<[3,4],si32>) -> !torch.vte // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4],f32> -> tensor<3x4xf32> // CHECK: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor}> : () -> tensor -// CHECK: %[[VAL_3:.*]] = tosa.reshape %[[VAL_2]] {new_shape = array} : (tensor) -> tensor<1x1xf32> -// CHECK: %[[VAL_4:.*]] = tosa.add %[[VAL_1]], %[[VAL_3]] : (tensor<3x4xf32>, tensor<1x1xf32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_5:.*]] = tosa.log %[[VAL_4]] : (tensor<3x4xf32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> -// CHECK: return %[[VAL_6]] : !torch.vtensor<[3,4],f32> +// CHECK: %[[VAL_3:.*]] = tosa.const_shape {value = dense<1> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_2]], %[[VAL_3]] : (tensor, !tosa.shape<2>) -> tensor<1x1xf32> +// CHECK: %[[VAL_5:.*]] = tosa.add %[[VAL_1]], %[[VAL_4]] : (tensor<3x4xf32>, tensor<1x1xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_6:.*]] = tosa.log %[[VAL_5]] : (tensor<3x4xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> +// CHECK: return %[[VAL_7]] : !torch.vtensor<[3,4],f32> // CHECK: } func.func @torch.aten.log1p$basic(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> { %0 = torch.aten.log1p %arg0 : !torch.vtensor<[3,4],f32> -> !torch.vtensor<[3,4],f32> @@ -3131,11 +3481,12 @@ func.func @torch.aten.log1p$basic(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vt // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4],si32> -> tensor<3x4xi32> // CHECK: %[[VAL_2:.*]] = tosa.cast %[[VAL_1]] : (tensor<3x4xi32>) -> tensor<3x4xf32> // CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor}> : () -> tensor -// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor) -> tensor<1x1xf32> -// CHECK: %[[VAL_5:.*]] = tosa.add %[[VAL_2]], %[[VAL_4]] : (tensor<3x4xf32>, tensor<1x1xf32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_6:.*]] = tosa.log %[[VAL_5]] : (tensor<3x4xf32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> -// CHECK: return %[[VAL_7]] : !torch.vtensor<[3,4],f32> +// CHECK: %[[VAL_4:.*]] = tosa.const_shape {value = dense<1> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_5:.*]] = tosa.reshape %[[VAL_3]], %[[VAL_4]] : (tensor, !tosa.shape<2>) -> tensor<1x1xf32> +// CHECK: %[[VAL_6:.*]] = tosa.add %[[VAL_2]], %[[VAL_5]] : (tensor<3x4xf32>, tensor<1x1xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_7:.*]] = tosa.log %[[VAL_6]] : (tensor<3x4xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> +// CHECK: return %[[VAL_8]] : !torch.vtensor<[3,4],f32> // CHECK: } func.func @torch.aten.log1p$int(%arg0: !torch.vtensor<[3,4],si32>) -> !torch.vtensor<[3,4],f32> { %0 = torch.aten.log1p %arg0 : !torch.vtensor<[3,4],si32> -> !torch.vtensor<[3,4],f32> @@ -3150,13 +3501,15 @@ func.func @torch.aten.log1p$int(%arg0: !torch.vtensor<[3,4],si32>) -> !torch.vte // CHECK: %[[VAL_2:.*]] = torch.constant.float 9.9999999999999995E-8 // CHECK: %[[VAL_3:.*]] = tosa.clamp %[[VAL_1]] {max_fp = 0.99999988 : f32, max_int = 0 : i64, min_fp = 1.000000e-07 : f32, min_int = 0 : i64} : (tensor<3x4xf32>) -> tensor<3x4xf32> // CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor}> : () -> tensor -// CHECK: %[[VAL_5:.*]] = tosa.reshape %[[VAL_4]] {new_shape = array} : (tensor) -> tensor<1x1xf32> -// CHECK: %[[VAL_6:.*]] = tosa.sub %[[VAL_5]], %[[VAL_3]] : (tensor<1x1xf32>, tensor<3x4xf32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_7:.*]] = tosa.reciprocal %[[VAL_6]] : (tensor<3x4xf32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_8:.*]] = tosa.mul %[[VAL_3]], %[[VAL_7]] {shift = 0 : i8} : (tensor<3x4xf32>, tensor<3x4xf32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_9:.*]] = tosa.log %[[VAL_8]] : (tensor<3x4xf32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_10:.*]] = torch_c.from_builtin_tensor %[[VAL_9]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> -// CHECK: return %[[VAL_10]] : !torch.vtensor<[3,4],f32> +// CHECK: %[[VAL_5:.*]] = tosa.const_shape {value = dense<1> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_4]], %[[VAL_5]] : (tensor, !tosa.shape<2>) -> tensor<1x1xf32> +// CHECK: %[[VAL_7:.*]] = tosa.sub %[[VAL_6]], %[[VAL_3]] : (tensor<1x1xf32>, tensor<3x4xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_8:.*]] = tosa.reciprocal %[[VAL_7]] : (tensor<3x4xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_10:.*]] = tosa.mul %[[VAL_3]], %[[VAL_8]], %[[VAL_9]] : (tensor<3x4xf32>, tensor<3x4xf32>, tensor<1xi8>) -> tensor<3x4xf32> +// CHECK: %[[VAL_11:.*]] = tosa.log %[[VAL_10]] : (tensor<3x4xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_12:.*]] = torch_c.from_builtin_tensor %[[VAL_11]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> +// CHECK: return %[[VAL_12]] : !torch.vtensor<[3,4],f32> // CHECK: } func.func @torch.aten.logit$basic(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> { %float9.999990e-08 = torch.constant.float 9.9999999999999995E-8 @@ -3173,13 +3526,15 @@ func.func @torch.aten.logit$basic(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vt // CHECK: %[[VAL_3:.*]] = tosa.cast %[[VAL_1]] : (tensor<3x4xi32>) -> tensor<3x4xf32> // CHECK: %[[VAL_4:.*]] = tosa.clamp %[[VAL_3]] {max_fp = 0.99999988 : f32, max_int = 0 : i64, min_fp = 1.000000e-07 : f32, min_int = 0 : i64} : (tensor<3x4xf32>) -> tensor<3x4xf32> // CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor}> : () -> tensor -// CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_5]] {new_shape = array} : (tensor) -> tensor<1x1xf32> -// CHECK: %[[VAL_7:.*]] = tosa.sub %[[VAL_6]], %[[VAL_4]] : (tensor<1x1xf32>, tensor<3x4xf32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_8:.*]] = tosa.reciprocal %[[VAL_7]] : (tensor<3x4xf32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_9:.*]] = tosa.mul %[[VAL_4]], %[[VAL_8]] {shift = 0 : i8} : (tensor<3x4xf32>, tensor<3x4xf32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_10:.*]] = tosa.log %[[VAL_9]] : (tensor<3x4xf32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_11:.*]] = torch_c.from_builtin_tensor %[[VAL_10]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> -// CHECK: return %[[VAL_11]] : !torch.vtensor<[3,4],f32> +// CHECK: %[[VAL_6:.*]] = tosa.const_shape {value = dense<1> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_7:.*]] = tosa.reshape %[[VAL_5]], %[[VAL_6]] : (tensor, !tosa.shape<2>) -> tensor<1x1xf32> +// CHECK: %[[VAL_8:.*]] = tosa.sub %[[VAL_7]], %[[VAL_4]] : (tensor<1x1xf32>, tensor<3x4xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_9:.*]] = tosa.reciprocal %[[VAL_8]] : (tensor<3x4xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_10:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_11:.*]] = tosa.mul %[[VAL_4]], %[[VAL_9]], %[[VAL_10]] : (tensor<3x4xf32>, tensor<3x4xf32>, tensor<1xi8>) -> tensor<3x4xf32> +// CHECK: %[[VAL_12:.*]] = tosa.log %[[VAL_11]] : (tensor<3x4xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_13:.*]] = torch_c.from_builtin_tensor %[[VAL_12]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> +// CHECK: return %[[VAL_13]] : !torch.vtensor<[3,4],f32> // CHECK: } func.func @torch.aten.logit$int(%arg0: !torch.vtensor<[3,4],si32>) -> !torch.vtensor<[3,4],f32> { %float9.999990e-08 = torch.constant.float 9.9999999999999995E-8 @@ -3211,9 +3566,10 @@ func.func @torch.aten.log$int(%arg0: !torch.vtensor<[3,4],si32>) -> !torch.vtens // CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<0.693147182> : tensor<1x1xf32>}> : () -> tensor<1x1xf32> // CHECK: %[[VAL_4:.*]] = tosa.reciprocal %[[VAL_3]] : (tensor<1x1xf32>) -> tensor<1x1xf32> // CHECK: %[[VAL_5:.*]] = tosa.log %[[VAL_2]] : (tensor<3x4xf32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_5]], %[[VAL_4]] {shift = 0 : i8} : (tensor<3x4xf32>, tensor<1x1xf32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> -// CHECK: return %[[VAL_7]] : !torch.vtensor<[3,4],f32> +// CHECK: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_7:.*]] = tosa.mul %[[VAL_5]], %[[VAL_4]], %[[VAL_6]] : (tensor<3x4xf32>, tensor<1x1xf32>, tensor<1xi8>) -> tensor<3x4xf32> +// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> +// CHECK: return %[[VAL_8]] : !torch.vtensor<[3,4],f32> // CHECK: } func.func @torch.aten.log2$int(%arg0: !torch.vtensor<[3,4],si32>) -> !torch.vtensor<[3,4],f32> { %0 = torch.aten.log2 %arg0 : !torch.vtensor<[3,4],si32> -> !torch.vtensor<[3,4],f32> @@ -3240,14 +3596,13 @@ func.func @torch.aten.erf$int(%arg0: !torch.vtensor<[3,4],si32>) -> !torch.vtens // CHECK-LABEL: func.func @torch.aten.lt.Scalar$intfloat( // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4],si64>) -> !torch.vtensor<[4],i1> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4],si64> -> tensor<4xi64> -// CHECK: %[[VAL_2:.*]] = torch.constant.float 1.100000e+00 -// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<1.100000e+00> : tensor}> : () -> tensor -// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<1.1000000238418579> : tensor}> : () -> tensor -// CHECK: %[[VAL_5:.*]] = tosa.reshape %[[VAL_4]] {new_shape = array} : (tensor) -> tensor<1xf64> -// CHECK: %[[VAL_6:.*]] = tosa.cast %[[VAL_1]] : (tensor<4xi64>) -> tensor<4xf64> -// CHECK: %[[VAL_7:.*]] = tosa.greater %[[VAL_5]], %[[VAL_6]] : (tensor<1xf64>, tensor<4xf64>) -> tensor<4xi1> -// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<4xi1> -> !torch.vtensor<[4],i1> -// CHECK: return %[[VAL_8]] : !torch.vtensor<[4],i1> +// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<1.1000000238418579> : tensor}> : () -> tensor +// CHECK: %[[VAL_5:.*]] = tosa.const_shape {value = dense<1> : tensor<1xindex>} : () -> !tosa.shape<1> +// CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_3]], %[[VAL_5]] : (tensor, !tosa.shape<1>) -> tensor<1xf64> +// CHECK: %[[VAL_7:.*]] = tosa.cast %[[VAL_1]] : (tensor<4xi64>) -> tensor<4xf64> +// CHECK: %[[VAL_8:.*]] = tosa.greater %[[VAL_6]], %[[VAL_7]] : (tensor<1xf64>, tensor<4xf64>) -> tensor<4xi1> +// CHECK: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_8]] : tensor<4xi1> -> !torch.vtensor<[4],i1> +// CHECK: return %[[VAL_9]] : !torch.vtensor<[4],i1> // CHECK: } func.func @torch.aten.lt.Scalar$intfloat(%arg0: !torch.vtensor<[4],si64>) -> !torch.vtensor<[4],i1> { %float1.100000e00 = torch.constant.float 1.100000e+00 @@ -3278,9 +3633,10 @@ func.func @torch.aten.sigmoid$int(%arg0: !torch.vtensor<[3,5],si32>) -> !torch.v // CHECK: %[[VAL_2:.*]] = tosa.sin %[[VAL_1]] : (tensor<3x4xf32>) -> tensor<3x4xf32> // CHECK: %[[VAL_3:.*]] = tosa.cos %[[VAL_1]] : (tensor<3x4xf32>) -> tensor<3x4xf32> // CHECK: %[[VAL_4:.*]] = tosa.reciprocal %[[VAL_3]] : (tensor<3x4xf32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_5:.*]] = tosa.mul %[[VAL_2]], %[[VAL_4]] {shift = 0 : i8} : (tensor<3x4xf32>, tensor<3x4xf32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> -// CHECK: return %[[VAL_6]] : !torch.vtensor<[3,4],f32> +// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_2]], %[[VAL_4]], %[[VAL_5]] : (tensor<3x4xf32>, tensor<3x4xf32>, tensor<1xi8>) -> tensor<3x4xf32> +// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> +// CHECK: return %[[VAL_7]] : !torch.vtensor<[3,4],f32> // CHECK: } func.func @torch.aten.tan$basic(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> { %0 = torch.aten.tan %arg0 : !torch.vtensor<[3,4],f32> -> !torch.vtensor<[3,4],f32> @@ -3296,9 +3652,10 @@ func.func @torch.aten.tan$basic(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vten // CHECK: %[[VAL_3:.*]] = tosa.sin %[[VAL_2]] : (tensor<3x4xf32>) -> tensor<3x4xf32> // CHECK: %[[VAL_4:.*]] = tosa.cos %[[VAL_2]] : (tensor<3x4xf32>) -> tensor<3x4xf32> // CHECK: %[[VAL_5:.*]] = tosa.reciprocal %[[VAL_4]] : (tensor<3x4xf32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_3]], %[[VAL_5]] {shift = 0 : i8} : (tensor<3x4xf32>, tensor<3x4xf32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> -// CHECK: return %[[VAL_7]] : !torch.vtensor<[3,4],f32> +// CHECK: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_7:.*]] = tosa.mul %[[VAL_3]], %[[VAL_5]], %[[VAL_6]] : (tensor<3x4xf32>, tensor<3x4xf32>, tensor<1xi8>) -> tensor<3x4xf32> +// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> +// CHECK: return %[[VAL_8]] : !torch.vtensor<[3,4],f32> // CHECK: } func.func @torch.aten.tan$int(%arg0: !torch.vtensor<[3,4],si32>) -> !torch.vtensor<[3,4],f32> { %0 = torch.aten.tan %arg0 : !torch.vtensor<[3,4],si32> -> !torch.vtensor<[3,4],f32> @@ -3345,23 +3702,31 @@ func.func @torch.aten.pow.Tensor_Tensor$intfloat(%arg0: !torch.vtensor<[3,4,5],s // CHECK: %[[VAL_2:.*]] = torch.constant.int 0 // CHECK: %[[VAL_3:.*]] = torch.constant.int 2 // CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<{{\[\[}}0, 0, 0, 0], [1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3], [4, 4, 4, 4], [5, 5, 5, 5]]> : tensor<6x4xi32>}> : () -> tensor<6x4xi32> -// CHECK: %[[VAL_5:.*]] = tosa.reshape %[[VAL_4]] {new_shape = array} : (tensor<6x4xi32>) -> tensor<6x4x1xi32> -// CHECK: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<{{\[\[}}[0], [1], [2], [3]], {{\[\[}}0], [1], [2], [3]], {{\[\[}}0], [1], [2], [3]], {{\[\[}}0], [1], [2], [3]], {{\[\[}}0], [1], [2], [3]], {{\[\[}}0], [1], [2], [3]]]> : tensor<6x4x1xi32>}> : () -> tensor<6x4x1xi32> -// CHECK: %[[VAL_7:.*]] = tosa.concat %[[VAL_5]], %[[VAL_6]] {axis = 2 : i32} : (tensor<6x4x1xi32>, tensor<6x4x1xi32>) -> tensor<6x4x2xi32> -// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array} : (tensor<6x4xf32>) -> tensor<1x24x1xf32> -// CHECK: %[[VAL_9:.*]] = tosa.reshape %[[VAL_7]] {new_shape = array} : (tensor<6x4x2xi32>) -> tensor<24x2xi32> -// CHECK: %[[VAL_10:.*]] = "tosa.const"() <{value = dense<[4, 1]> : tensor<2xi32>}> : () -> tensor<2xi32> -// CHECK: %[[VAL_11:.*]] = tosa.reshape %[[VAL_10]] {new_shape = array} : (tensor<2xi32>) -> tensor<1x2xi32> -// CHECK: %[[VAL_12:.*]] = tosa.mul %[[VAL_9]], %[[VAL_11]] {shift = 0 : i8} : (tensor<24x2xi32>, tensor<1x2xi32>) -> tensor<24x2xi32> -// CHECK: %[[VAL_13:.*]] = tosa.reduce_sum %[[VAL_12]] {axis = 1 : i32} : (tensor<24x2xi32>) -> tensor<24x1xi32> -// CHECK: %[[VAL_14:.*]] = tosa.reshape %[[VAL_13]] {new_shape = array} : (tensor<24x1xi32>) -> tensor<1x24xi32> -// CHECK: %[[VAL_15:.*]] = tosa.gather %[[VAL_8]], %[[VAL_14]] : (tensor<1x24x1xf32>, tensor<1x24xi32>) -> tensor<1x24x1xf32> -// CHECK: %[[VAL_16:.*]] = tosa.reshape %[[VAL_15]] {new_shape = array} : (tensor<1x24x1xf32>) -> tensor<6x4xf32> -// CHECK: %[[VAL_17:.*]] = tosa.reshape %[[VAL_16]] {new_shape = array} : (tensor<6x4xf32>) -> tensor<3x2x4xf32> -// CHECK: %[[VAL_18:.*]] = "tosa.const"() <{value = dense<[0, 2, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> -// CHECK: %[[VAL_19:.*]] = tosa.transpose %[[VAL_17]], %[[VAL_18]] : (tensor<3x2x4xf32>, tensor<3xi32>) -> tensor<3x4x2xf32> -// CHECK: %[[VAL_20:.*]] = torch_c.from_builtin_tensor %[[VAL_19]] : tensor<3x4x2xf32> -> !torch.vtensor<[3,4,2],f32> -// CHECK: return %[[VAL_20]] : !torch.vtensor<[3,4,2],f32> +// CHECK: %[[VAL_5:.*]] = tosa.const_shape {value = dense<[6, 4, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_4]], %[[VAL_5]] : (tensor<6x4xi32>, !tosa.shape<3>) -> tensor<6x4x1xi32> +// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<{{\[\[}}[0], [1], [2], [3]], {{\[\[}}0], [1], [2], [3]], {{\[\[}}0], [1], [2], [3]], {{\[\[}}0], [1], [2], [3]], {{\[\[}}0], [1], [2], [3]], {{\[\[}}0], [1], [2], [3]]]> : tensor<6x4x1xi32>}> : () -> tensor<6x4x1xi32> +// CHECK: %[[VAL_8:.*]] = tosa.concat %[[VAL_6]], %[[VAL_7]] {axis = 2 : i32} : (tensor<6x4x1xi32>, tensor<6x4x1xi32>) -> tensor<6x4x2xi32> +// CHECK: %[[VAL_9:.*]] = tosa.const_shape {value = dense<[1, 24, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_10:.*]] = tosa.reshape %[[VAL_1]], %[[VAL_9]] : (tensor<6x4xf32>, !tosa.shape<3>) -> tensor<1x24x1xf32> +// CHECK: %[[VAL_11:.*]] = tosa.const_shape {value = dense<[24, 2]> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_12:.*]] = tosa.reshape %[[VAL_8]], %[[VAL_11]] : (tensor<6x4x2xi32>, !tosa.shape<2>) -> tensor<24x2xi32> +// CHECK: %[[VAL_13:.*]] = "tosa.const"() <{value = dense<[4, 1]> : tensor<2xi32>}> : () -> tensor<2xi32> +// CHECK: %[[VAL_14:.*]] = tosa.const_shape {value = dense<[1, 2]> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_15:.*]] = tosa.reshape %[[VAL_13]], %[[VAL_14]] : (tensor<2xi32>, !tosa.shape<2>) -> tensor<1x2xi32> +// CHECK: %[[VAL_16:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_17:.*]] = tosa.mul %[[VAL_12]], %[[VAL_15]], %[[VAL_16]] : (tensor<24x2xi32>, tensor<1x2xi32>, tensor<1xi8>) -> tensor<24x2xi32> +// CHECK: %[[VAL_18:.*]] = tosa.reduce_sum %[[VAL_17]] {axis = 1 : i32} : (tensor<24x2xi32>) -> tensor<24x1xi32> +// CHECK: %[[VAL_19:.*]] = tosa.const_shape {value = dense<[1, 24]> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_20:.*]] = tosa.reshape %[[VAL_18]], %[[VAL_19]] : (tensor<24x1xi32>, !tosa.shape<2>) -> tensor<1x24xi32> +// CHECK: %[[VAL_21:.*]] = tosa.gather %[[VAL_10]], %[[VAL_20]] : (tensor<1x24x1xf32>, tensor<1x24xi32>) -> tensor<1x24x1xf32> +// CHECK: %[[VAL_22:.*]] = tosa.const_shape {value = dense<[6, 4]> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_23:.*]] = tosa.reshape %[[VAL_21]], %[[VAL_22]] : (tensor<1x24x1xf32>, !tosa.shape<2>) -> tensor<6x4xf32> +// CHECK: %[[VAL_24:.*]] = tosa.const_shape {value = dense<[3, 2, 4]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_25:.*]] = tosa.reshape %[[VAL_23]], %[[VAL_24]] : (tensor<6x4xf32>, !tosa.shape<3>) -> tensor<3x2x4xf32> +// CHECK: %[[VAL_26:.*]] = "tosa.const"() <{value = dense<[0, 2, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> +// CHECK: %[[VAL_27:.*]] = tosa.transpose %[[VAL_25]], %[[VAL_26]] : (tensor<3x2x4xf32>, tensor<3xi32>) -> tensor<3x4x2xf32> +// CHECK: %[[VAL_28:.*]] = torch_c.from_builtin_tensor %[[VAL_27]] : tensor<3x4x2xf32> -> !torch.vtensor<[3,4,2],f32> +// CHECK: return %[[VAL_28]] : !torch.vtensor<[3,4,2],f32> // CHECK: } func.func @torch.aten.unfold$basic(%arg0: !torch.vtensor<[6,4],f32>) -> !torch.vtensor<[3,4,2],f32> { %int0 = torch.constant.int 0 @@ -3377,9 +3742,10 @@ func.func @torch.aten.unfold$basic(%arg0: !torch.vtensor<[6,4],f32>) -> !torch.v // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[],f32> -> tensor // CHECK: %[[VAL_2:.*]] = torch.constant.int 0 // CHECK: %[[VAL_3:.*]] = torch.constant.int 1 -// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array} : (tensor) -> tensor<1xf32> -// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<1xf32> -> !torch.vtensor<[1],f32> -// CHECK: return %[[VAL_5]] : !torch.vtensor<[1],f32> +// CHECK: %[[VAL_4:.*]] = tosa.const_shape {value = dense<1> : tensor<1xindex>} : () -> !tosa.shape<1> +// CHECK: %[[VAL_5:.*]] = tosa.reshape %[[VAL_1]], %[[VAL_4]] : (tensor, !tosa.shape<1>) -> tensor<1xf32> +// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<1xf32> -> !torch.vtensor<[1],f32> +// CHECK: return %[[VAL_6]] : !torch.vtensor<[1],f32> // CHECK: } func.func @torch.aten.unfold$rank_zero(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> { %int0 = torch.constant.int 0 @@ -3394,11 +3760,12 @@ func.func @torch.aten.unfold$rank_zero(%arg0: !torch.vtensor<[],f32>) -> !torch. // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4],f32> -> tensor<3x4xf32> // CHECK: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor}> : () -> tensor -// CHECK: %[[VAL_3:.*]] = tosa.reshape %[[VAL_2]] {new_shape = array} : (tensor) -> tensor<1x1xf32> -// CHECK: %[[VAL_4:.*]] = tosa.exp %[[VAL_1]] : (tensor<3x4xf32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_5:.*]] = tosa.sub %[[VAL_4]], %[[VAL_3]] : (tensor<3x4xf32>, tensor<1x1xf32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> -// CHECK: return %[[VAL_6]] : !torch.vtensor<[3,4],f32> +// CHECK: %[[VAL_3:.*]] = tosa.const_shape {value = dense<1> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_2]], %[[VAL_3]] : (tensor, !tosa.shape<2>) -> tensor<1x1xf32> +// CHECK: %[[VAL_5:.*]] = tosa.exp %[[VAL_1]] : (tensor<3x4xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_6:.*]] = tosa.sub %[[VAL_5]], %[[VAL_4]] : (tensor<3x4xf32>, tensor<1x1xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> +// CHECK: return %[[VAL_7]] : !torch.vtensor<[3,4],f32> // CHECK: } func.func @torch.aten.expm1$basic(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> { %0 = torch.aten.expm1 %arg0 : !torch.vtensor<[3,4],f32> -> !torch.vtensor<[3,4],f32> @@ -3412,11 +3779,12 @@ func.func @torch.aten.expm1$basic(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vt // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4],si32> -> tensor<3x4xi32> // CHECK: %[[VAL_2:.*]] = tosa.cast %[[VAL_1]] : (tensor<3x4xi32>) -> tensor<3x4xf32> // CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor}> : () -> tensor -// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor) -> tensor<1x1xf32> -// CHECK: %[[VAL_5:.*]] = tosa.exp %[[VAL_2]] : (tensor<3x4xf32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_6:.*]] = tosa.sub %[[VAL_5]], %[[VAL_4]] : (tensor<3x4xf32>, tensor<1x1xf32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> -// CHECK: return %[[VAL_7]] : !torch.vtensor<[3,4],f32> +// CHECK: %[[VAL_4:.*]] = tosa.const_shape {value = dense<1> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_5:.*]] = tosa.reshape %[[VAL_3]], %[[VAL_4]] : (tensor, !tosa.shape<2>) -> tensor<1x1xf32> +// CHECK: %[[VAL_6:.*]] = tosa.exp %[[VAL_2]] : (tensor<3x4xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_7:.*]] = tosa.sub %[[VAL_6]], %[[VAL_5]] : (tensor<3x4xf32>, tensor<1x1xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> +// CHECK: return %[[VAL_8]] : !torch.vtensor<[3,4],f32> // CHECK: } func.func @torch.aten.expm1$int(%arg0: !torch.vtensor<[3,4],si32>) -> !torch.vtensor<[3,4],f32> { %0 = torch.aten.expm1 %arg0 : !torch.vtensor<[3,4],si32> -> !torch.vtensor<[3,4],f32> @@ -3465,12 +3833,14 @@ func.func @torch.aten.constant_pad_nd$basic(%arg0: !torch.vtensor<[1,1,20,20,4,4 // CHECK: %[[VAL_12:.*]] = "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32> // CHECK: %[[VAL_13:.*]] = tosa.transpose %[[VAL_1]], %[[VAL_12]] : (tensor<5x2x10x20xf32>, tensor<4xi32>) -> tensor<5x10x20x2xf32> // CHECK: %[[VAL_14:.*]] = tosa.transpose %[[VAL_4]], %[[VAL_12]] : (tensor<10x2x3x3xf32>, tensor<4xi32>) -> tensor<10x3x3x2xf32> -// CHECK: %[[VAL_15:.*]] = tosa.conv2d %[[VAL_13]], %[[VAL_14]], %[[VAL_11]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<5x10x20x2xf32>, tensor<10x3x3x2xf32>, tensor<10xf32>) -> tensor<5x14x24x10xf32> -// CHECK: %[[VAL_16:.*]] = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32> -// CHECK: %[[VAL_17:.*]] = tosa.transpose %[[VAL_15]], %[[VAL_16]] : (tensor<5x14x24x10xf32>, tensor<4xi32>) -> tensor<5x10x14x24xf32> -// CHECK: %[[VAL_18:.*]] = tensor.cast %[[VAL_17]] : tensor<5x10x14x24xf32> to tensor<5x10x14x24xf32> -// CHECK: %[[VAL_19:.*]] = torch_c.from_builtin_tensor %[[VAL_18]] : tensor<5x10x14x24xf32> -> !torch.vtensor<[5,10,14,24],f32> -// CHECK: return %[[VAL_19]] : !torch.vtensor<[5,10,14,24],f32> +// CHECK: %[[VAL_15:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK: %[[VAL_16:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK: %[[VAL_17:.*]] = tosa.conv2d %[[VAL_13]], %[[VAL_14]], %[[VAL_11]], %[[VAL_15]], %[[VAL_16]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<5x10x20x2xf32>, tensor<10x3x3x2xf32>, tensor<10xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<5x14x24x10xf32> +// CHECK: %[[VAL_18:.*]] = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK: %[[VAL_19:.*]] = tosa.transpose %[[VAL_17]], %[[VAL_18]] : (tensor<5x14x24x10xf32>, tensor<4xi32>) -> tensor<5x10x14x24xf32> +// CHECK: %[[VAL_20:.*]] = tensor.cast %[[VAL_19]] : tensor<5x10x14x24xf32> to tensor<5x10x14x24xf32> +// CHECK: %[[VAL_21:.*]] = torch_c.from_builtin_tensor %[[VAL_20]] : tensor<5x10x14x24xf32> -> !torch.vtensor<[5,10,14,24],f32> +// CHECK: return %[[VAL_21]] : !torch.vtensor<[5,10,14,24],f32> // CHECK: } func.func @torch.aten.convolution$basic(%arg0: !torch.vtensor<[5,2,10,20],f32>) -> !torch.vtensor<[5,10,14,24],f32> { %false = torch.constant.bool false @@ -3506,13 +3876,16 @@ func.func @torch.aten.convolution$basic(%arg0: !torch.vtensor<[5,2,10,20],f32>) // CHECK: %[[VAL_14:.*]] = tosa.transpose %[[VAL_1]], %[[VAL_13]] : (tensor<5x4x10x20xf32>, tensor<4xi32>) -> tensor<5x10x20x4xf32> // CHECK: %[[VAL_15:.*]] = "tosa.const"() <{value = dense<[2, 3, 0, 1]> : tensor<4xi32>}> : () -> tensor<4xi32> // CHECK: %[[VAL_16:.*]] = tosa.transpose %[[VAL_5]], %[[VAL_15]] : (tensor<4x1x3x3xf32>, tensor<4xi32>) -> tensor<3x3x4x1xf32> -// CHECK: %[[VAL_17:.*]] = tosa.reshape %[[VAL_16]] {new_shape = array} : (tensor<3x3x4x1xf32>) -> tensor<3x3x4x1xf32> -// CHECK: %[[VAL_18:.*]] = tosa.depthwise_conv2d %[[VAL_14]], %[[VAL_17]], %[[VAL_12]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<5x10x20x4xf32>, tensor<3x3x4x1xf32>, tensor<4xf32>) -> tensor<5x5x10x4xf32> -// CHECK: %[[VAL_19:.*]] = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32> -// CHECK: %[[VAL_20:.*]] = tosa.transpose %[[VAL_18]], %[[VAL_19]] : (tensor<5x5x10x4xf32>, tensor<4xi32>) -> tensor<5x4x5x10xf32> -// CHECK: %[[VAL_21:.*]] = tensor.cast %[[VAL_20]] : tensor<5x4x5x10xf32> to tensor<5x4x5x10xf32> -// CHECK: %[[VAL_22:.*]] = torch_c.from_builtin_tensor %[[VAL_21]] : tensor<5x4x5x10xf32> -> !torch.vtensor<[5,4,5,10],f32> -// CHECK: return %[[VAL_22]] : !torch.vtensor<[5,4,5,10],f32> +// CHECK: %[[VAL_17:.*]] = tosa.const_shape {value = dense<[3, 3, 4, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_18:.*]] = tosa.reshape %[[VAL_16]], %[[VAL_17]] : (tensor<3x3x4x1xf32>, !tosa.shape<4>) -> tensor<3x3x4x1xf32> +// CHECK: %[[VAL_19:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK: %[[VAL_20:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK: %[[VAL_21:.*]] = tosa.depthwise_conv2d %[[VAL_14]], %[[VAL_18]], %[[VAL_12]], %[[VAL_19]], %[[VAL_20]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<5x10x20x4xf32>, tensor<3x3x4x1xf32>, tensor<4xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<5x5x10x4xf32> +// CHECK: %[[VAL_22:.*]] = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK: %[[VAL_23:.*]] = tosa.transpose %[[VAL_21]], %[[VAL_22]] : (tensor<5x5x10x4xf32>, tensor<4xi32>) -> tensor<5x4x5x10xf32> +// CHECK: %[[VAL_24:.*]] = tensor.cast %[[VAL_23]] : tensor<5x4x5x10xf32> to tensor<5x4x5x10xf32> +// CHECK: %[[VAL_25:.*]] = torch_c.from_builtin_tensor %[[VAL_24]] : tensor<5x4x5x10xf32> -> !torch.vtensor<[5,4,5,10],f32> +// CHECK: return %[[VAL_25]] : !torch.vtensor<[5,4,5,10],f32> // CHECK: } func.func @torch.aten.convolution$depthwise(%arg0: !torch.vtensor<[5,4,10,20],f32>) -> !torch.vtensor<[5,4,5,10],f32> { %false = torch.constant.bool false diff --git a/test/Conversion/TorchToTosa/torch-backend-to-tosa-backend-pipeline.mlir b/test/Conversion/TorchToTosa/torch-backend-to-tosa-backend-pipeline.mlir index 68ca4f0e08ac..42ff78d4dea0 100644 --- a/test/Conversion/TorchToTosa/torch-backend-to-tosa-backend-pipeline.mlir +++ b/test/Conversion/TorchToTosa/torch-backend-to-tosa-backend-pipeline.mlir @@ -2,9 +2,10 @@ // CHECK-LABEL: func.func @torch.aten.mul.Scalar$mixed_type( // CHECK-SAME: %[[VAL_0:.*]]: tensor<5xbf16>) -> tensor<5xbf16> { -// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<2.000000e+00> : tensor<1xbf16>}> : () -> tensor<1xbf16> -// CHECK: %[[VAL_2:.*]] = tosa.mul %[[VAL_0]], %[[VAL_1]] {shift = 0 : i8} : (tensor<5xbf16>, tensor<1xbf16>) -> tensor<5xbf16> -// CHECK: return %[[VAL_2]] : tensor<5xbf16> +// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<2.000000e+00> : tensor<1xbf16>}> : () -> tensor<1xbf16> +// CHECK: %[[VAL_3:.*]] = tosa.mul %[[VAL_0]], %[[VAL_2]], %[[VAL_1]] : (tensor<5xbf16>, tensor<1xbf16>, tensor<1xi8>) -> tensor<5xbf16> +// CHECK: return %[[VAL_3]] : tensor<5xbf16> // CHECK: } func.func @torch.aten.mul.Scalar$mixed_type(%arg0: !torch.vtensor<[5],bf16>) -> !torch.vtensor<[5],bf16> { %float2.000000e00 = torch.constant.float 2.000000e+00 @@ -93,10 +94,11 @@ func.func @torch.aten.bitwise_and.Tensor$mixed_type(%arg0: !torch.vtensor<[?,?], // CHECK-LABEL: func.func @torch.aten.div.Tensor$mixed_type_fp( // CHECK-SAME: %[[VAL_0:.*]]: tensor, // CHECK-SAME: %[[VAL_1:.*]]: tensor) -> tensor { -// CHECK: %[[VAL_2:.*]] = tosa.cast %[[VAL_1]] : (tensor) -> tensor -// CHECK: %[[VAL_3:.*]] = tosa.reciprocal %[[VAL_2]] : (tensor) -> tensor -// CHECK: %[[VAL_4:.*]] = tosa.mul %[[VAL_0]], %[[VAL_3]] {shift = 0 : i8} : (tensor, tensor) -> tensor -// CHECK: return %[[VAL_4]] : tensor +// CHECK: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_3:.*]] = tosa.cast %[[VAL_1]] : (tensor) -> tensor +// CHECK: %[[VAL_4:.*]] = tosa.reciprocal %[[VAL_3]] : (tensor) -> tensor +// CHECK: %[[VAL_5:.*]] = tosa.mul %[[VAL_0]], %[[VAL_4]], %[[VAL_2]] : (tensor, tensor, tensor<1xi8>) -> tensor +// CHECK: return %[[VAL_5]] : tensor // CHECK: } func.func @torch.aten.div.Tensor$mixed_type_fp(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vtensor<[?, ?],si32>) -> !torch.vtensor<[?, ?],f32> { %0 = torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[?, ?],f32>, !torch.vtensor<[?, ?],si32> -> !torch.vtensor<[?, ?],f32> @@ -119,9 +121,10 @@ func.func @torch.aten.div.Tensor$mixed_type_int(%arg0: !torch.vtensor<[?, ?],si1 // CHECK-LABEL: torch.aten.div.Scalar$int_input_fp_output // CHECK-SAME: %[[VAL_0:.*]]: tensor +// CHECK: %[[SHIFT:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> // CHECK: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<7.812500e-03> : tensor<1x1xf32>}> : () -> tensor<1x1xf32> // CHECK: %[[VAL_3:.*]] = tosa.cast %[[VAL_0]] : (tensor) -> tensor -// CHECK: %[[VAL_5:.*]] = tosa.mul %[[VAL_3]], %[[VAL_1]] {shift = 0 : i8} : (tensor, tensor<1x1xf32>) -> tensor +// CHECK: %[[VAL_5:.*]] = tosa.mul %[[VAL_3]], %[[VAL_1]], %[[SHIFT]] : (tensor, tensor<1x1xf32>, tensor<1xi8>) -> tensor func.func @torch.aten.div.Scalar$int_input_fp_output(%arg0: !torch.vtensor<[?, ?],si64>) -> !torch.vtensor<[?, ?],f32> { %int128 = torch.constant.int 128 %0 = torch.aten.div.Scalar %arg0, %int128 : !torch.vtensor<[?, ?],si64>, !torch.int -> !torch.vtensor<[?, ?],f32> diff --git a/test/python/fx_importer/v2.3/mutation_import.py b/test/python/fx_importer/v2.3/mutation_import.py index c2e5d9f14e2f..c0214b761467 100644 --- a/test/python/fx_importer/v2.3/mutation_import.py +++ b/test/python/fx_importer/v2.3/mutation_import.py @@ -171,26 +171,3 @@ def forward(self, x): ) print(m) m.operation.verify() - - -@run -# CHECK-LABEL: test_mutable_buffer_not_supported_without_hooks -# CHECK: EXPECTED ERROR: Store of a mutation to {{.*}} is not supported -def test_mutable_buffer_not_supported_without_hooks(): - class Basic(nn.Module): - def __init__(self): - super().__init__() - self.register_buffer("buffer", torch.randn(3, 4)) - - def forward(self, x): - self.buffer.mul_(x) - return x - - try: - m = fx.export_and_import( - Basic(), - torch.randn(3, 4), - experimental_support_mutation=True, - ) - except NotImplementedError as e: - print("EXPECTED ERROR:", str(e))