diff --git a/externals/llvm-project b/externals/llvm-project index 951cb07c781f..41d02533ef16 160000 --- a/externals/llvm-project +++ b/externals/llvm-project @@ -1 +1 @@ -Subproject commit 951cb07c781f01533979916035e2ee1c061774af +Subproject commit 41d02533ef16c5671972000ac69053f5305199bd diff --git a/externals/stablehlo b/externals/stablehlo index d40285ef3db0..6e403b1aa6a7 160000 --- a/externals/stablehlo +++ b/externals/stablehlo @@ -1 +1 @@ -Subproject commit d40285ef3db0687e3f1e2bb0d716d748485a9739 +Subproject commit 6e403b1aa6a71f5eaa09cc720e4ad42f692745e6 diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 4cc878350776..a429c405d24c 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -309,6 +309,61 @@ def Torch_AtenRrelu_Op : Torch_Op<"aten.rrelu_", [ }]; } +def Torch_AtenRreluWithNoiseOp : Torch_Op<"aten.rrelu_with_noise", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::rrelu_with_noise : (Tensor, Tensor, Scalar, Scalar, bool, Generator?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$noise, + AnyTorchScalarType:$lower, + AnyTorchScalarType:$upper, + Torch_BoolType:$training, + AnyTorchOptionalGeneratorType:$generator + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenRreluWithNoiseOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 6, 1); + } + void AtenRreluWithNoiseOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 6, 1); + } + }]; +} + +def Torch_AtenRreluWithNoise_Op : Torch_Op<"aten.rrelu_with_noise_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::rrelu_with_noise_ : (Tensor, Tensor, Scalar, Scalar, bool, Generator?) -> (Tensor)`"; + let arguments = (ins + Torch_NonValueTensorType:$self, + Torch_NonValueTensorType:$noise, + AnyTorchScalarType:$lower, + AnyTorchScalarType:$upper, + Torch_BoolType:$training, + AnyTorchOptionalGeneratorType:$generator + ); + let results = (outs + AnyTorchOptionalNonValueTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenRreluWithNoise_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 6, 1); + } + void AtenRreluWithNoise_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 6, 1); + } + }]; +} + def Torch_AtenCeluOp : Torch_Op<"aten.celu", [ AllowsTypeRefinement, HasValueSemantics, @@ -7352,6 +7407,7 @@ def Torch_AtenMaxPool3dWithIndicesOp : Torch_Op<"aten.max_pool3d_with_indices", printDefaultTorchOp(printer, *this, 6, 2); } }]; + let hasCanonicalizer = 1; } def Torch_AtenMaxPool3dWithIndicesBackwardOp : Torch_Op<"aten.max_pool3d_with_indices_backward", [ @@ -8079,6 +8135,7 @@ def Torch_AtenTransposeIntOp : Torch_Op<"aten.transpose.int", [ printDefaultTorchOp(printer, *this, 3, 1); } }]; + let hasFolder = 1; } def Torch_AtenPixelShuffleOp : Torch_Op<"aten.pixel_shuffle", [ @@ -9671,6 +9728,7 @@ def Torch_AtenFlattenUsingIntsOp : Torch_Op<"aten.flatten.using_ints", [ printDefaultTorchOp(printer, *this, 3, 1); } }]; + let hasFolder = 1; } def Torch_AtenUnflattenIntOp : Torch_Op<"aten.unflatten.int", [ @@ -9695,6 +9753,7 @@ def Torch_AtenUnflattenIntOp : Torch_Op<"aten.unflatten.int", [ printDefaultTorchOp(printer, *this, 3, 1); } }]; + let hasFolder = 1; let hasCanonicalizer = 1; } @@ -14085,6 +14144,59 @@ def Torch_AtenUpsampleNearest2dVecOp : Torch_Op<"aten.upsample_nearest2d.vec", [ }]; } +def Torch_AtenUpsampleBilinear2dOp : Torch_Op<"aten.upsample_bilinear2d", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::upsample_bilinear2d : (Tensor, int[], bool, float?, float?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$output_size, + Torch_BoolType:$align_corners, + AnyTorchOptionalFloatType:$scales_h, + AnyTorchOptionalFloatType:$scales_w + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenUpsampleBilinear2dOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 5, 1); + } + void AtenUpsampleBilinear2dOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 5, 1); + } + }]; +} + +def Torch_AtenUpsampleBilinear2dVecOp : Torch_Op<"aten.upsample_bilinear2d.vec", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::upsample_bilinear2d.vec : (Tensor, int[]?, bool, float[]?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$input, + AnyTorchOptionalListOfTorchIntType:$output_size, + Torch_BoolType:$align_corners, + AnyTorchOptionalListOfTorchFloatType:$scale_factors + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenUpsampleBilinear2dVecOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); + } + void AtenUpsampleBilinear2dVecOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; +} + def Torch_AtenScaledDotProductAttentionOp : Torch_Op<"aten.scaled_dot_product_attention", [ AllowsTypeRefinement, HasValueSemantics, @@ -16861,6 +16973,35 @@ def Torch_AtenLeakyReluBackwardOp : Torch_Op<"aten.leaky_relu_backward", [ }]; } +def Torch_AtenRreluWithNoiseBackwardOp : Torch_Op<"aten.rrelu_with_noise_backward", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::rrelu_with_noise_backward : (Tensor, Tensor, Tensor, Scalar, Scalar, bool, bool) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$grad_output, + AnyTorchTensorType:$self, + AnyTorchTensorType:$noise, + AnyTorchScalarType:$lower, + AnyTorchScalarType:$upper, + Torch_BoolType:$training, + Torch_BoolType:$self_is_result + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenRreluWithNoiseBackwardOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 7, 1); + } + void AtenRreluWithNoiseBackwardOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 7, 1); + } + }]; +} + def Torch_AtenQuantizePerChannelOp : Torch_Op<"aten.quantize_per_channel", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index 469026cab908..82ffd584820a 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -1087,9 +1087,6 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( if (binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET")) return rewriter.notifyMatchFailure(binder.op, "auto_pad bind failure"); - if (autoPad != "NOTSET") - return rewriter.notifyMatchFailure( - binder.op, "unsupported conversion: auto_pad != NOTSET"); Torch::ValueTensorType resultTypeOut; Value operand; @@ -1136,6 +1133,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( return rewriter.notifyMatchFailure(binder.op, "dilations bind failure"); + // set default padding if (padding.empty()) padding.resize(spatial, 0); if (strides.empty()) @@ -1143,6 +1141,34 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( if (dilations.empty()) dilations.resize(spatial, 1); + auto inputTensorType = cast(operand.getType()); + + // Padding for the beginning and ending along each spatial axis, it can + // take any value greater than or equal to 0. The value represent the + // number of pixels added to the beginning and end part of the + // corresponding axis. pads format should be as follow [x1_begin, + // x2_begin…x1_end, x2_end,…], where xi_begin the number of pixels added + // at the beginning of axis i and xi_end, the number of pixels added at + // the end of axis i. + if (autoPad != "NOTSET" && autoPad != "VALID") { + const bool isSameLower = autoPad == "SAME_LOWER"; + ArrayRef inputShape = inputTensorType.getSizes(); + padding.resize_for_overwrite(2 * spatial); + for (unsigned dimIdx = 0; dimIdx < spatial; dimIdx++) { + const int64_t dilatedKernelSize = + dilations[dimIdx] * (kernel[dimIdx] - 1) + 1; + int64_t totalPad = ((inputShape[dimIdx + 2] + strides[dimIdx] - 1) / + strides[dimIdx] - + 1) * + strides[dimIdx] + + dilatedKernelSize - inputShape[dimIdx + 2]; + totalPad = totalPad >= 0 ? totalPad : 0; + padding[dimIdx] = + isSameLower ? ((totalPad + 1) / 2) : (totalPad / 2); + padding[spatial + dimIdx] = totalPad - padding[dimIdx]; + } + } + // If the padding is symmetric we can push the padding operation to the // torch operator. if (padding.size() == static_cast(2 * spatial)) { diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index a4962d12abdc..9c914690bbf4 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -1125,54 +1125,57 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { } if (numGroups == 1 && inputZp) { - // The quantized version uses a different channel ordering so we need to - // permute the tensors in order to use the existing path. We should - // eventually directly support this channel ordering. - llvm::SmallVector inPerms, weightPerms; - inPerms.push_back(0); // N stays at the front for input. - // Then we expect the spatial dimensions - for (size_t i = 0; i < numSpatialDims; ++i) { - inPerms.push_back(i + 2); - weightPerms.push_back(i + 2); - } - inPerms.push_back(1); - weightPerms.append({1, 0}); - - paddedInput = transposeValue(op.getLoc(), paddedInput, inPerms, rewriter); - weight = transposeValue(op.getLoc(), weight, weightPerms, rewriter); - outputTensor = - transposeValue(op.getLoc(), outputTensor, inPerms, rewriter); - switch (numSpatialDims) { case 2: conv = rewriter - .create( + .create( loc, outputTensor.getType(), ValueRange{paddedInput, weight, inputZp, weightZp}, outputTensor, stridesAttr, dilationAttr) .getResult(0); break; - case 3: + case 3: { + // The quantized version uses a different channel ordering so we need to + // permute the tensors in order to use the existing path. We should + // eventually directly support this channel ordering. + llvm::SmallVector inPerms, weightPerms; + inPerms.push_back(0); // N stays at the front for input. + // Then we expect the spatial dimensions + for (size_t i = 0; i < numSpatialDims; ++i) { + inPerms.push_back(i + 2); + weightPerms.push_back(i + 2); + } + inPerms.push_back(1); + weightPerms.append({1, 0}); + + paddedInput = + transposeValue(op.getLoc(), paddedInput, inPerms, rewriter); + weight = transposeValue(op.getLoc(), weight, weightPerms, rewriter); + outputTensor = + transposeValue(op.getLoc(), outputTensor, inPerms, rewriter); + conv = rewriter .create( loc, outputTensor.getType(), ValueRange{paddedInput, weight, inputZp, weightZp}, outputTensor, stridesAttr, dilationAttr) .getResult(0); + + llvm::SmallVector outPerms; + outPerms.push_back(0); + outPerms.push_back(inPerms.size() - 1); + for (size_t i = 0; i < numSpatialDims; ++i) { + outPerms.push_back(i + 1); + } + conv = transposeValue(op.getLoc(), conv, outPerms, rewriter); + break; + } default: return rewriter.notifyMatchFailure( op, "unimplemented: only 1D, 2D, and 3D convolution supported"); }; - llvm::SmallVector outPerms; - outPerms.push_back(0); - outPerms.push_back(inPerms.size() - 1); - for (size_t i = 0; i < numSpatialDims; ++i) { - outPerms.push_back(i + 1); - } - conv = transposeValue(op.getLoc(), conv, outPerms, rewriter); - Type newResultType = getTypeConverter()->convertType(op.getType()); if (accumulatorDType != resultDTy) { Type resultElementType = diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index b4f051bc24a9..3ba8cef22cff 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -4639,6 +4639,25 @@ class ConvertAtenIndexTensorOpNone } }; +Value wrapNegativeIndices(Value index, int maxIndex, Operation *op, + ConversionPatternRewriter &rewriter) { + + auto zeroValue = tosa::getConstTensor(rewriter, op, 0, {}).value(); + auto maxIndexValue = + tosa::getConstTensor(rewriter, op, maxIndex, {}).value(); + + auto indexType = dyn_cast(index.getType()); + + auto wrappedIndicesOp = tosa::CreateOpAndInfer( + rewriter, op->getLoc(), indexType, maxIndexValue, index); + auto boolType = indexType.clone(rewriter.getIntegerType(1)); + auto isNegativeIndices = tosa::CreateOpAndInfer( + rewriter, op->getLoc(), boolType, zeroValue, index); + return tosa::CreateOpAndInfer(rewriter, op->getLoc(), + indexType, isNegativeIndices, + wrappedIndicesOp, index); +} + template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenIndexTensorHackedTwinOp op, OpAdaptor adaptor, @@ -4677,6 +4696,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto outType = getTypeConverter()->convertType(op.getType()); + Operation *indicesTf; + // Support for multiple indexes if (indexTensors.size() > 1) { // t[i, i] @@ -4710,6 +4731,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( index); } + index = wrapNegativeIndices(index, inputTensorType.getShape()[i], op, + rewriter); // Expand last dim of index to tf indices [2,3] -> [2,3,1] SmallVector indiceShapeOneDim; for (auto shape : indexShape) { @@ -4852,49 +4875,39 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto indicesShapeConcat = indexesShape[0]; uint64_t lastDim = indexesRank[0]; indicesShapeConcat.push_back(indicesTfConcatTensors.size()); - auto indicesTf = tosa::CreateOpAndInfer( + indicesTf = tosa::CreateOpAndInfer( rewriter, op->getLoc(), GetTypeFromTensorShape(indicesShapeConcat, rewriter.getIntegerType(32)), indicesTfConcatTensors, lastDim); - if (!indicesTf) { - return rewriter.notifyMatchFailure( - op, "Convert TorchIndex To TfIndices fail."); - } - // do the tf gathernp algorithm with tf style indices as input. - auto result = tosa::convertGatherNdOp(rewriter, op, outType, input, - indicesTf.getResult()); + } else { - if (!result) { - return rewriter.notifyMatchFailure( - op, "Convert GatherNdOp fail for index tensor."); + // Single index + auto index = indexTensors[0]; + auto indexType = dyn_cast(index.getType()); + auto indexShape = indexType.getShape(); + // index i64 to i32 for tosa compatible + if (indexType.getElementType() != rewriter.getIntegerType(32)) { + index = rewriter.create( + op->getLoc(), + RankedTensorType::get(indexShape, rewriter.getIntegerType(32)), + index); } - rewriter.replaceOp(op, {result.value()}); - return success(); - } + index = + wrapNegativeIndices(index, inputTensorType.getShape()[0], op, rewriter); - // Support for multiple index - auto index = indexTensors[0]; - auto indexType = dyn_cast(index.getType()); - auto indexShape = indexType.getShape(); - // index i64 to i32 for tosa compatible - if (indexType.getElementType() != rewriter.getIntegerType(32)) { - index = rewriter.create( - op->getLoc(), - RankedTensorType::get(indexShape, rewriter.getIntegerType(32)), index); - } - - // Expand last dim of index to tf indices [2,3] -> [2,3,1] - SmallVector indicesShape; - for (auto shape : indexShape) { - indicesShape.push_back(shape); + // Expand last dim of index to tf indices [2,3] -> [2,3,1] + SmallVector indicesShape; + for (auto shape : indexShape) { + indicesShape.push_back(shape); + } + indicesShape.push_back(1); + indicesTf = tosa::CreateOpAndInfer( + rewriter, op->getLoc(), + RankedTensorType::get(indicesShape, rewriter.getIntegerType(32)), index, + rewriter.getDenseI64ArrayAttr(indicesShape)); } - indicesShape.push_back(1); - auto indicesTf = tosa::CreateOpAndInfer( - rewriter, op->getLoc(), - RankedTensorType::get(indicesShape, rewriter.getIntegerType(32)), index, - rewriter.getDenseI64ArrayAttr(indicesShape)); if (!indicesTf) { return rewriter.notifyMatchFailure(op, @@ -4902,7 +4915,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( } // do the tf gathernp algorithm with tf style indices as input. auto result = tosa::convertGatherNdOp(rewriter, op, outType, input, - indicesTf.getResult()); + indicesTf->getResult(0)); if (!result) { return rewriter.notifyMatchFailure( diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 25c7d240ab01..bb7c0dfe43d8 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -30,6 +30,24 @@ using namespace mlir::torch::Torch; // Utilities //===----------------------------------------------------------------------===// +OpFoldResult genericViewLikeFold(Attribute self, Type resultType) { + auto selfAttr = dyn_cast_or_null(self); + if (!selfAttr) + return nullptr; + + auto resultTy = dyn_cast_or_null(resultType); + if (!resultTy || !resultTy.areAllSizesKnown()) + return nullptr; + + if (selfAttr.isSplat()) { + return SplatElementsAttr::get(resultTy.toBuiltinTensor(), + selfAttr.getSplatValue()); + } + return DenseElementsAttr::get( + resultTy.toBuiltinTensor(), + llvm::to_vector(selfAttr.getValues())); +} + Value mlir::torch::Torch::adjustStaticInformation(OpBuilder &builder, Location loc, Value value, Type desiredType, @@ -1049,6 +1067,8 @@ void Aten_CastLongOp::getCanonicalizationPatterns(RewritePatternSet &patterns, //===----------------------------------------------------------------------===// OpFoldResult AtenViewOp::fold(FoldAdaptor adaptor) { + if (auto genericFold = genericViewLikeFold(adaptor.getSelf(), getType())) + return genericFold; auto inputType = dyn_cast(getOperand(0).getType()); if (!inputType || !inputType.hasSizes() || inputType.getSizes().size() != 1) return nullptr; @@ -2236,10 +2256,22 @@ void AtenSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns, }); } +//===----------------------------------------------------------------------===// +// AtenFlattenUsingIntsOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenFlattenUsingIntsOp::fold(FoldAdaptor adaptor) { + return genericViewLikeFold(adaptor.getSelf(), getType()); +} + //===----------------------------------------------------------------------===// // AtenUnflattenIntOp //===----------------------------------------------------------------------===// +OpFoldResult AtenUnflattenIntOp::fold(FoldAdaptor adaptor) { + return genericViewLikeFold(adaptor.getSelf(), getType()); +} + void AtenUnflattenIntOp::getCanonicalizationPatterns( RewritePatternSet &patterns, MLIRContext *context) { // if there are only two sizes and one of them is statically 1, then convert @@ -3737,6 +3769,69 @@ OpFoldResult AtenSubIntOp::fold(FoldAdaptor adaptor) { adaptor.getOperands(), [](int64_t a, int64_t b) { return a - b; }); } +//===----------------------------------------------------------------------===// +// AtenTransposeIntOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenTransposeIntOp::fold(FoldAdaptor adaptor) { + // first check for no-op + IntegerAttr dim0 = dyn_cast_or_null(adaptor.getDim0()); + IntegerAttr dim1 = dyn_cast_or_null(adaptor.getDim1()); + if (!dim0 || !dim1) + return nullptr; + int64_t _dim0 = dim0.getValue().getSExtValue(); + int64_t _dim1 = dim1.getValue().getSExtValue(); + auto selfTy = dyn_cast(getSelf().getType()); + if (!selfTy || !selfTy.hasSizes()) + return nullptr; + int64_t rank = selfTy.getSizes().size(); + _dim0 = toPositiveDim(_dim0, rank); + _dim1 = toPositiveDim(_dim1, rank); + if (!isValidDim(_dim0, rank) || !isValidDim(_dim1, rank)) + return nullptr; + // if dims are the same, return self + if (_dim0 == _dim1) + return getSelf(); + + // We set a maximum folding size of 16. This is a reasonable upper limit + // for shape computations. + constexpr int64_t kMaxFoldSize = 16; + auto self = dyn_cast_or_null(adaptor.getSelf()); + if (!self || self.getNumElements() > kMaxFoldSize) + return nullptr; + auto resultTy = dyn_cast(getType()); + if (!selfTy || !resultTy || !selfTy.areAllSizesKnown()) + return nullptr; + if (self.isSplat()) + return SplatElementsAttr::get(resultTy.toBuiltinTensor(), + self.getSplatValue()); + + // TODO: add support for rank != 2 + if (rank != 2) + return nullptr; + + ArrayRef sizes = selfTy.getSizes(); + auto values = llvm::to_vector(self.getValues()); + // reordered[i] = Trans[i//sizes[0], i % sizes[0]] = Self[i % sizes[0], + // i//sizes[0]] = values[(i % sizes[0])*sizes[1] + (i//sizes[0])]. + // e.g., Self size = [4,2]; Trans size = [2,4]. + // reindex(i) = (i % 4)*2 + (i // 4) . + // i = 0 -> Trans[0,0] -> Self[0,0] -> 0 . + // i = 1 -> Trans[0,1] -> Self[1,0] -> 2 . + // i = 2 -> Trans[0,2] -> Self[2,0] -> 4 . + // i = 3 -> Trans[0,3] -> Self[3,0] -> 6 . + // i = 4 -> Trans[1,0] -> Self[0,1] -> 1 . + // i = 5 -> Trans[1,1] -> Self[1,1] -> 3 . + auto reindex = [&](int64_t i) { + return (i % sizes[0]) * sizes[1] + (i / sizes[0]); + }; + SmallVector reordered; + for (int64_t i = 0; i < self.getNumElements(); i++) { + reordered.push_back(values[reindex(i)]); + } + return DenseElementsAttr::get(resultTy.toBuiltinTensor(), reordered); +} + //===----------------------------------------------------------------------===// // AtenCatOp //===----------------------------------------------------------------------===// @@ -3913,15 +4008,17 @@ OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) { // Fold the slice if the output tensor is relatively small, currently // coded to 16: constexpr int64_t kMaxFold = 16; - if (input && start && step && dim && count <= kMaxFold) { + if (input && start && step && dim && end && count <= kMaxFold) { int64_t begin = start.getValue().getSExtValue(); int64_t limit = end.getValue().getSExtValue(); int64_t stride = step.getValue().getSExtValue(); - if (stride < 1) - return nullptr; begin = begin < 0 ? begin + inType.getSizes()[dimInt] : begin; limit = limit < 0 ? limit + inType.getSizes()[dimInt] : limit; + limit = limit < 0 ? -1 : limit; limit = std::min(limit, inType.getSizes()[dimInt]); + assert((stride > 0 && begin < limit) || + (stride < 0 && begin > limit) && + "aten.slice.Tensor iteration args are statically invalid."); int64_t inputRank = inType.getSizes().size(); llvm::SmallVector inputStrides(inputRank, 1); @@ -3934,10 +4031,21 @@ OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) { auto recursiveIter = [&](auto &self, int64_t currDim, int64_t currOffset) { if (currDim >= inputRank) return; - size_t _begin = (currDim == dimInt) ? begin : 0; - size_t _limit = (currDim == dimInt) ? limit : inType.getSizes()[currDim]; - size_t _stride = (currDim == dimInt) ? stride : 1; - for (size_t i = _begin; i < _limit; i += _stride) { + int64_t _stride = (currDim == dimInt) ? stride : 1; + int64_t _begin = (currDim == dimInt) ? begin : 0; + int64_t _limit = (currDim == dimInt) ? limit : inType.getSizes()[currDim]; + // ensure that the limit is reached exactly (even with negative strides) + // E.g., with begin = 0, limit = 10, stride = 3, we modify limit to be 11 + // = 10 + (10-0) % 3 . + // E.g., with begin = 8, limit = -1, stride = -2, limit becomes -2 = -1 + + // (-1-8) % (-2) - stride = -1 + 1 - 2 = -2 . + // Note: cpp uses true math remainder "n % d = least positive int, x, such + // that d divides (n - x)" + int64_t limit_rem = (_limit - _begin) % _stride; + limit_rem = + (_stride > 0 || limit_rem == 0) ? limit_rem : limit_rem - _stride; + _limit += limit_rem; + for (int64_t i = _begin; std::abs(_limit - i) > 0; i += _stride) { if (currDim == inputRank - 1) { values.push_back(input.getValues()[currOffset + i]); } @@ -5272,18 +5380,38 @@ OpFoldResult PrimsConvertElementTypeOp::fold(FoldAdaptor adaptor) { } //===----------------------------------------------------------------------===// -// AtenMaxPool2dWithIndicesOp +// AtenMaxPoolWithIndicesOp //===----------------------------------------------------------------------===// -void AtenMaxPool2dWithIndicesOp::getCanonicalizationPatterns( - RewritePatternSet &patterns, MLIRContext *context) { - patterns.add(+[](AtenMaxPool2dWithIndicesOp op, PatternRewriter &rewriter) { +namespace { + +template struct MaxPoolWithoutIndices { + using type = OpTy; +}; + +template <> struct MaxPoolWithoutIndices { + using type = AtenMaxPool2dOp; +}; + +template <> struct MaxPoolWithoutIndices { + using type = AtenMaxPool3dOp; +}; + +} // namespace + +template +struct SimplifyMaxPoolWithIndices : public mlir::OpRewritePattern { + SimplifyMaxPoolWithIndices(mlir::MLIRContext *context) + : OpRewritePattern(context, /*benefit=*/1) {} + + LogicalResult + matchAndRewrite(OpTy op, mlir::PatternRewriter &rewriter) const override { if (!op.getResult1().use_empty()) { return rewriter.notifyMatchFailure( - op, "result1 of MaxPool2dWithIndices should be unused"); + op, "result1 of MaxPoolWithIndices should be unused"); } - Value result = rewriter.create( + Value result = rewriter.create::type>( op->getLoc(), op.getResult0().getType(), op.getSelf(), op.getKernelSize(), op.getStride(), op.getPadding(), op.getDilation(), op.getCeilMode()); @@ -5291,7 +5419,17 @@ void AtenMaxPool2dWithIndicesOp::getCanonicalizationPatterns( op.getResult0().replaceAllUsesWith(result); rewriter.eraseOp(op); return success(); - }); + } +}; + +void AtenMaxPool2dWithIndicesOp::getCanonicalizationPatterns( + RewritePatternSet &patterns, MLIRContext *context) { + patterns.add>(context); +} + +void AtenMaxPool3dWithIndicesOp::getCanonicalizationPatterns( + RewritePatternSet &patterns, MLIRContext *context) { + patterns.add>(context); } //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 8ee2dde985ef..327d545eb664 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -6683,6 +6683,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.rrelu_with_noise_backward\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.float, %arg4: !torch.float, %arg5: !torch.bool, %arg6: !torch.bool) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.hardtanh_backward\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.float, %arg3: !torch.float) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -7285,6 +7289,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.rrelu_with_noise\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.float, %arg3: !torch.float, %arg4: !torch.bool, %arg5: !torch.any) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.selu\"(%arg0: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -11056,6 +11064,20 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %10 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.upsample_bilinear2d\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.bool, %arg3: !torch.optional, %arg4: !torch.optional) -> !torch.list {\n" +" %int0 = torch.constant.int 0\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %1 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %2 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %3 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %4 = torch.prim.ListConstruct %0, %1, %2, %3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list\n" +" return %4 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.upsample_bilinear2d.vec\"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.bool, %arg3: !torch.optional>) -> !torch.list {\n" +" %0 = call @\"__torch_mlir_shape_fn.aten.upsample_nearest2d.vec\"(%arg0, %arg1, %arg3) : (!torch.list, !torch.optional>, !torch.optional>) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.prims.split_dim\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.int) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" @@ -12073,6 +12095,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.rrelu_with_noise_backward\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.number, %arg4: !torch.number, %arg5: !torch.bool, %arg6: !torch.bool) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.lift_fresh_copy\"(%arg0: !torch.tuple) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" @@ -12265,6 +12295,47 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.rrelu_with_noise\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.number, %arg3: !torch.number, %arg4: !torch.bool, %arg5: !torch.any) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %true = torch.constant.bool true\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %3 = torch.prim.If %2 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %7 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" torch.prim.If.yield %7 : !torch.bool\n" +" }\n" +" torch.prim.If %3 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %5 = torch.prim.If %4 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %7 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" torch.prim.If.yield %7 : !torch.bool\n" +" }\n" +" torch.prim.If %5 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %6 = torch.aten.eq.int %0#0, %1#0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %6 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.relu6\"(%arg0: !torch.tuple) -> !torch.int {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" @@ -12541,6 +12612,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.upsample_bilinear2d\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.bool, %arg3: !torch.optional, %arg4: !torch.optional) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.upsample_bilinear2d.vec\"(%arg0: !torch.tuple, %arg1: !torch.optional>, %arg2: !torch.bool, %arg3: !torch.optional>) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.view\"(%arg0: !torch.tuple, %arg1: !torch.list) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index df12dec58b7d..8769cafe8316 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -3489,6 +3489,59 @@ class DecomposeAtenLeakyReluBackwardOp }; } // namespace +namespace { +class DecomposeAtenRreluWithNoiseBackwardOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenRreluWithNoiseBackwardOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value gradOutput = op.getGradOutput(); + Value self = op.getSelf(); + Value noise = op.getNoise(); + auto resType = cast(op.getType()); + if (!resType.hasDtype()) { + return rewriter.notifyMatchFailure(op, "result should have dtype"); + } + + bool training; + if (!matchPattern(op.getTraining(), m_TorchConstantBool(&training))) { + return rewriter.notifyMatchFailure(op, + "training should be a bool constant"); + } + + bool selfIsResult = false; + if (!matchPattern(op.getSelfIsResult(), + m_TorchConstantBool(&selfIsResult)) || + selfIsResult) + return rewriter.notifyMatchFailure( + op, "unimplemented: self_is_result should be false"); + + double lower, upper; + if (!matchPattern(op.getLower(), m_TorchConstantFloat(&lower)) || + !matchPattern(op.getUpper(), m_TorchConstantFloat(&upper))) { + return rewriter.notifyMatchFailure( + op, "lower and upper should be float constants"); + } + + if (training && (upper - lower > 0.000001)) { + Value rreluWithNoiseBackwardOutput = + rewriter.create(loc, resType, gradOutput, noise); + rewriter.replaceOp(op, rreluWithNoiseBackwardOutput); + } else { + double negative_slope = (upper + lower) / 2; + Value cstNegativeSlope = rewriter.create( + loc, rewriter.getF64FloatAttr(negative_slope)); + rewriter.replaceOpWithNewOp( + op, resType, gradOutput, self, cstNegativeSlope, + op.getSelfIsResult()); + } + return success(); + } +}; +} // namespace + namespace { class DecomposeAtenPreluOp : public OpRewritePattern { public: @@ -3588,6 +3641,82 @@ class DecomposeAtenRreluOp : public OpRewritePattern { }; } // namespace +namespace { +class DecomposeAtenRreluWithNoiseOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenRreluWithNoiseOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value self = op.getSelf(); + Value noise = op.getNoise(); + Value lower = op.getLower(); + Value upper = op.getUpper(); + auto resType = cast(op.getType()); + if (!resType.hasDtype()) { + return rewriter.notifyMatchFailure(op, "result should have dtype"); + } + + bool training; + if (!matchPattern(op.getTraining(), m_TorchConstantBool(&training))) { + return rewriter.notifyMatchFailure(op, "training should be a constant"); + } + + Value constantZeroFloat = + rewriter.create(loc, rewriter.getF64FloatAttr(0.0)); + Value constantOneFloat = + rewriter.create(loc, rewriter.getF64FloatAttr(1.0)); + Value constantTwoFloat = + rewriter.create(loc, rewriter.getF64FloatAttr(2.0)); + + Value alpha; + if (training) { + Value none = rewriter.create(loc); + Value emptyTensor = rewriter.create( + loc, resType, self, constantZeroFloat, /*dtype=*/none, + /*layout=*/none, + /*device=*/none, /*pin_memoty=*/none, /*memory_format=*/none); + alpha = rewriter.create(loc, resType, emptyTensor, + /*from=*/lower, /*to=*/upper, + /*generator=*/none); + } else { + Value half = rewriter.create(loc, constantTwoFloat.getType(), + lower, upper); + alpha = rewriter.create(loc, constantTwoFloat.getType(), half, + constantTwoFloat); + } + + Value zeroTensor = + createRank0Tensor(rewriter, loc, resType, constantZeroFloat); + Value positiveOutput = + rewriter.create(loc, resType, zeroTensor, self); + + Value scaledSelf; + if (training) { + scaledSelf = rewriter.create(loc, resType, self, alpha); + auto boolResType = resType.getWithSizesAndDtype(resType.getSizes(), + rewriter.getI1Type()); + Value oneTensor = + createRank0Tensor(rewriter, loc, resType, constantOneFloat); + Value not_positive = rewriter.create( + loc, boolResType, self, constantZeroFloat); + noise = rewriter.create(loc, resType, not_positive, + alpha, oneTensor); + } else { + scaledSelf = rewriter.create(loc, resType, self, alpha); + } + + Value negativeOutput = + rewriter.create(loc, resType, zeroTensor, scaledSelf); + Value rreluOutput = rewriter.create( + loc, resType, positiveOutput, negativeOutput, constantOneFloat); + rewriter.replaceOp(op, rreluOutput); + return success(); + } +}; +} // namespace + // CELU(x)=max(0,x)+min(0,alpha∗(exp(x/alpha)−1)) namespace { class DecomposeAtenCeluOp : public OpRewritePattern { @@ -10013,6 +10142,9 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal( + patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index 9dfd905a1d0a..057182b178c8 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -498,6 +498,8 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/lib/Dialect/Torch/Transforms/SimplifyAbstractInterpCalculationsUtils.cpp b/lib/Dialect/Torch/Transforms/SimplifyAbstractInterpCalculationsUtils.cpp index f1ebeb307976..d599fd5369f4 100644 --- a/lib/Dialect/Torch/Transforms/SimplifyAbstractInterpCalculationsUtils.cpp +++ b/lib/Dialect/Torch/Transforms/SimplifyAbstractInterpCalculationsUtils.cpp @@ -32,9 +32,6 @@ class FoldPrimUncheckedCastOp : public OpRewritePattern { } // namespace namespace { -// TODO: Only unroll inside the shape calculation region. -// Maybe do this by only applying patterns and folding greedily on the ops -// inside the region + the shape.calculate op itself? class FullyUnrollPrimLoopOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -42,6 +39,12 @@ class FullyUnrollPrimLoopOp : public OpRewritePattern { PatternRewriter &rewriter) const override { Location loc = op->getLoc(); MLIRContext *context = op->getContext(); + // Only unroll loops if they are contained in a shape calculate region. + Region *region = op->getParentRegion(); + Operation *parentOp = region->getParentOp(); + if (!parentOp || !isa(parentOp)) + return rewriter.notifyMatchFailure( + op, "Loop is not contained in a shape calculation region."); if (!op.isForLike()) return rewriter.notifyMatchFailure(op, "Loop is not for-like"); int64_t maxTripCount; diff --git a/lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp b/lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp index 0f2533e063f0..53de48f21934 100644 --- a/lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp +++ b/lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp @@ -57,16 +57,16 @@ static void setupTorchBoolToI1Conversion(ConversionTarget &target, typeConverter.addConversion([](Torch::BoolType type) -> std::optional { return IntegerType::get(type.getContext(), 1); }); - typeConverter.addTargetMaterialization( - [](OpBuilder &builder, IntegerType type, ValueRange inputs, - Location loc) -> std::optional { - // Other builtin integer types could be handled by other materializers. - if (!(type.getWidth() == 1 && type.isSignless())) - return std::nullopt; - assert(inputs.size() == 1); - assert(isa(inputs[0].getType())); - return builder.create(loc, inputs[0]).getResult(); - }); + typeConverter.addTargetMaterialization([](OpBuilder &builder, + IntegerType type, ValueRange inputs, + Location loc) -> Value { + // Other builtin integer types could be handled by other materializers. + if (!(type.getWidth() == 1 && type.isSignless())) + return Value(); + assert(inputs.size() == 1); + assert(isa(inputs[0].getType())); + return builder.create(loc, inputs[0]).getResult(); + }); auto sourceMaterialization = [](OpBuilder &builder, Torch::BoolType type, ValueRange inputs, Location loc) -> Value { assert(inputs.size() == 1); @@ -83,19 +83,19 @@ static void setupTorchIntToI64Conversion(ConversionTarget &target, typeConverter.addConversion([](Torch::IntType type) -> std::optional { return IntegerType::get(type.getContext(), 64); }); - typeConverter.addTargetMaterialization( - [](OpBuilder &builder, IntegerType type, ValueRange inputs, - Location loc) -> std::optional { - // Other builtin integer types could be handled by other materializers. - if (!(type.getWidth() == 64 && type.isSignless())) - return std::nullopt; - // Other input type to be converted to i64 are handled by other - // materializers. - if (!isa(inputs[0].getType())) - return std::nullopt; - assert(inputs.size() == 1); - return builder.createOrFold(loc, inputs[0]); - }); + typeConverter.addTargetMaterialization([](OpBuilder &builder, + IntegerType type, ValueRange inputs, + Location loc) -> Value { + // Other builtin integer types could be handled by other materializers. + if (!(type.getWidth() == 64 && type.isSignless())) + return Value(); + // Other input type to be converted to i64 are handled by other + // materializers. + if (!isa(inputs[0].getType())) + return Value(); + assert(inputs.size() == 1); + return builder.createOrFold(loc, inputs[0]); + }); auto sourceMaterialization = [](OpBuilder &builder, Torch::IntType type, ValueRange inputs, Location loc) -> Value { assert(inputs.size() == 1); @@ -112,13 +112,13 @@ static void setupTorchFloatToF64Conversion(ConversionTarget &target, typeConverter.addConversion([](Torch::FloatType type) -> std::optional { return Float64Type::get(type.getContext()); }); - typeConverter.addTargetMaterialization( - [](OpBuilder &builder, Float64Type type, ValueRange inputs, - Location loc) -> std::optional { - assert(inputs.size() == 1); - assert(isa(inputs[0].getType())); - return builder.create(loc, inputs[0]).getResult(); - }); + typeConverter.addTargetMaterialization([](OpBuilder &builder, + Float64Type type, ValueRange inputs, + Location loc) -> Value { + assert(inputs.size() == 1); + assert(isa(inputs[0].getType())); + return builder.create(loc, inputs[0]).getResult(); + }); auto sourceMaterialization = [](OpBuilder &builder, Torch::FloatType type, ValueRange inputs, Location loc) -> Value { assert(inputs.size() == 1); @@ -137,19 +137,19 @@ static void setupTorchGeneratorToI64Conversion(ConversionTarget &target, [](Torch::GeneratorType type) -> std::optional { return IntegerType::get(type.getContext(), 64); }); - typeConverter.addTargetMaterialization( - [](OpBuilder &builder, IntegerType type, ValueRange inputs, - Location loc) -> std::optional { - // Other builtin integer types could be handled by other materializers. - if (!(type.getWidth() == 64 && type.isSignless())) - return std::nullopt; - // Other input type to be converted to i64 are handled by other - // materializers. - if (!isa(inputs[0].getType())) - return std::nullopt; - assert(inputs.size() == 1); - return builder.create(loc, inputs[0]).getResult(); - }); + typeConverter.addTargetMaterialization([](OpBuilder &builder, + IntegerType type, ValueRange inputs, + Location loc) -> Value { + // Other builtin integer types could be handled by other materializers. + if (!(type.getWidth() == 64 && type.isSignless())) + return Value(); + // Other input type to be converted to i64 are handled by other + // materializers. + if (!isa(inputs[0].getType())) + return Value(); + assert(inputs.size() == 1); + return builder.create(loc, inputs[0]).getResult(); + }); auto sourceMaterialization = [](OpBuilder &builder, Torch::GeneratorType type, ValueRange inputs, Location loc) -> Value { assert(inputs.size() == 1); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 9adb31a22365..e9784a52fa85 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -394,15 +394,6 @@ "AtenIntBoolOpModule_basic", "AtenIntMM_basic", "AtenItemFpOpModule_basic", - "AtenMatmulQMixedSigni8Transpose_basic", - "AtenMatmulQMixedSigni8_basic", - "AtenMatmulQint8MV_basic", - "AtenMatmulQint8_basic", - "AtenMatmulQint8VM_basic", - "AtenMatmulQint8VV_basic", - "AtenMmQMixedSigni8_basic", - "AtenMmQint8_basic", - "AtenMmQuint8_basic", "QuantizedReluInt32_basic", "QuantizedReluInt8_basic", "QuantizedReluUint8_basic", @@ -531,6 +522,8 @@ "ChunkListUnpackUnevenDynamic_Module_basic", "ChunkListUnpackUneven_Module_basic", "ChunkListUnpack_Module_basic", + "ElementwiseRreluWithNoiseTrainModule_basic", + "ElementwiseRreluWithNoiseTrainStaticModule_basic", "SplitTensorGetItem_Module_basic", "SplitTensorLastSmallerModule_basic", "SplitTensorListUnpackModule_basic", @@ -564,10 +557,6 @@ FX_IMPORTER_XFAIL_SET |= { "AtenSubFloatModule_basic", "ElementwiseAddScalar_NumToTensorFloat_Module_basic", - "ElementwiseRreluEvalModule_basic", - "ElementwiseRreluEvalStaticModule_basic", - "ElementwiseRreluTrainModule_basic", - "ElementwiseRreluTrainStaticModule_basic", "EqIntModule_basic", "GeFloatModule_basic", "GtIntModule_basic", @@ -591,6 +580,9 @@ "AdaptiveAvgPool1dGeneralDynamicNoBatches_basic", # Randomly mismatching values "ConvolutionModule2DTranspose_basic", + # torch export: RuntimeError: cannot mutate tensors with frozen storage + "ElementwiseRreluWithNoiseTrainModule_basic", + "ElementwiseRreluWithNoiseTrainStaticModule_basic", } FX_IMPORTER_STABLEHLO_XFAIL_SET = { @@ -747,7 +739,6 @@ "DiagonalModule_with_offset", "DivFloatModule_basic", "DivIntModule_basic", - "ElementwiseAddScalar_NumToTensorFloat_Module_basic", "ElementwiseDequantizePerChannelModule_basic", "ElementwiseDequantizePerTensorModule_basic", "ElementwiseErfIntModule_basic", @@ -849,8 +840,6 @@ "NormScalarComplexModule_basic", "NormScalarModule_basic", "NormalFunctionalModule_basic", - "NumToTensorFloatModule_basic", - "NumToTensorIntModule_basic", "NumelModule_basic", "NumelZeroRankModule_basic", "PowIntFloatModule_basic", @@ -886,7 +875,6 @@ "ReplicationPad2dModule_left0", "ReplicationPad2dModule_right0", "ReplicationPad2dModule_top0", - "RsubInt0d_NumToTensor_Module_basic", "ScalarImplicitFloatModule_basic", # REMOVE WHEN ENABLE_GQA IS ADDED "ScatterReduceFloatMaxModule", @@ -1021,6 +1009,11 @@ "UpSampleNearest2dStaticFactor_basic", "UpSampleNearest2dStaticSize_basic", "UpSampleNearest2d_basic", + # RuntimeError: cannot mutate tensors with frozen storage + "ElementwiseRreluTrainModule_basic", + "ElementwiseRreluTrainStaticModule_basic", + "ElementwiseRreluWithNoiseTrainModule_basic", + "ElementwiseRreluWithNoiseTrainStaticModule_basic", } FX_IMPORTER_STABLEHLO_CRASHING_SET = { @@ -1039,6 +1032,9 @@ # materialization callback produced value of incorrect type failed "ReduceMaxAlongDimUnsignedInt_basic", "ReduceMinAlongDimUnsignedInt_basic", + # torch export: RuntimeError: cannot mutate tensors with frozen storage + "ElementwiseRreluWithNoiseTrainModule_basic", + "ElementwiseRreluWithNoiseTrainStaticModule_basic", } STABLEHLO_PASS_SET = { @@ -1267,6 +1263,10 @@ "ElementwisePreluStaticModule_basic", "ElementwiseReciprocalModule_basic", "ElementwiseReluModule_basic", + "ElementwiseRreluWithNoiseEvalStaticModule_basic", + "ElementwiseRreluWithNoiseTrainStaticModule_basic", + "RreluWithNoiseBackwardEvalStaticModule_basic", + "RreluWithNoiseBackwardTrainStaticModule_basic", "ElementwiseRemainderTensorModule_Float_basic", "ElementwiseRemainderTensorModule_Float_NegativeDividend_basic", "ElementwiseRemainderTensorModule_Int_Float_basic", @@ -1754,7 +1754,6 @@ "ArangeStartOutModule_basic", "ScatterSrcStaticModule_basic", # Runtime op verification: Out of bounds access - "IndexTensorNegativeIndexModule_basic", "ReduceAllDimEmpty_basic", } @@ -1762,7 +1761,6 @@ "ScatterSrcModule_basic", "ScatterSrcStaticModule_basic", "HBC_basic", - "IndexTensorNegativeIndexModule_basic", "InterpolateDynamicModule_scales_recompute_bilinear", "InterpolateDynamicModule_sizes_bilinear", "InterpolateDynamicModule_sizes_nearest", @@ -2172,6 +2170,7 @@ "ElementwiseReciprocalModule_basic", "ElementwiseRelu6Module_basic", "ElementwiseReluModule_basic", + "ElementwiseRreluWithNoiseEvalStaticModule_basic", "ElementwiseRemainderScalarModule_Float_NegativeDividend_basic", "ElementwiseRemainderScalarModule_Float_NegativeDivisor_basic", "ElementwiseRemainderScalarModule_Int_Float_NegativeDividend_basic", @@ -2223,6 +2222,7 @@ "HardswishRandomModule_basic", "HardtanhBackward_basic", "IndexTensorMultiIndexStaticModule_basic", + "IndexTensorNegativeIndexModule_basic", "IndexTensorStaticModule_basic", "IscloseStaticModuleTrue_basic", "IscloseStaticModule_basic", @@ -2304,6 +2304,10 @@ "ReduceSumFloatModule_basic", "ReduceSumSignedIntModule_basic", "ReduceSumUnsignedIntModule_basic", + "RreluWithNoiseBackwardEvalModule_basic", + "RreluWithNoiseBackwardEvalStaticModule_basic", + "RreluWithNoiseBackwardTrainModule_basic", + "RreluWithNoiseBackwardTrainStaticModule_basic", "RepeatModule_basic", "RepeatInterleaveSelfIntNoDimModule_basic", "ResNet18StaticModule_basic", @@ -2502,6 +2506,10 @@ "ViewSizeFromOtherTensor_basic", "RenormModuleFloat32NegativeDim_basic", "RenormModuleFloat32_basic", + "RreluWithNoiseBackwardEvalModule_basic", + "RreluWithNoiseBackwardEvalStaticModule_basic", + "RreluWithNoiseBackwardTrainModule_basic", + "RreluWithNoiseBackwardTrainStaticModule_basic", } ) - { ### Test failing in make_fx_tosa but not in tosa @@ -2734,20 +2742,6 @@ "MultinomialModule2D_basic", "MultinomialModule2D_F32", "PixelShuffleModuleStaticRank4Float32_basic", - "ReflectionPad1dModule2dInput_Right", - "ReflectionPad1dModule2dInput_basic", - "ReflectionPad1dModule3dInput_Left", - "ReflectionPad1dModule3dInput_basic", - "ReflectionPad2dModule_Bottom", - "ReflectionPad2dModule_Left", - "ReflectionPad2dModule_Right", - "ReflectionPad2dModule_Top", - "ReflectionPad2dModule_basic", - "ReplicationPad2dModule_basic", - "ReplicationPad2dModule_bottom0", - "ReplicationPad2dModule_left0", - "ReplicationPad2dModule_right0", - "ReplicationPad2dModule_top0", "SliceCopyEndGreaterThanDimSize_Module_basic", "SliceCopyNegative_Module_basic", "SliceCopyNonZeroDim_Module_basic", @@ -2934,6 +2928,10 @@ "ElementwiseRemainderTensorModule_Int_basic", "ElementwiseRemainderTensorModule_Int_NegativeDividend_basic", "ElementwiseRemainderTensorModule_Int_NegativeDivisor_basic", + "ElementwiseRreluWithNoiseEvalModule_basic", + "ElementwiseRreluWithNoiseEvalStaticModule_basic", + "ElementwiseRreluWithNoiseTrainModule_basic", + "ElementwiseRreluWithNoiseTrainStaticModule_basic", "ElementwiseSgnModule_basic", "EmptyStridedModule_basic", "EmptyStridedSizeIntStrideModule_basic", @@ -3081,6 +3079,11 @@ "ReduceL1NormComplexModule_basic", "ReduceL2NormComplexModule_basic", "ReduceL3NormKeepDimComplexModule_basic", + "RreluWithNoiseBackwardEvalModule_basic", + "RreluWithNoiseBackwardEvalStaticModule_basic", + "RreluWithNoiseBackwardTrainModule_basic", + "RreluWithNoiseBackwardTrainStaticModule_basic", + "RreluWithNoiseForwardBackwardModule_basic", "ReshapeAliasCollapseModule_basic", "ReshapeAliasExpandModule_basic", "ReshapeExpandModule_basic", @@ -3618,6 +3621,8 @@ "ElementwiseQuantizePerTensorModule_basic", "ElementwiseQuantizePerTensorUIntModule_basic", "ElementwiseReciprocalIntModule_basic", + "ElementwiseRreluWithNoiseTrainModule_basic", + "ElementwiseRreluWithNoiseTrainStaticModule_basic", "ElementwiseRsqrtIntModule_basic", "ElementwiseSigmoidIntModule_basic", "ElementwiseSinIntModule_basic", @@ -3673,7 +3678,6 @@ "IndexPutImpl3DFloatAccumulateModule_basic", "IndexPutImpl3DFloatNonAccumulateModule_basic", "IndexPutImplIndexWithNoneModule_basic", - "IndexTensorNegativeIndexModule_basic", "InterpolateDynamicModule_sizes_bilinear", "InterpolateDynamicModule_sizes_nearest", "InterpolateStaticModule_scales_bilinear_align_corners", @@ -3993,6 +3997,12 @@ "EinsumStaticModule_basic", "EinsumStaticWithEllipsisSlicingAndBroadcastModule_basic", "EinsumStaticWithEllipsisSlicingModule_basic", + "ElementwiseRreluEvalModule_basic", + "ElementwiseRreluEvalStaticModule_basic", + "ElementwiseRreluTrainModule_basic", + "ElementwiseRreluTrainStaticModule_basic", + "ElementwiseRreluWithNoiseTrainModule_basic", + "ElementwiseRreluWithNoiseTrainStaticModule_basic", "GridSamplerBasic1_basic", "GridSamplerBasic2_basic", "GridSamplerBasic3_basic", @@ -4029,8 +4039,10 @@ FX_IMPORTER_TOSA_XFAIL_SET |= { "ElementwiseAddScalar_NumToTensorFloat_Module_basic", "ElementwiseLogSigmoidModule_basic", - "ElementwiseRreluEvalModule_basic", - "ElementwiseRreluEvalStaticModule_basic", + "ElementwiseRreluTrainModule_basic", + "ElementwiseRreluTrainStaticModule_basic", + "ElementwiseRreluWithNoiseTrainModule_basic", + "ElementwiseRreluWithNoiseTrainStaticModule_basic", "NumToTensorFloatModule_basic", "NumToTensorIntModule_basic", "RsubInt0d_NumToTensor_Module_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index a7b89449d7e9..64584af0371c 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -298,6 +298,9 @@ def aten〇gelu_backward〡shape(grad_output: List[int], self: List[int], approx def aten〇leaky_relu_backward〡shape(grad_output: List[int], self: List[int], negative_slope: float, self_is_result: bool) -> List[int]: return upstream_shape_functions.unary(grad_output) +def aten〇rrelu_with_noise_backward〡shape(grad_output: List[int], self: List[int], noise: List[int], lower: float, upper: float, training: bool, self_is_result: bool) -> List[int]: + return upstream_shape_functions.unary(grad_output) + def aten〇hardtanh_backward〡shape(grad_output: List[int], self: List[int], min_val: float, max_val: float) -> List[int]: return upstream_shape_functions.unary(grad_output) @@ -634,6 +637,9 @@ def aten〇celu〡shape(self: List[int], alpha: float = 1.) -> List[int]: def aten〇rrelu〡shape(self: List[int], lower: float = 0.125, upper: float = 0.33333333333333331, training: bool = False, generator: Any = None) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇rrelu_with_noise〡shape(self: List[int], noise: List[int], lower: float = 0.125, upper: float = 0.33333333333333331, training: bool = False, generator: Any = None) -> List[int]: + return upstream_shape_functions.unary(self) + def aten〇selu〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) @@ -2343,6 +2349,20 @@ def aten〇upsample_nearest2d〇vec〡shape(input: List[int], output_size: Optio assert scale_factors is not None return [input[0], input[1], int(input[2] * scale_factors[0]), int(input[3] * scale_factors[1])] +@check_shape_function([ + Invocation(TensorOfShape(1, 3, 10, 10), [11, 12], True) +]) +def aten〇upsample_bilinear2d〡shape(self: List[int], output_size: List[int], align_corners: bool, scales_h: Optional[float] = None, scales_w: Optional[float] = None) -> List[int]: + return [self[0], self[1], output_size[0], output_size[1]] + +@check_shape_function([ + Invocation(TensorOfShape(1, 3, 10, 10), [11, 12], True, None), + Invocation(TensorOfShape(1, 3, 10, 9), None, True, [2.0, 2.3]), + Invocation(TensorOfShape(1, 3, 5, 6), None, True, [2.5, 1.0]) +]) +def aten〇upsample_bilinear2d〇vec〡shape(input: List[int], output_size: Optional[List[int]], align_corners: bool, scale_factors: Optional[List[float]]) -> List[int]: + return aten〇upsample_nearest2d〇vec〡shape(input, output_size, scale_factors) + # ============================================================================== # Dtype Functions # ============================================================================== @@ -3145,6 +3165,15 @@ def aten〇leaky_relu_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], promoted_dtype = promote_dtypes(ranks, dtypes) return promoted_dtype +@check_dtype_function([Invocation(TensorOfShape(3, 3, dtype=dtype), TensorOfShape(3, 3, dtype=dtype), TensorOfShape(3, 3, dtype=dtype), 0.1, 0.9, False, False) for dtype in _SORTED_TORCH_TYPES]) +def aten〇rrelu_with_noise_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], self_rank_dtype: Tuple[int, int], noise_rank_dtype: Tuple[int, int], lower: Union[int, float, complex], upper: Union[int, float, complex], training: bool, self_is_result: bool) -> int: + grad_output_rank, grad_output_dtype = grad_output_rank_dtype + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [grad_output_rank, self_rank] + dtypes = [grad_output_dtype, self_dtype] + promoted_dtype = promote_dtypes(ranks, dtypes) + return promoted_dtype + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇lift_fresh_copy〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype @@ -3312,6 +3341,15 @@ def aten〇rrelu〡dtype(self_rank_dtype: Tuple[int, int], lower: Union[int, flo assert is_float_dtype(self_dtype) or is_complex_dtype(self_dtype) return self_dtype +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=2, error_types={torch.bool, *all_integer_dtypes()})) +def aten〇rrelu_with_noise〡dtype(self_rank_dtype: Tuple[int, int], noise_rank_dtype: Tuple[int, int], lower: Union[int, float, complex] = 0.125, upper: Union[int, float, complex] = 0.33333333333333331, training: bool = False, generator: Any = None) -> int: + self_rank, self_dtype = self_rank_dtype + noise_rank, noise_dtype = noise_rank_dtype + assert is_float_dtype(self_dtype) or is_complex_dtype(self_dtype) + assert is_float_dtype(noise_dtype) or is_complex_dtype(noise_dtype) + assert self_rank == noise_rank + return self_dtype + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.bool})) def aten〇relu6〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype @@ -3569,6 +3607,16 @@ def aten〇upsample_nearest2d〇vec〡dtype(input_rank_dtype: Tuple[int, int], o self_rank, self_dtype = input_rank_dtype return self_dtype +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], output_size=[11, 13], align_corners=True)) +def aten〇upsample_bilinear2d〡dtype(self_rank_dtype: Tuple[int, int], output_size: List[int], align_corners: bool, scales_h: Optional[float] = None, scales_w: Optional[float] = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], output_size=[11, 13], align_corners=True, scale_factors=None)) +def aten〇upsample_bilinear2d〇vec〡dtype(input_rank_dtype: Tuple[int, int], output_size: Optional[List[int]], align_corners: bool, scale_factors: Optional[List[float]]) -> int: + self_rank, self_dtype = input_rank_dtype + return self_dtype + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1])) def aten〇view〡dtype(self_rank_dtype: Tuple[int, int], size: List[int]) -> int: self_rank, self_dtype = self_rank_dtype diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index a1e201720c1c..4fa0cc0e6f6a 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -302,6 +302,7 @@ def emit_with_mutating_variants(key, **kwargs): "aten::relu6 : (Tensor) -> (Tensor)", "aten::leaky_relu : (Tensor, Scalar) -> (Tensor)", "aten::rrelu : (Tensor, Scalar, Scalar, bool, Generator?) -> (Tensor)", + "aten::rrelu_with_noise : (Tensor, Tensor, Scalar, Scalar, bool, Generator?) -> (Tensor)", "aten::celu : (Tensor, Scalar) -> (Tensor)", "aten::selu : (Tensor) -> (Tensor)", "aten::sigmoid : (Tensor) -> (Tensor)", @@ -636,7 +637,8 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::max_pool3d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)") emit("aten::max_unpool3d : (Tensor, Tensor, int[], int[], int[]) -> (Tensor)") emit( - "aten::max_pool3d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)" + "aten::max_pool3d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)", + has_canonicalizer=True, ) emit( "aten::max_pool3d_with_indices_backward : (Tensor, Tensor, int[], int[], int[], int[], bool, Tensor) -> (Tensor)" @@ -683,7 +685,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::adaptive_max_pool2d : (Tensor, int[]) -> (Tensor, Tensor)") emit("aten::adaptive_max_pool3d : (Tensor, int[]) -> (Tensor, Tensor)") emit("aten::topk : (Tensor, int, int, bool, bool) -> (Tensor, Tensor)") - emit("aten::transpose.int : (Tensor, int, int) -> (Tensor)") + emit("aten::transpose.int : (Tensor, int, int) -> (Tensor)", has_folder=True) emit("aten::pixel_shuffle : (Tensor, int) -> (Tensor)") emit("aten::permute : (Tensor, int[]) -> (Tensor)", has_verifier=True) emit("aten::movedim.int : (Tensor, int, int) -> (Tensor)") @@ -768,9 +770,11 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::pad : (Tensor, int[], str, float?) -> (Tensor)") emit("aten::squeeze.dim : (Tensor, int) -> (Tensor)", has_folder=True) emit("aten::squeeze : (Tensor) -> (Tensor)", has_folder=True) - emit("aten::flatten.using_ints : (Tensor, int, int) -> (Tensor)") + emit("aten::flatten.using_ints : (Tensor, int, int) -> (Tensor)", has_folder=True) emit( - "aten::unflatten.int : (Tensor, int, int[]) -> (Tensor)", has_canonicalizer=True + "aten::unflatten.int : (Tensor, int, int[]) -> (Tensor)", + has_canonicalizer=True, + has_folder=True, ) emit("aten::dim : (Tensor) -> (int)", has_folder=True) emit("aten::size : (Tensor) -> (int[])", has_canonicalizer=True) @@ -1018,6 +1022,10 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::upsample_nearest1d.vec : (Tensor, int[]?, float[]?) -> (Tensor)") emit("aten::upsample_nearest2d : (Tensor, int[], float?, float?) -> (Tensor)") emit("aten::upsample_nearest2d.vec : (Tensor, int[]?, float[]?) -> (Tensor)") + emit( + "aten::upsample_bilinear2d : (Tensor, int[], bool, float?, float?) -> (Tensor)" + ) + emit("aten::upsample_bilinear2d.vec : (Tensor, int[]?, bool, float[]?) -> (Tensor)") emit( "aten::scaled_dot_product_attention : (Tensor, Tensor, Tensor, Tensor?, float, bool, float?, bool) -> (Tensor)" ) @@ -1177,6 +1185,9 @@ def emit_with_mutating_variants(key, **kwargs): "aten::elu_backward : (Tensor, Scalar, Scalar, Scalar, bool, Tensor) -> (Tensor)" ) emit("aten::leaky_relu_backward : (Tensor, Tensor, Scalar, bool) -> (Tensor)") + emit( + "aten::rrelu_with_noise_backward : (Tensor, Tensor, Tensor, Scalar, Scalar, bool, bool) -> (Tensor)" + ) # quantized ops emit("aten::quantize_per_channel : (Tensor, Tensor, Tensor, int, int) -> (Tensor)") diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/backprop.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/backprop.py index e209d15b2b0b..5e6e093902c4 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/backprop.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/backprop.py @@ -322,3 +322,164 @@ def forward(self, grad, input): @register_test_case(module_factory=lambda: LeakyReluBackwardStaticModule()) def LeakyReluBackwardStaticModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5), tu.rand(3, 4, 5)) + + +# ============================================================================== + + +class RreluWithNoiseBackwardTrainModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ] + ) + def forward(self, grad, input, noise): + return torch.ops.aten.rrelu_with_noise_backward( + grad, + input, + noise, + lower=0.1, + upper=0.9, + training=True, + self_is_result=False, + ) + + +@register_test_case(module_factory=lambda: RreluWithNoiseBackwardTrainModule()) +def RreluWithNoiseBackwardTrainModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5), tu.rand(3, 4, 5), tu.rand(3, 4, 5)) + + +class RreluWithNoiseBackwardTrainStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([3, 4, 5], torch.float32, True), + ([3, 4, 5], torch.float32, True), + ([3, 4, 5], torch.float32, True), + ] + ) + def forward(self, grad, input, noise): + return torch.ops.aten.rrelu_with_noise_backward( + grad, + input, + noise, + lower=0.1, + upper=0.9, + training=True, + self_is_result=False, + ) + + +@register_test_case(module_factory=lambda: RreluWithNoiseBackwardTrainStaticModule()) +def RreluWithNoiseBackwardTrainStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5), tu.rand(3, 4, 5), tu.rand(3, 4, 5)) + + +# ============================================================================== + + +class RreluWithNoiseBackwardEvalModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ] + ) + def forward(self, grad, input, noise): + return torch.ops.aten.rrelu_with_noise_backward( + grad, + input, + noise, + lower=0.1, + upper=0.9, + training=False, + self_is_result=False, + ) + + +@register_test_case(module_factory=lambda: RreluWithNoiseBackwardEvalModule()) +def RreluWithNoiseBackwardEvalModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5), tu.rand(3, 4, 5), tu.rand(3, 4, 5)) + + +class RreluWithNoiseBackwardEvalStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([3, 4, 5], torch.float32, True), + ([3, 4, 5], torch.float32, True), + ([3, 4, 5], torch.float32, True), + ] + ) + def forward(self, grad, input, noise): + return torch.ops.aten.rrelu_with_noise_backward( + grad, + input, + noise, + lower=0.1, + upper=0.9, + training=False, + self_is_result=False, + ) + + +@register_test_case(module_factory=lambda: RreluWithNoiseBackwardEvalStaticModule()) +def RreluWithNoiseBackwardEvalStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5), tu.rand(3, 4, 5), tu.rand(3, 4, 5)) + + +class RreluWithNoiseForwardBackwardModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ] + ) + def forward(self, grad, input, noise): + res = torch.ops.aten.rrelu_with_noise_backward( + grad, + input, + noise, + lower=0.4, + upper=0.6, + training=True, + self_is_result=False, + ) + return torch.mean(res), torch.std(res) + + +@register_test_case(module_factory=lambda: RreluWithNoiseForwardBackwardModule()) +def RreluWithNoiseForwardBackwardModule_basic(module, tu: TestUtils): + grad = tu.rand(256, 244) + input = tu.rand(256, 244, low=-1.0, high=1.0) + noise = tu.rand(256, 244) + torch.ops.aten.rrelu_with_noise(input, noise, lower=0.4, upper=0.6, training=True) + module.forward(grad, input, noise) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index f21e0eaf02ee..389de50e4cc1 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -1179,6 +1179,88 @@ def ElementwiseRreluEvalStaticModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseRreluWithNoiseTrainModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [None, ([-1, -1], torch.float32, True), ([-1, -1], torch.float32, True)] + ) + def forward(self, x, noise): + res = torch.ops.aten.rrelu_with_noise(x, noise, 0.2, 0.5, True) + return torch.mean(res), torch.std(res) + + +@register_test_case(module_factory=lambda: ElementwiseRreluWithNoiseTrainModule()) +def ElementwiseRreluWithNoiseTrainModule_basic(module, tu: TestUtils): + module.forward(tu.rand(128, 128, low=-1, high=1), tu.rand(128, 128)) + + +# ============================================================================== + + +class ElementwiseRreluWithNoiseTrainStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [None, ([128, 128], torch.float32, True), ([128, 128], torch.float32, True)] + ) + def forward(self, x, noise): + res = torch.ops.aten.rrelu_with_noise(x, noise, 0.4, 0.6, True) + return torch.mean(res), torch.std(res) + + +@register_test_case(module_factory=lambda: ElementwiseRreluWithNoiseTrainStaticModule()) +def ElementwiseRreluWithNoiseTrainStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(128, 128, low=-1, high=1), tu.rand(128, 128)) + + +# ============================================================================== + + +class ElementwiseRreluWithNoiseEvalModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [None, ([-1, -1], torch.float32, True), ([-1, -1], torch.float32, True)] + ) + def forward(self, x, noise): + res = torch.ops.aten.rrelu_with_noise(x, noise, 0.4, 0.6, False) + return torch.mean(res), torch.std(res) + + +@register_test_case(module_factory=lambda: ElementwiseRreluWithNoiseEvalModule()) +def ElementwiseRreluWithNoiseEvalModule_basic(module, tu: TestUtils): + module.forward(tu.rand(5, 3, low=-1, high=1), tu.rand(5, 3)) + + +# ============================================================================== + + +class ElementwiseRreluWithNoiseEvalStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([5, 3], torch.float32, True), ([5, 3], torch.float32, True)]) + def forward(self, x, noise): + res = torch.ops.aten.rrelu_with_noise(x, noise, 0.4, 0.6, False) + return torch.mean(res), torch.std(res) + + +@register_test_case(module_factory=lambda: ElementwiseRreluWithNoiseEvalStaticModule()) +def ElementwiseRreluWithNoiseEvalStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(5, 3, low=-1, high=1), tu.rand(5, 3)) + + +# ============================================================================== + + class ElementwiseCeluStaticModule(torch.nn.Module): def __init__(self): super().__init__() diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py index fd241ba00850..2067e13d5997 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py @@ -361,6 +361,8 @@ def AtenMmIntTypes_basic(module, tu: TestUtils): # ============================================================================== +# For DQ-Q fake quantization ops +import torch.ao.quantization.fx._decomposed class AtenMmQint8(torch.nn.Module): @@ -376,12 +378,14 @@ def __init__(self): ] ) def forward(self, x, y): - qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, -25) - qx = torch.dequantize(qx) - qy = torch._make_per_tensor_quantized_tensor(y, 0.0176, 18) - qy = torch.dequantize(qy) - qz = torch.mm(qx, qy) - return qz + x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default( + x, 0.0215, -25, -128, 127, torch.int8 + ) + y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default( + y, 0.0176, 18, -128, 127, torch.int8 + ) + z = torch.mm(x, y) + return z @register_test_case(module_factory=lambda: AtenMmQint8()) @@ -408,12 +412,14 @@ def __init__(self): ] ) def forward(self, x, y): - qx = torch._make_per_tensor_quantized_tensor(x, 0.199, 65) - qx = torch.dequantize(qx) - qy = torch._make_per_tensor_quantized_tensor(y, 0.0215, 160) - qy = torch.dequantize(qy) - qz = torch.mm(qx, qy) - return qz + x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default( + x, 0.199, 65, 0, 255, torch.uint8 + ) + y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default( + y, 0.0215, 160, 0, 255, torch.uint8 + ) + z = torch.mm(x, y) + return z @register_test_case(module_factory=lambda: AtenMmQuint8()) @@ -440,12 +446,14 @@ def __init__(self): ] ) def forward(self, x, y): - qx = torch._make_per_tensor_quantized_tensor(x, 0.03, -66) - qx = torch.dequantize(qx) - qy = torch._make_per_tensor_quantized_tensor(y, 0.025, 160) - qy = torch.dequantize(qy) - qz = torch.mm(qx, qy) - return qz + x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default( + x, 0.03, -66, -128, 127, torch.int8 + ) + y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default( + y, 0.025, 160, 0, 255, torch.uint8 + ) + z = torch.mm(x, y) + return z @register_test_case(module_factory=lambda: AtenMmQMixedSigni8()) @@ -499,12 +507,14 @@ def __init__(self): ] ) def forward(self, x, y): - qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, -25) - qx = torch.dequantize(qx) - qy = torch._make_per_tensor_quantized_tensor(y, 0.0176, 18) - qy = torch.dequantize(qy) - qz = torch.matmul(qx, qy) - return qz + x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default( + x, 0.0215, -25, -128, 127, torch.int8 + ) + y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default( + y, 0.0176, 18, -128, 127, torch.int8 + ) + z = torch.matmul(x, y) + return z @register_test_case(module_factory=lambda: AtenMatmulQint8VM()) @@ -529,12 +539,14 @@ def __init__(self): ] ) def forward(self, x, y): - qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, -25) - qx = torch.dequantize(qx) - qy = torch._make_per_tensor_quantized_tensor(y, 0.0176, 18) - qy = torch.dequantize(qy) - qz = torch.matmul(qx, qy) - return qz + x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default( + x, 0.0215, -25, -128, 127, torch.int8 + ) + y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default( + y, 0.0176, 18, -128, 127, torch.int8 + ) + z = torch.matmul(x, y) + return z @register_test_case(module_factory=lambda: AtenMatmulQint8VV()) @@ -559,12 +571,14 @@ def __init__(self): ] ) def forward(self, x, y): - qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, -25) - qx = torch.dequantize(qx) - qy = torch._make_per_tensor_quantized_tensor(y, 0.0176, 18) - qy = torch.dequantize(qy) - qz = torch.matmul(qx, qy) - return qz + x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default( + x, 0.0215, -25, -128, 127, torch.int8 + ) + y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default( + y, 0.0176, 18, -128, 127, torch.int8 + ) + z = torch.matmul(x, y) + return z @register_test_case(module_factory=lambda: AtenMatmulQint8MV()) @@ -589,12 +603,14 @@ def __init__(self): ] ) def forward(self, x, y): - qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, -25) - qx = torch.dequantize(qx) - qy = torch._make_per_tensor_quantized_tensor(y, 0.0176, 18) - qy = torch.dequantize(qy) - qz = torch.matmul(qx, qy) - return qz + x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default( + x, 0.0215, -25, -128, 127, torch.int8 + ) + y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default( + y, 0.0176, 18, -128, 127, torch.int8 + ) + z = torch.matmul(x, y) + return z @register_test_case(module_factory=lambda: AtenMatmulQint8()) @@ -621,12 +637,14 @@ def __init__(self): ] ) def forward(self, x, y): - qx = torch._make_per_tensor_quantized_tensor(x, 0.03, -66) - qx = torch.dequantize(qx) - qy = torch._make_per_tensor_quantized_tensor(y, 0.025, 160) - qy = torch.dequantize(qy) - qz = torch.matmul(qx, qy) - return qz + x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default( + x, 0.03, -66, -128, 127, torch.int8 + ) + y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default( + y, 0.025, 160, 0, 255, torch.uint8 + ) + z = torch.matmul(x, y) + return z @register_test_case(module_factory=lambda: AtenMatmulQMixedSigni8()) @@ -653,13 +671,15 @@ def __init__(self): ] ) def forward(self, x, y): - qx = torch._make_per_tensor_quantized_tensor(x, 0.03, -66) - qx = torch.dequantize(qx) - qy = torch._make_per_tensor_quantized_tensor(y, 0.025, 160) - qy = torch.dequantize(qy) - qy = torch.transpose(qy, 1, 2) - qz = torch.matmul(qx, qy) - return qz + x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default( + x, 0.03, -66, -128, 127, torch.int8 + ) + y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default( + y, 0.025, 160, 0, 255, torch.uint8 + ) + y = torch.transpose(y, 1, 2) + z = torch.matmul(x, y) + return z @register_test_case(module_factory=lambda: AtenMatmulQMixedSigni8Transpose()) diff --git a/pytorch-hash.txt b/pytorch-hash.txt index f9e0abfabac1..dd4f3a19ad33 100644 --- a/pytorch-hash.txt +++ b/pytorch-hash.txt @@ -1 +1 @@ -160d421a40e934ac8183e47f9cbc8618a4bd97dd +c787213d413e85c66bdad0d8c9cde1c5ced34b1b diff --git a/test/CAPI/torch.c b/test/CAPI/torch.c index d42cf96d554c..3d1308f08b25 100644 --- a/test/CAPI/torch.c +++ b/test/CAPI/torch.c @@ -33,12 +33,12 @@ static void testTensor(MlirContext ctx, intptr_t numSizes, int64_t *sizes, bool TTT##hasDtype = torchMlirTorch##TTT##TypeHasDtype(TTT##Type); \ fprintf(stderr, #TTT "Type %s hasDtype: %d\n", testName, TTT##hasDtype); \ if (TTT##hasSizes) { \ - fprintf(stderr, #TTT "Type %s rank: %zu\n", testName, \ + fprintf(stderr, #TTT "Type %s rank: %" PRId64 "\n", testName, \ torchMlirTorch##TTT##TypeGetRank(TTT##Type)); \ int64_t *TTT##Sizes = malloc(sizeof(int64_t) * numSizes); \ torchMlirTorch##TTT##TypeGetSizes(TTT##Type, TTT##Sizes); \ for (int i = 0; i < numSizes; ++i) { \ - fprintf(stderr, #TTT "Type %s pos %d size: %ld\n", testName, i, \ + fprintf(stderr, #TTT "Type %s pos %d size: %" PRId64 "\n", testName, i, \ TTT##Sizes[i]); \ } \ } \ diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index e478c53a470e..f51b0f2ee50f 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -730,6 +730,86 @@ func.func @test_maxpool_pad(%arg0: !torch.vtensor<[1,64,111,111],f32>) -> !torch return %0 : !torch.vtensor<[1,64,56,56],f32> } +// ----- + +// CHECK-LABEL: func.func @test_maxpool_2d_same_lower +func.func @test_maxpool_2d_same_lower(%arg0: !torch.vtensor<[1,3,32,32],f32>) -> !torch.vtensor<[1,3,32,32],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64} { + // CHECK: %[[int1:.*]] = torch.constant.int 1 + // CHECK: %[[int0:.*]] = torch.constant.int 0 + // CHECK: %[[int1_0:.*]] = torch.constant.int 1 + // CHECK: %[[int0_1:.*]] = torch.constant.int 0 + // CHECK: %[[list0:.*]] = torch.prim.ListConstruct %[[int1]], %[[int0]], %[[int1_0]], %[[int0_1]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[FLOAT0:.*]] = torch.constant.float -1.7976931348623157E+308 + // CHECK: %[[FUNC1:.*]] = torch.aten.constant_pad_nd %arg0, %[[list0]], %[[FLOAT0]] : !torch.vtensor<[1,3,32,32],f32>, !torch.list, !torch.float -> !torch.vtensor<[1,3,33,33],f32> + // CHECK: %[[int2:.*]] = torch.constant.int 2 + // CHECK: %[[int2_2:.*]] = torch.constant.int 2 + // CHECK: %[[list1:.*]] = torch.prim.ListConstruct %[[int2]], %[[int2_2]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[int0_3:.*]] = torch.constant.int 0 + // CHECK: %[[int0_4:.*]] = torch.constant.int 0 + // CHECK: %[[list2:.*]] = torch.prim.ListConstruct %[[int0_3]], %[[int0_4]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[int1_5:.*]] = torch.constant.int 1 + // CHECK: %[[int1_6:.*]] = torch.constant.int 1 + // CHECK: %[[list3:.*]] = torch.prim.ListConstruct %[[int1_5]], %[[int1_6]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[int1_7:.*]] = torch.constant.int 1 + // CHECK: %[[int1_8:.*]] = torch.constant.int 1 + // CHECK: %[[list4:.*]] = torch.prim.ListConstruct %[[int1_7]], %[[int1_8]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: %[[FUNC6:.*]] = torch.aten.max_pool2d %[[FUNC1]], %[[list1]], %[[list3]], %[[list2]], %[[list4]], %[[FALSE]] : !torch.vtensor<[1,3,33,33],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool -> !torch.vtensor<[1,3,32,32],f32> + %0 = torch.operator "onnx.MaxPool"(%arg0) {torch.onnx.auto_pad = "SAME_LOWER", torch.onnx.kernel_shape = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,3,32,32],f32>) -> !torch.vtensor<[1,3,32,32],f32> + return %0 : !torch.vtensor<[1,3,32,32],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_maxpool_2d_same_upper +func.func @test_maxpool_2d_same_upper(%arg0: !torch.vtensor<[1,3,32,32],f32>) -> !torch.vtensor<[1,3,32,32],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[int0:.*]] = torch.constant.int 0 + // CHECK: %[[int1:.*]] = torch.constant.int 1 + // CHECK: %[[int0_0:.*]] = torch.constant.int 0 + // CHECK: %[[int1_1:.*]] = torch.constant.int 1 + // CHECK: %[[list0:.*]] = torch.prim.ListConstruct %[[int0]], %[[int1]], %[[int0_0]], %[[int1_1]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[FLOAT0:.*]] = torch.constant.float -1.7976931348623157E+308 + // CHECK: %[[FUNC1:.*]] = torch.aten.constant_pad_nd %arg0, %[[list0]], %[[FLOAT0]] : !torch.vtensor<[1,3,32,32],f32>, !torch.list, !torch.float -> !torch.vtensor<[1,3,33,33],f32> + // CHECK: %[[int2:.*]] = torch.constant.int 2 + // CHECK: %[[int2_2:.*]] = torch.constant.int 2 + // CHECK: %[[list1:.*]] = torch.prim.ListConstruct %[[int2]], %[[int2_2]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[int0_3:.*]] = torch.constant.int 0 + // CHECK: %[[int0_4:.*]] = torch.constant.int 0 + // CHECK: %[[list2:.*]] = torch.prim.ListConstruct %[[int0_3]], %[[int0_4]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[int1_5:.*]] = torch.constant.int 1 + // CHECK: %[[int1_6:.*]] = torch.constant.int 1 + // CHECK: %[[list3:.*]] = torch.prim.ListConstruct %[[int1_5]], %[[int1_6]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[int1_7:.*]] = torch.constant.int 1 + // CHECK: %[[int1_8:.*]] = torch.constant.int 1 + // CHECK: %[[list4:.*]] = torch.prim.ListConstruct %[[int1_7]], %[[int1_8]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: %[[FUNC6:.*]] = torch.aten.max_pool2d %[[FUNC1]], %[[list1]], %[[list3]], %[[list2]], %[[list4]], %[[FALSE]] : !torch.vtensor<[1,3,33,33],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool -> !torch.vtensor<[1,3,32,32],f32> + %0 = torch.operator "onnx.MaxPool"(%arg0) {torch.onnx.auto_pad = "SAME_UPPER", torch.onnx.kernel_shape = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,3,32,32],f32>) -> !torch.vtensor<[1,3,32,32],f32> + return %0 : !torch.vtensor<[1,3,32,32],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_maxpool_2d_precomputed_same_upper +func.func @test_maxpool_2d_precomputed_same_upper(%arg0: !torch.vtensor<[1,1,5,5],f32>) -> !torch.vtensor<[1,1,3,3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64}{ + // CHECK: %[[int3:.*]] = torch.constant.int 3 + // CHECK: %[[int3_0:.*]] = torch.constant.int 3 + // CHECK: %[[list0:.*]] = torch.prim.ListConstruct %[[int3]], %[[int3_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[int1:.*]] = torch.constant.int 1 + // CHECK: %[[int1_1:.*]] = torch.constant.int 1 + // CHECK: %[[list1:.*]] = torch.prim.ListConstruct %[[int1]], %[[int1_1]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[int2:.*]] = torch.constant.int 2 + // CHECK: %[[int2_2:.*]] = torch.constant.int 2 + // CHECK: %[[list2:.*]] = torch.prim.ListConstruct %[[int2]], %[[int2_2]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[int1_3:.*]] = torch.constant.int 1 + // CHECK: %[[int1_4:.*]] = torch.constant.int 1 + // CHECK: %[[list3:.*]] = torch.prim.ListConstruct %[[int1_3]], %[[int1_4]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: %[[FUNC4:.*]] = torch.aten.max_pool2d %arg0, %[[list0]], %[[list2]], %[[list1]], %[[list3]], %[[FALSE]] : !torch.vtensor<[1,1,5,5],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool -> !torch.vtensor<[1,1,3,3],f32> +%0 = torch.operator "onnx.MaxPool"(%arg0) {torch.onnx.auto_pad = "SAME_UPPER", torch.onnx.kernel_shape = [3 : si64, 3 : si64], torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,1,5,5],f32>) -> !torch.vtensor<[1,1,3,3],f32> +return %0 : !torch.vtensor<[1,1,3,3],f32> +} + // ----- diff --git a/test/Conversion/TorchToLinalg/convolution.mlir b/test/Conversion/TorchToLinalg/convolution.mlir index 3023c0ba6d8a..480b1eeb9ed2 100644 --- a/test/Conversion/TorchToLinalg/convolution.mlir +++ b/test/Conversion/TorchToLinalg/convolution.mlir @@ -24,12 +24,8 @@ func.func @torch.aten.convolution$nobias(%arg0: !torch.vtensor<[1,24,16,128,128] // CHECK: %[[c7:.*]] = arith.constant 7 : i32 // CHECK: %[[input:.*]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[?,?,?,?],si8> -> tensor // CHECK: %[[weight:.*]] = torch_c.to_builtin_tensor %arg1 : !torch.vtensor<[?,?,?,?],si8> -> tensor -// CHECK: %[[TransInput:.*]] = linalg.transpose ins(%[[input]] : tensor) -// CHECK-SAME: permutation = [0, 2, 3, 1] -// CHECK: %[[TransWeight:.*]] = linalg.transpose ins(%[[weight]] : tensor) -// CHECK-SAME: permutation = [2, 3, 1, 0] -// CHECK: %[[conv:.*]] = linalg.conv_2d_nhwc_hwcf_q {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} -// CHECK-SAME: ins(%[[TransInput]], %[[TransWeight]], %[[c7]], %[[c3]] : tensor, tensor, i32, i32) +// CHECK: %[[conv:.*]] = linalg.conv_2d_nchw_fchw_q {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} +// CHECK-SAME: ins(%[[input]], %[[weight]], %[[c7]], %[[c3]] : tensor, tensor, i32, i32) // CHECK-SAME: outs(%[[convout:.*]] : tensor) -> tensor func.func @q_conv_test(%arg0: !torch.vtensor<[?,?,?,?],si8>, %arg1: !torch.vtensor<[?,?,?,?],si8>, %arg2: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { %false = torch.constant.bool false diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index ae58dc6215c0..0e70994c3663 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -2373,3 +2373,35 @@ func.func @torch.aten.diag_embed$basic(%arg0: !torch.vtensor<[2,3,4],f32>) -> !t %0 = torch.aten.diag_embed %arg0, %int0, %int-2, %int-1 : !torch.vtensor<[2,3,4],f32>, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,3,4,4],f32> return %0 : !torch.vtensor<[2,3,4,4],f32> } + +// ----- + +// CHECK-LABEL: func.func @torch.aten.index.Tensor_hacked_twin( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2,4,2],si64>, +// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[],si64>) -> !torch.vtensor<[4,2],si64> { +// CHECK: %[[VAL_0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[2,4,2],si64> -> tensor<2x4x2xi64> +// CHECK: %[[VAL_1:.*]] = torch.prim.ListConstruct %[[ARG1]] : (!torch.vtensor<[],si64>) -> !torch.list +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[],si64> -> tensor +// CHECK: %[[VAL_3:.*]] = tosa.cast %[[VAL_2]] : (tensor) -> tensor +// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<0> : tensor}> : () -> tensor +// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<2> : tensor}> : () -> tensor +// CHECK: %[[VAL_6:.*]] = tosa.add %[[VAL_5]], %[[VAL_3]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_7:.*]] = tosa.greater %[[VAL_4]], %[[VAL_3]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_8:.*]] = tosa.select %[[VAL_7]], %[[VAL_6]], %[[VAL_3]] : (tensor, tensor, tensor) -> tensor +// CHECK: %[[VAL_9:.*]] = tosa.reshape %[[VAL_8]] {new_shape = array} : (tensor) -> tensor<1xi32> +// CHECK: %[[VAL_10:.*]] = tosa.reshape %[[VAL_0]] {new_shape = array} : (tensor<2x4x2xi64>) -> tensor<1x2x8xi64> +// CHECK: %[[VAL_11:.*]] = tosa.reshape %[[VAL_9]] {new_shape = array} : (tensor<1xi32>) -> tensor<1x1xi32> +// CHECK: %[[VAL_12:.*]] = "tosa.const"() <{value = dense<1> : tensor<1xi32>}> : () -> tensor<1xi32> +// CHECK: %[[VAL_13:.*]] = tosa.mul %[[VAL_11]], %[[VAL_12]] {shift = 0 : i8} : (tensor<1x1xi32>, tensor<1xi32>) -> tensor<1x1xi32> +// CHECK: %[[VAL_14:.*]] = tosa.reduce_sum %[[VAL_13]] {axis = 1 : i32} : (tensor<1x1xi32>) -> tensor<1x1xi32> +// CHECK: %[[VAL_15:.*]] = tosa.reshape %[[VAL_14]] {new_shape = array} : (tensor<1x1xi32>) -> tensor<1x1xi32> +// CHECK: %[[VAL_16:.*]] = tosa.gather %[[VAL_10]], %[[VAL_15]] : (tensor<1x2x8xi64>, tensor<1x1xi32>) -> tensor<1x1x8xi64> +// CHECK: %[[VAL_17:.*]] = tosa.reshape %[[VAL_16]] {new_shape = array} : (tensor<1x1x8xi64>) -> tensor<4x2xi64> +// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[VAL_17]] : tensor<4x2xi64> -> !torch.vtensor<[4,2],si64> +// CHECK: return %[[RESULT]] : !torch.vtensor<[4,2],si64> + +func.func @torch.aten.index.Tensor_hacked_twin(%arg0: !torch.vtensor<[2,4,2],si64>, %arg1: !torch.vtensor<[],si64>) -> !torch.vtensor<[4,2],si64> { + %0 = torch.prim.ListConstruct %arg1 : (!torch.vtensor<[],si64>) -> !torch.list + %1 = torch.aten.index.Tensor_hacked_twin %arg0, %0 : !torch.vtensor<[2,4,2],si64>, !torch.list -> !torch.vtensor<[4,2],si64> + return %1 : !torch.vtensor<[4,2],si64> + } diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index f13bf60cb15b..90b4e103c4fb 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -1682,6 +1682,82 @@ func.func @torch.aten.view$1D(%arg0: !torch.tensor<[?],f32>) -> !torch.tensor<[? return %1 : !torch.tensor<[?],f32> } +// CHECK-LABEL: func.func @torch.aten.view$fold_splat( +// CHECK: %[[SPLAT:.*]] = torch.vtensor.literal(dense<2> : tensor<2x4x1xsi64>) : !torch.vtensor<[2,4,1],si64> +// CHECK: return %[[SPLAT]] : !torch.vtensor<[2,4,1],si64> +func.func @torch.aten.view$fold_splat() -> !torch.vtensor<[2,4,1],si64> { + %int4 = torch.constant.int 4 + %int2 = torch.constant.int 2 + %int1 = torch.constant.int 1 + %0 = torch.vtensor.literal(dense<2> : tensor<8xsi64>) : !torch.vtensor<[8],si64> + %1 = torch.prim.ListConstruct %int2, %int4, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2 = torch.aten.view %0, %1 : !torch.vtensor<[8],si64>, !torch.list -> !torch.vtensor<[2,4,1],si64> + return %2 : !torch.vtensor<[2,4,1],si64> +} + +// CHECK-LABEL: func.func @torch.aten.view$fold_literal( +// CHECK: %[[LITERAL:.*]] = torch.vtensor.literal(dense<[ +// CHECK-SAME: [ +// CHECK-SAME: [0, 1], [2, 3], [4, 5], [6, 7]]]> : tensor<1x4x2xsi64>) : !torch.vtensor<[1,4,2],si64> +// CHECK: return %[[LITERAL]] : !torch.vtensor<[1,4,2],si64> +func.func @torch.aten.view$fold_literal() -> !torch.vtensor<[1,4,2],si64> { + %int4 = torch.constant.int 4 + %int2 = torch.constant.int 2 + %int1 = torch.constant.int 1 + %0 = torch.vtensor.literal(dense<[0,1,2,3,4,5,6,7]> : tensor<8xsi64>) : !torch.vtensor<[8],si64> + %1 = torch.prim.ListConstruct %int1, %int4, %int2 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2 = torch.aten.view %0, %1 : !torch.vtensor<[8],si64>, !torch.list -> !torch.vtensor<[1,4,2],si64> + return %2 : !torch.vtensor<[1,4,2],si64> +} + +// CHECK-LABEL: func.func @torch.aten.transpose.int$fold_literal( +// CHECK: %[[LIT:.*]] = torch.vtensor.literal(dense<[ +// CHECK-SAME: [0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xsi64>) : !torch.vtensor<[2,4],si64> +// CHECK: return %[[LIT]] : !torch.vtensor<[2,4],si64> +func.func @torch.aten.transpose.int$fold_literal() -> !torch.vtensor<[2,4],si64> { + %int-1 = torch.constant.int -1 + %int0 = torch.constant.int 0 + %0 = torch.vtensor.literal(dense<[[0,1],[2,3],[4,5],[6,7]]> : tensor<4x2xsi64>) : !torch.vtensor<[4,2],si64> + %1 = torch.aten.transpose.int %0, %int-1, %int0 : !torch.vtensor<[4,2],si64>, !torch.int, !torch.int -> !torch.vtensor<[2,4], si64> + return %1 : !torch.vtensor<[2,4],si64> +} + +// CHECK-LABEL: func.func @torch.aten.transpose.int$fold_noop( +// CHECK: return %arg0 : !torch.vtensor<[?,?,?,?],f32> +func.func @torch.aten.transpose.int$fold_noop(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { + %int-1 = torch.constant.int -1 + %int3 = torch.constant.int 3 + %0 = torch.aten.transpose.int %arg0, %int-1, %int3 : !torch.vtensor<[?,?,?,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?,?,?],f32> + return %0 : !torch.vtensor<[?,?,?,?],f32> +} + +// CHECK-LABEL: func.func @torch.aten.slice.Tensor$flip_slice_fold( +// CHECK: %[[LIT:.*]] = torch.vtensor.literal(dense<[ +// CHECK-SAME: [6, 7], [4, 5], [2, 3], [0, 1]]> : tensor<4x2xsi64>) : !torch.vtensor<[4,2],si64> +// CHECK: return %[[LIT]] : !torch.vtensor<[4,2],si64> +func.func @torch.aten.slice.Tensor$flip_slice_fold() -> !torch.vtensor<[4,2],si64> { + %int-9223372036854775807 = torch.constant.int -9223372036854775807 + %int-1 = torch.constant.int -1 + %int0 = torch.constant.int 0 + %0 = torch.vtensor.literal(dense<[[0, 1], [2, 3], [4, 5], [6, 7]]> : tensor<4x2xsi64>) : !torch.vtensor<[4,2],si64> + %1 = torch.aten.slice.Tensor %0, %int0, %int-1, %int-9223372036854775807, %int-1 : !torch.vtensor<[4,2],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,2],si64> + return %1 : !torch.vtensor<[4,2],si64> +} + +// CHECK-LABEL: func.func @torch.aten.slice.Tensor$negative_two_stride_fold( +// CHECK: %[[LIT:.*]] = torch.vtensor.literal(dense<[ +// CHECK-SAME: [6, 7], [2, 3]]> : tensor<2x2xsi64>) : !torch.vtensor<[2,2],si64> +// CHECK: return %[[LIT]] : !torch.vtensor<[2,2],si64> +func.func @torch.aten.slice.Tensor$negative_two_stride_fold() -> !torch.vtensor<[2,2],si64> { + %int-5 = torch.constant.int -5 + %int-1 = torch.constant.int -1 + %int-2 = torch.constant.int -2 + %int0 = torch.constant.int 0 + %0 = torch.vtensor.literal(dense<[[0, 1], [2, 3], [4, 5], [6, 7]]> : tensor<4x2xsi64>) : !torch.vtensor<[4,2],si64> + %1 = torch.aten.slice.Tensor %0, %int0, %int-1, %int-5, %int-2 : !torch.vtensor<[4,2],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,2],si64> + return %1 : !torch.vtensor<[2,2],si64> +} + // CHECK-LABEL: func.func @torch.aten.div.float$fold_zero_dividend( // CHECK: %[[CST0:.*]] = torch.constant.float 0.000000e+00 // CHECK: return %[[CST0]] : !torch.float @@ -3136,6 +3212,24 @@ func.func @torch.aten.max_pool2d_with_indices$canonicalize(%arg0: !torch.vtensor // ----- +// CHECK-LABEL: @torch.aten.max_pool3d_with_indices$canonicalize( +// CHECK: %[[ARG:.*]]: !torch.vtensor<[10,64,112,112,112],f32>) -> !torch.vtensor<[10,64,56,56,56],f32> { +// CHECK: %[[RET:.*]] = torch.aten.max_pool3d %[[ARG]] +// CHECK: return %[[RET]] : !torch.vtensor<[10,64,56,56,56],f32> +func.func @torch.aten.max_pool3d_with_indices$canonicalize(%arg0: !torch.vtensor<[10,64,112,112,112],f32>) -> !torch.vtensor<[10,64,56,56,56],f32> { + %false = torch.constant.bool false + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %int3 = torch.constant.int 3 + %29 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list + %30 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list + %31 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list + %result0, %result1 = torch.aten.max_pool3d_with_indices %arg0, %29, %30, %31, %31, %false : !torch.vtensor<[10,64,112,112,112],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool -> !torch.vtensor<[10,64,56,56,56],f32>, !torch.vtensor<[10,64,56,56,56],si64> + return %result0 : !torch.vtensor<[10,64,56,56,56],f32> +} + +// ----- + // CHECK-LABEL: @torch.aten.clone$no_fold( func.func @torch.aten.clone$no_fold(%arg0: !torch.vtensor<[1,2,50,4],f32>) -> (!torch.tensor) { // CHECK: %{{.*}} = torch.aten.clone %{{.*}}, %{{.*}} : !torch.vtensor<[1,2,50,4],f32>, !torch.none -> !torch.vtensor diff --git a/test/Dialect/Torch/simplify-shape-calculations.mlir b/test/Dialect/Torch/simplify-shape-calculations.mlir index 59884616f13f..af96e108efbd 100644 --- a/test/Dialect/Torch/simplify-shape-calculations.mlir +++ b/test/Dialect/Torch/simplify-shape-calculations.mlir @@ -152,6 +152,23 @@ func.func @fully_unroll_prim_loop$no_unroll(%arg0: !torch.vtensor, %arg1: !torch return %0 : !torch.vtensor } +// CHECK-LABEL: func.func @fully_unroll_prim_loop$outside_region( +// CHECK: %[[LOOP:.*]] = torch.prim.Loop +func.func @fully_unroll_prim_loop$outside_region(%arg0: !torch.vtensor, %arg1: !torch.list, %arg2: !torch.int) -> !torch.vtensor { + %true = torch.constant.bool true + %0 = torch.prim.Loop %arg2, %true, init(%arg0) { + ^bb0(%arg3: !torch.int, %arg4: !torch.vtensor): + %1 = torch.shape.calculate { + torch.shape.calculate.yield %arg4 : !torch.vtensor + } shapes { + torch.prim.Print(%arg3) : !torch.int + torch.shape.calculate.yield.shapes %arg1 : !torch.list + } : !torch.vtensor + torch.prim.Loop.condition %true, iter(%1 : !torch.vtensor) + } : (!torch.int, !torch.bool, !torch.vtensor) -> !torch.vtensor + return %0 : !torch.vtensor +} + // CHECK-LABEL: func.func @abstractly_interpret_list_ops$basic( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor, // CHECK-SAME: %[[ARG1:.*]]: !torch.int,