diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 335277aa4462..ece87bda88ae 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -1894,6 +1894,63 @@ class DecomposeAtenAtleast1dOp : public OpRewritePattern { }; } // namespace +// Decompose 'aten.outer' into 'aten.unsqueeze', 'aten.matmul' + +namespace { +class DecomposeAtenOuterOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenOuterOp op, + PatternRewriter &rewriter) const override { + + Location loc = op.getLoc(); + Value input = op.getSelf(); + Value vec2 = op.getVec2(); + Type opType = op.getType(); + + auto inputType = cast(input.getType()); + auto vec2Type = cast(vec2.getType()); + + // Check if tensors not empty + if (!inputType.hasSizes() || !vec2Type.hasSizes()) { + return rewriter.notifyMatchFailure( + op, "Inputs must be ranked tensors for aten.outer"); + } + + // Check if both tensors are 1-dimensional + SmallVector inputShape(inputType.getSizes()); + SmallVector vec2Shape(vec2Type.getSizes()); + + if (inputShape.size() != 1 || vec2Shape.size() != 1) { + return rewriter.notifyMatchFailure( + op, "Inputs must be 1-dimensional vectors for aten.outer"); + } + + Value one = rewriter.create( + loc, rewriter.getI64IntegerAttr(1)); // Dimension index + inputShape.push_back(1); + Type inputMatrixType = inputType.getWithSizesAndDtype( + inputShape, inputType.getOptionalDtype()); + + Value inputMatrix = + rewriter.create(loc, inputMatrixType, input, one); + + Value zero = rewriter.create( + loc, rewriter.getI64IntegerAttr(0)); + vec2Shape.insert(vec2Shape.begin(), 1); + Type vec2MatrixType = + vec2Type.getWithSizesAndDtype(vec2Shape, vec2Type.getOptionalDtype()); + + Value vec2Matrix = + rewriter.create(loc, vec2MatrixType, vec2, zero); + + rewriter.replaceOpWithNewOp(op, opType, inputMatrix, + vec2Matrix); + return success(); + } +}; +} // namespace + namespace { // Decompose aten.atleast_2d into: aten.reshape. See // https://github.com/pytorch/pytorch/blob/9a8ab778d34bd24c5caceb340837483decc4c311/torch/_refs/__init__.py#L2604 @@ -11591,6 +11648,7 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal< DecomposeConstantTensorAllocLikeOp>(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal< DecomposeConstantTensorAllocLikeOp>(patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index 6695f2964b65..95bacd9fc9e6 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -379,6 +379,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 6904b4acb3c6..00a493c0db58 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -3810,6 +3810,7 @@ } ONNX_TOSA_XFAIL_SET = { + "AtenOuter_basic", "AtenFftRfft2DLastDim_basic", "AtenFftRfft2DMiddleDim_basic", "AtenNonzero1DDynamicModule_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py index 17240cf953df..25cb3cf57e35 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py @@ -918,3 +918,61 @@ def forward(self, a, b): @register_test_case(module_factory=lambda: AtenLinalgCrossDynamic()) def AtenLinalgCrossDynamic_basic(module, tu: TestUtils): module.forward(tu.rand(4, 3, 1, 6), tu.rand(4, 3, 7, 1)) + + +# ============================================================================== + + +class AtenOuter(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([3], torch.float32, True), + ([3], torch.float32, True), + ] + ) + def forward(self, lhs, rhs): + return torch.outer(lhs, rhs) + + +@register_test_case(module_factory=lambda: AtenOuter()) +def AtenOuter_basic(module, tu: TestUtils): + module.forward(tu.rand(3), tu.rand(2)) + + +# ============================================================================== + + +class AtenOuterDynamic(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1], torch.float32, True), + ([-1], torch.float32, True), + ] + ) + def forward(self, lhs, rhs): + return torch.outer(lhs, rhs) + + +@register_test_case(module_factory=lambda: AtenOuterDynamic()) +def AtenOuterDynamic_basic(module, tu: TestUtils): + module.forward(tu.rand(5), tu.rand(5)) + + +@register_test_case(module_factory=lambda: AtenOuterDynamic()) +def AtenOuterDynamic_lhs_larger(module, tu: TestUtils): + module.forward(tu.rand(7), tu.rand(4)) + + +@register_test_case(module_factory=lambda: AtenOuterDynamic()) +def AtenOuterDynamic_rhs_larger(module, tu: TestUtils): + module.forward(tu.rand(2), tu.rand(6))