Skip to content

Commit 446292f

Browse files
committed
Addressed the problem with testing
1 parent d5623a7 commit 446292f

File tree

2 files changed

+6
-6
lines changed
  • lib/Conversion/TorchToLinalg
  • projects/pt1/python/torch_mlir_e2e_test/test_suite

2 files changed

+6
-6
lines changed

lib/Conversion/TorchToLinalg/Linear.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1774,8 +1774,8 @@ class ConvertAtenOuterOp : public OpConversionPattern<AtenOuterOp> {
17741774
if (failed(verifyLinalgCompatibleTypes(op, rewriter))) {
17751775
return failure();
17761776
}
1777-
auto lhsType = cast<RankedTensorType>(lhs.getType());
1778-
auto rhsType = cast<RankedTensorType>(rhs.getType());
1777+
auto lhsType = dyn_cast<RankedTensorType>(lhs.getType());
1778+
auto rhsType = dyn_cast<RankedTensorType>(rhs.getType());
17791779

17801780
if (!lhsType || !rhsType)
17811781
return rewriter.notifyMatchFailure(op,
@@ -1784,8 +1784,8 @@ class ConvertAtenOuterOp : public OpConversionPattern<AtenOuterOp> {
17841784
return rewriter.notifyMatchFailure(
17851785
op, "outer: expected 1D tensors for outer op lowering");
17861786

1787-
Value lhsDim = getDimOp(rewriter, loc, lhs, 1);
1788-
Value rhsDim = getDimOp(rewriter, loc, rhs, 1);
1787+
Value lhsDim = getDimOp(rewriter, loc, lhs, 0);
1788+
Value rhsDim = getDimOp(rewriter, loc, rhs, 0);
17891789
Type elementType = lhsType.getElementType();
17901790
Type newResultType = getTypeConverter()->convertType(op.getType());
17911791

projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -931,8 +931,8 @@ def __init__(self):
931931
@annotate_args(
932932
[
933933
None,
934-
([-1], torch.float32, True),
935-
([-1], torch.float32, True),
934+
([3], torch.float32, True),
935+
([3], torch.float32, True),
936936
]
937937
)
938938
def forward(self, lhs, rhs):

0 commit comments

Comments
 (0)