Skip to content

Commit 4df6def

Browse files
committed
Addressed the problem with testing
1 parent 73dfc8a commit 4df6def

File tree

2 files changed

+6
-6
lines changed
  • lib/Conversion/TorchToLinalg
  • projects/pt1/python/torch_mlir_e2e_test/test_suite

2 files changed

+6
-6
lines changed

lib/Conversion/TorchToLinalg/Linear.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1688,8 +1688,8 @@ class ConvertAtenOuterOp : public OpConversionPattern<AtenOuterOp> {
16881688
if (failed(verifyLinalgCompatibleTypes(op, rewriter))) {
16891689
return failure();
16901690
}
1691-
auto lhsType = cast<RankedTensorType>(lhs.getType());
1692-
auto rhsType = cast<RankedTensorType>(rhs.getType());
1691+
auto lhsType = dyn_cast<RankedTensorType>(lhs.getType());
1692+
auto rhsType = dyn_cast<RankedTensorType>(rhs.getType());
16931693

16941694
if (!lhsType || !rhsType)
16951695
return rewriter.notifyMatchFailure(op,
@@ -1698,8 +1698,8 @@ class ConvertAtenOuterOp : public OpConversionPattern<AtenOuterOp> {
16981698
return rewriter.notifyMatchFailure(
16991699
op, "outer: expected 1D tensors for outer op lowering");
17001700

1701-
Value lhsDim = getDimOp(rewriter, loc, lhs, 1);
1702-
Value rhsDim = getDimOp(rewriter, loc, rhs, 1);
1701+
Value lhsDim = getDimOp(rewriter, loc, lhs, 0);
1702+
Value rhsDim = getDimOp(rewriter, loc, rhs, 0);
17031703
Type elementType = lhsType.getElementType();
17041704
Type newResultType = getTypeConverter()->convertType(op.getType());
17051705

projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -931,8 +931,8 @@ def __init__(self):
931931
@annotate_args(
932932
[
933933
None,
934-
([-1], torch.float32, True),
935-
([-1], torch.float32, True),
934+
([3], torch.float32, True),
935+
([3], torch.float32, True),
936936
]
937937
)
938938
def forward(self, lhs, rhs):

0 commit comments

Comments
 (0)