Skip to content

Commit 803a91b

Browse files
committed
Addressed the feedback
1 parent 6fefb77 commit 803a91b

File tree

2 files changed

+27
-23
lines changed

2 files changed

+27
-23
lines changed

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1911,37 +1911,41 @@ class DecomposeAtenOuterOp : public OpRewritePattern<AtenOuterOp> {
19111911
auto inputType = cast<BaseTensorType>(input.getType());
19121912
auto vec2Type = cast<BaseTensorType>(vec2.getType());
19131913

1914+
// Check if tensors not empty
1915+
if (!inputType.hasSizes() || !vec2Type.hasSizes()) {
1916+
return rewriter.notifyMatchFailure(
1917+
op, "Inputs must be ranked tensors for aten.outer");
1918+
}
1919+
19141920
// Check if both tensors are 1-dimensional
19151921
SmallVector<int64_t> inputShape(inputType.getSizes());
19161922
SmallVector<int64_t> vec2Shape(vec2Type.getSizes());
19171923

1918-
if (inputShape.size() == 1 && vec2Shape.size() == 1) {
1924+
if (inputShape.size() != 1 || vec2Shape.size() != 1) {
1925+
return rewriter.notifyMatchFailure(
1926+
op, "Inputs must be 1-dimensional vectors for aten.outer");
1927+
}
19191928

1920-
Value one = rewriter.create<Torch::ConstantIntOp>(
1921-
loc, rewriter.getI64IntegerAttr(1)); // Dimension index
1922-
SmallVector<int64_t, 2> inputMatrixShape = {inputShape[0], 1};
1923-
Type inputMatrixType = inputType.getWithSizesAndDtype(
1924-
inputMatrixShape, inputType.getOptionalDtype());
1929+
Value one = rewriter.create<Torch::ConstantIntOp>(
1930+
loc, rewriter.getI64IntegerAttr(1)); // Dimension index
1931+
SmallVector<int64_t, 2> inputMatrixShape = {inputShape[0], 1};
1932+
Type inputMatrixType = inputType.getWithSizesAndDtype(
1933+
inputMatrixShape, inputType.getOptionalDtype());
19251934

1926-
Value inputMatrix =
1927-
rewriter.create<AtenUnsqueezeOp>(loc, inputMatrixType, input, one);
1935+
Value inputMatrix =
1936+
rewriter.create<AtenUnsqueezeOp>(loc, inputMatrixType, input, one);
19281937

1929-
Value zero = rewriter.create<Torch::ConstantIntOp>(
1930-
loc, rewriter.getI64IntegerAttr(0));
1931-
SmallVector<int64_t, 2> vec2MatrixShape = {1, vec2Shape[0]};
1932-
Type vec2MatrixType = vec2Type.getWithSizesAndDtype(
1933-
vec2MatrixShape, vec2Type.getOptionalDtype());
1934-
1935-
Value vec2Matrix =
1936-
rewriter.create<AtenUnsqueezeOp>(loc, vec2MatrixType, vec2, zero);
1938+
Value zero = rewriter.create<Torch::ConstantIntOp>(
1939+
loc, rewriter.getI64IntegerAttr(0));
1940+
SmallVector<int64_t, 2> vec2MatrixShape = {1, vec2Shape[0]};
1941+
Type vec2MatrixType = vec2Type.getWithSizesAndDtype(
1942+
vec2MatrixShape, vec2Type.getOptionalDtype());
19371943

1938-
rewriter.replaceOpWithNewOp<AtenMatmulOp>(op, opType, inputMatrix,
1939-
vec2Matrix);
1940-
return success();
1941-
} else {
1942-
return failure();
1943-
}
1944+
Value vec2Matrix =
1945+
rewriter.create<AtenUnsqueezeOp>(loc, vec2MatrixType, vec2, zero);
19441946

1947+
rewriter.replaceOpWithNewOp<AtenMatmulOp>(op, opType, inputMatrix,
1948+
vec2Matrix);
19451949
return success();
19461950
}
19471951
};

projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -941,7 +941,7 @@ def forward(self, lhs, rhs):
941941

942942
@register_test_case(module_factory=lambda: AtenOuter())
943943
def AtenOuter_basic(module, tu: TestUtils):
944-
module.forward(tu.rand(3), tu.rand(3))
944+
module.forward(tu.rand(3), tu.rand(2))
945945

946946

947947
# ==============================================================================

0 commit comments

Comments
 (0)