diff --git a/docs/development.md b/docs/development.md index f1e72966f84d..bdb4854e4d66 100644 --- a/docs/development.md +++ b/docs/development.md @@ -187,7 +187,7 @@ sudo apt install clang ccache lld - **...run Python regression tests**, run: ```shell - cmake --build build --target check-torch-mlir-python + cmake --build build --target check-torch_mlir-python ``` TIP: add multiple target options to stack build phases diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index eb08786f7982..d91926276f00 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -12,6 +12,7 @@ #include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" #include "PopulatePatterns.h" +#include "mlir/Dialect/Affine/Utils.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" @@ -671,7 +672,8 @@ class ConvertAtenUnflattenIntOp return rewriter.notifyMatchFailure(op, "Expected input type having sizes"); } - int inputRank = inputTensorType.getSizes().size(); + auto inputTensorSizes = inputTensorType.getSizes(); + int inputRank = inputTensorSizes.size(); auto outputSizes = outputTensorType.getSizes(); int outputRank = outputSizes.size(); @@ -692,7 +694,6 @@ class ConvertAtenUnflattenIntOp if (outputSizes[i] == Torch::kUnknownSize) numDynamicReassocDims++; } - SmallVector reassocSizes; if (!getListConstructElements(op.getSizes(), reassocSizes) && numDynamicReassocDims > 1) @@ -700,7 +701,8 @@ class ConvertAtenUnflattenIntOp op, "Must be able to either infer expansion dims, or retrieve them " "from list construct"); - auto expandTy = getTypeConverter()->convertType(outputTensorType); + RankedTensorType expandTy = cast( + getTypeConverter()->convertType(outputTensorType)); Value expand; // When there are less than two dynamic reassociation dims, this will lower // to tensor.expand_shape. Otherwise, this lowers to tensor.reshape. @@ -717,10 +719,80 @@ class ConvertAtenUnflattenIntOp for (int i = dimInt + numSizes; i < outputRank; ++i) reassociations[i - numSizes + 1].push_back(i); } - expand = rewriter - .create( - loc, expandTy, adaptor.getSelf(), reassociations) + + auto sizeToOFR = [&](Value sizeVal) -> OpFoldResult { + int64_t constantSize; + if (matchPattern(sizeVal, m_TorchConstantInt(&constantSize))) { + return rewriter.getIndexAttr(constantSize); + } + SmallVector singleSizeVec = {sizeVal}; + Value converted = castIntToIndex( + rewriter, loc, + getTypeConvertedValues(rewriter, loc, getTypeConverter(), + singleSizeVec)[0]); + return OpFoldResult(converted); + }; + + int64_t minusOneIdx = -1; + OpFoldResult knownProduct = rewriter.getIndexAttr(1); + AffineExpr s0 = getAffineSymbolExpr(0, rewriter.getContext()); + AffineExpr s1 = getAffineSymbolExpr(1, rewriter.getContext()); + auto mulMap = AffineMap::get(0, 2, s0 * s1, rewriter.getContext()); + + for (int64_t j = 0, e = reassocSizes.size(); j < e; ++j) { + int64_t constantSize; + if (matchPattern(reassocSizes[j], m_TorchConstantInt(&constantSize)) && + constantSize == -1) { + minusOneIdx = j; + } else { + knownProduct = affine::makeComposedFoldedAffineApply( + rewriter, loc, mulMap, + {knownProduct, sizeToOFR(reassocSizes[j])}); + } + } + + SmallVector outputShape; + SmallVector inputSizes = + getTensorSizes(rewriter, loc, adaptor.getSelf()); + for (int64_t i = 0; i < inputRank; ++i) { + if (i != dimInt) { + OpFoldResult inputDimSize = + (inputTensorSizes[i] != Torch::kUnknownSize) + ? rewriter.getIndexAttr(inputTensorSizes[i]) + : OpFoldResult(inputSizes[i]); + outputShape.push_back(inputDimSize); + continue; + } + + OpFoldResult inputDimSize = + (inputTensorSizes[dimInt] != Torch::kUnknownSize) + ? rewriter.getIndexAttr(inputTensorSizes[dimInt]) + : OpFoldResult(inputSizes[dimInt]); + for (int64_t j = 0; j < numSizes; ++j) { + if (j == minusOneIdx) { + auto divMap = + AffineMap::get(0, 2, s0.floorDiv(s1), rewriter.getContext()); + outputShape.push_back(affine::makeComposedFoldedAffineApply( + rewriter, loc, divMap, {inputDimSize, knownProduct})); + } else { + outputShape.push_back(sizeToOFR(reassocSizes[j])); + } + } + } + + SmallVector resultShape = + decomposeMixedValues(outputShape).first; + auto resultType = + RankedTensorType::get(resultShape, expandTy.getElementType()); + expand = tensor::ExpandShapeOp::create(rewriter, loc, resultType, + adaptor.getSelf(), reassociations, + outputShape) .getResult(); + + if (resultType != expandTy) { + expand = + rewriter.create(loc, expandTy, expand).getResult(); + } } else { reassocSizes = getTypeConvertedValues(rewriter, loc, getTypeConverter(), reassocSizes); @@ -745,6 +817,7 @@ class ConvertAtenUnflattenIntOp shapeValue) .getResult(); } + rewriter.replaceOp(op, expand); return success(); } diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py index d1ddc42b39b1..1441eb1890f7 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py @@ -1281,6 +1281,46 @@ def UnflattenIntNegativeOneSizeStaticModule_basic(module, tu: TestUtils): module.forward(tu.rand(5, 12, 3)) +class UnflattenIntDynamicModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, 12], torch.float32, True), + ] + ) + def forward(self, inputs): + return torch.ops.aten.unflatten(inputs, 1, [3, 4]) + + +@register_test_case(module_factory=lambda: UnflattenIntDynamicModule()) +def UnflattenIntDynamicModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 12)) + + +class UnflattenIntDynamicWithInferredSizeModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, 20], torch.float32, True), + ] + ) + def forward(self, inputs): + return torch.ops.aten.unflatten(inputs, 1, [4, -1]) + + +@register_test_case(module_factory=lambda: UnflattenIntDynamicWithInferredSizeModule()) +def UnflattenIntDynamicWithInferredSizeModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 20)) + + # ============================================================================== diff --git a/test/Conversion/TorchToLinalg/unflatten.mlir b/test/Conversion/TorchToLinalg/unflatten.mlir new file mode 100644 index 000000000000..01049d4fac29 --- /dev/null +++ b/test/Conversion/TorchToLinalg/unflatten.mlir @@ -0,0 +1,74 @@ +// RUN: torch-mlir-opt <%s -convert-torch-to-linalg -split-input-file -verify-diagnostics | FileCheck %s + +// CHECK-LABEL: func.func @torch.aten.unflatten.int$static +// CHECK: torch_c.to_builtin_tensor +// CHECK: tensor.expand_shape +// CHECK: torch_c.from_builtin_tensor +func.func @torch.aten.unflatten.int$static(%arg0: !torch.vtensor<[2,6,4],f32>) -> !torch.vtensor<[2,2,3,4],f32> { + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %int3 = torch.constant.int 3 + %0 = torch.prim.ListConstruct %int2, %int3 : (!torch.int, !torch.int) -> !torch.list + %1 = torch.aten.unflatten.int %arg0, %int1, %0 : !torch.vtensor<[2,6,4],f32>, !torch.int, !torch.list -> !torch.vtensor<[2,2,3,4],f32> + return %1 : !torch.vtensor<[2,2,3,4],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.unflatten.int$negative_dim +// CHECK: torch_c.to_builtin_tensor +// CHECK: tensor.expand_shape +// CHECK: torch_c.from_builtin_tensor +func.func @torch.aten.unflatten.int$negative_dim(%arg0: !torch.vtensor<[2,6,4],f32>) -> !torch.vtensor<[2,2,3,4],f32> { + %int-2 = torch.constant.int -2 + %int2 = torch.constant.int 2 + %int3 = torch.constant.int 3 + %0 = torch.prim.ListConstruct %int2, %int3 : (!torch.int, !torch.int) -> !torch.list + %1 = torch.aten.unflatten.int %arg0, %int-2, %0 : !torch.vtensor<[2,6,4],f32>, !torch.int, !torch.list -> !torch.vtensor<[2,2,3,4],f32> + return %1 : !torch.vtensor<[2,2,3,4],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.unflatten.int$inferred_size +// CHECK: torch_c.to_builtin_tensor +// CHECK: tensor.expand_shape +// CHECK: torch_c.from_builtin_tensor +func.func @torch.aten.unflatten.int$inferred_size(%arg0: !torch.vtensor<[3,12],f32>) -> !torch.vtensor<[3,2,6],f32> { + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %int-1 = torch.constant.int -1 + %0 = torch.prim.ListConstruct %int2, %int-1 : (!torch.int, !torch.int) -> !torch.list + %1 = torch.aten.unflatten.int %arg0, %int1, %0 : !torch.vtensor<[3,12],f32>, !torch.int, !torch.list -> !torch.vtensor<[3,2,6],f32> + return %1 : !torch.vtensor<[3,2,6],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.unflatten.int$dynamic_input +// CHECK: torch_c.to_builtin_tensor +// CHECK: tensor.expand_shape +// CHECK: torch_c.from_builtin_tensor +func.func @torch.aten.unflatten.int$dynamic_input(%arg0: !torch.vtensor<[?,6],f32>) -> !torch.vtensor<[?,2,3],f32> { + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %int3 = torch.constant.int 3 + %0 = torch.prim.ListConstruct %int2, %int3 : (!torch.int, !torch.int) -> !torch.list + %1 = torch.aten.unflatten.int %arg0, %int1, %0 : !torch.vtensor<[?,6],f32>, !torch.int, !torch.list -> !torch.vtensor<[?,2,3],f32> + return %1 : !torch.vtensor<[?,2,3],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.unflatten.int$two_dynamic_dims +// CHECK: torch_c.to_builtin_tensor +// CHECK: tensor.from_elements +// CHECK: tensor.reshape +// CHECK: torch_c.from_builtin_tensor +func.func @torch.aten.unflatten.int$two_dynamic_dims(%arg0: !torch.vtensor<[?,12],f32>) -> !torch.vtensor<[?,?,?],f32> { + %int1 = torch.constant.int 1 + %2 = torch.aten.size.int %arg0, %int1 : !torch.vtensor<[?,12],f32>, !torch.int -> !torch.int + %0 = torch.prim.ListConstruct %2, %2 : (!torch.int, !torch.int) -> !torch.list + %1 = torch.aten.unflatten.int %arg0, %int1, %0 : !torch.vtensor<[?,12],f32>, !torch.int, !torch.list -> !torch.vtensor<[?,?,?],f32> + return %1 : !torch.vtensor<[?,?,?],f32> +}