Skip to content

Commit 0e88737

Browse files
committed
Addressed the feedback
1 parent 4df6def commit 0e88737

File tree

2 files changed

+39
-3
lines changed
  • lib/Conversion/TorchToLinalg
  • projects/pt1/python/torch_mlir_e2e_test/test_suite

2 files changed

+39
-3
lines changed

lib/Conversion/TorchToLinalg/Linear.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1683,7 +1683,7 @@ class ConvertAtenOuterOp : public OpConversionPattern<AtenOuterOp> {
16831683

16841684
Location loc = op->getLoc();
16851685
Value lhs = adaptor.getSelf();
1686-
Value rhs = op->getOperand(1);
1686+
Value rhs = adaptor.getVec2();
16871687

16881688
if (failed(verifyLinalgCompatibleTypes(op, rewriter))) {
16891689
return failure();
@@ -1704,8 +1704,10 @@ class ConvertAtenOuterOp : public OpConversionPattern<AtenOuterOp> {
17041704
Type newResultType = getTypeConverter()->convertType(op.getType());
17051705

17061706
// Create a zero-initialized tensor with shape [lhsDim, rhsDim]
1707-
Value initTensor = createInitTensor(
1708-
rewriter, loc, ValueRange{lhsDim, rhsDim}, elementType, NULL);
1707+
SmallVector<OpFoldResult> resultShape =
1708+
getAsOpFoldResult(ValueRange{lhsDim, rhsDim});
1709+
Value initTensor =
1710+
rewriter.create<tensor::EmptyOp>(loc, resultShape, elementType);
17091711

17101712
// Set up affine indexing maps:
17111713
// We create a 2D loop iteration space. For the lhs, we use the first index

projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -942,3 +942,37 @@ def forward(self, lhs, rhs):
942942
@register_test_case(module_factory=lambda: AtenOuter())
943943
def AtenOuter_basic(module, tu: TestUtils):
944944
module.forward(tu.rand(3), tu.rand(3))
945+
946+
947+
# ==============================================================================
948+
949+
950+
class AtenOuterDynamic(torch.nn.Module):
951+
def __init__(self):
952+
super().__init__()
953+
954+
@export
955+
@annotate_args(
956+
[
957+
None,
958+
([-1], torch.float32, True),
959+
([-1], torch.float32, True),
960+
]
961+
)
962+
def forward(self, lhs, rhs):
963+
return torch.outer(lhs, rhs)
964+
965+
966+
@register_test_case(module_factory=lambda: AtenOuterDynamic())
967+
def AtenOuterDynamic_basic(module, tu: TestUtils):
968+
module.forward(tu.rand(5), tu.rand(5))
969+
970+
971+
@register_test_case(module_factory=lambda: AtenOuterDynamic())
972+
def AtenOuterDynamic_lhs_larger(module, tu: TestUtils):
973+
module.forward(tu.rand(7), tu.rand(4))
974+
975+
976+
@register_test_case(module_factory=lambda: AtenOuterDynamic())
977+
def AtenOuterDynamic_rhs_larger(module, tu: TestUtils):
978+
module.forward(tu.rand(2), tu.rand(6))

0 commit comments

Comments
 (0)