Skip to content

Commit 396ab35

Browse files
authored
Small fixes for slice edge cases (#476)
1 parent a6c3050 commit 396ab35

File tree

3 files changed

+24
-14
lines changed

3 files changed

+24
-14
lines changed

e2e_testing/torchscript/slice_like.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ def SliceModule_basic(module, tu: TestUtils):
3232

3333
# ==============================================================================
3434

35-
# This Test currently xfails due to https://github.com/llvm/torch-mlir/issues/448
3635
class SliceOutOfUpperBoundIndexModule(torch.nn.Module):
3736
def __init__(self):
3837
super().__init__()
@@ -43,8 +42,11 @@ def __init__(self):
4342
([-1, -1, -1], torch.float32, True),
4443
])
4544
def forward(self, x):
46-
return x[:8, :5, 8:]
47-
45+
# TODO: remove hacky cat tensor once refbackend supports 0 size dim
46+
result = x[:8, :5, 8:]
47+
cat_tensor = torch.ones((6,4,1), dtype=torch.float32)
48+
return torch.cat((result,cat_tensor), dim=2)
49+
4850

4951
@register_test_case(module_factory=lambda: SliceOutOfUpperBoundIndexModule())
5052
def SliceOutOfUpperBoundIndexModule_basic(module, tu: TestUtils):
@@ -90,7 +92,7 @@ def SliceOutOfLowerBoundStartIndexModule_basic(module, tu: TestUtils):
9092

9193
# ==============================================================================
9294

93-
# This Test currently xfails due to https://github.com/llvm/torch-mlir/issues/448
95+
9496
class SliceEndSleStartModule(torch.nn.Module):
9597
def __init__(self):
9698
super().__init__()
@@ -101,7 +103,10 @@ def __init__(self):
101103
([-1, -1, -1], torch.float32, True),
102104
])
103105
def forward(self, x):
104-
return x[:0, 4:3, :-7]
106+
# TODO: remove hacky cat tensor once refbackend supports 0 size dim
107+
result = x[:, 4:3, :]
108+
cat_tensor = torch.ones((6,1,7), dtype=torch.float32)
109+
return torch.cat((result, cat_tensor), dim=1)
105110

106111

107112
@register_test_case(module_factory=lambda: SliceEndSleStartModule())
@@ -110,7 +115,7 @@ def SliceEndSleStartModule_basic(module, tu: TestUtils):
110115

111116
# ==============================================================================
112117

113-
# This Test currently xfails due to https://github.com/llvm/torch-mlir/issues/448
118+
114119
class SliceStartEqEndModule(torch.nn.Module):
115120
def __init__(self):
116121
super().__init__()
@@ -121,7 +126,10 @@ def __init__(self):
121126
([-1, -1, -1], torch.float32, True),
122127
])
123128
def forward(self, x):
124-
return x[5:5, 3:3, -1:]
129+
# TODO: remove hacky cat tensor once refbackend supports 0 size dim
130+
result = x[5:5, :, :]
131+
cat_tensor = torch.ones((1,4,7), dtype=torch.float32)
132+
return torch.cat((result, cat_tensor), dim=0)
125133

126134

127135
@register_test_case(module_factory=lambda: SliceStartEqEndModule())

e2e_testing/torchscript/xfail_sets.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,7 @@
1717
"QuantizedMLP_basic",
1818
"IouOfModule_basic",
1919
}
20-
# Fails due to https://github.com/llvm/torch-mlir/issues/448
21-
SIZE_ZERO_TENSOR_XFAILS = {
22-
"SliceEndSleStartModule_basic",
23-
"SliceStartEqEndModule_basic",
24-
"SliceOutOfUpperBoundIndexModule_basic",
25-
}
26-
REFBACKEND_XFAIL_SET = set.union(COMMON_TORCH_MLIR_LOWERING_XFAILS, SIZE_ZERO_TENSOR_XFAILS)
20+
REFBACKEND_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS
2721

2822
# Write the TOSA set as a "passing" set as it is very early in development
2923
# and very few tests work yet.

lib/Conversion/TorchToLinalg/TorchToLinalg.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3047,9 +3047,17 @@ class ConvertAtenSliceTensorOp : public OpConversionPattern<AtenSliceTensorOp> {
30473047
return castIntToIndex(rewriter, loc, startOrEndBoundedByDimSize);
30483048
};
30493049

3050+
if (op.start().getType().isa<OptionalType>() ||
3051+
op.end().getType().isa<OptionalType>())
3052+
return rewriter.notifyMatchFailure(op, "unimplemented optional type arg");
30503053
Value start = adjustStartOrEnd(op.start(), adaptor.start(), zero);
30513054
Value end = adjustStartOrEnd(op.end(), adaptor.end(), dimSize);
30523055

3056+
// end >= start ? end : start
3057+
Value endSgeStart = rewriter.create<arith::CmpIOp>(
3058+
loc, arith::CmpIPredicate::sge, end, start);
3059+
end = rewriter.create<SelectOp>(loc, endSgeStart, end, start);
3060+
30533061
int64_t step;
30543062
if (!matchPattern(op.step(), m_TorchConstantInt(&step))) {
30553063
if (!op.step().getType().isa<Torch::NoneType>())

0 commit comments

Comments
 (0)