@@ -32,7 +32,6 @@ def SliceModule_basic(module, tu: TestUtils):
3232
3333# ==============================================================================
3434
35- # This Test currently xfails due to https://github.com/llvm/torch-mlir/issues/448
3635class SliceOutOfUpperBoundIndexModule (torch .nn .Module ):
3736 def __init__ (self ):
3837 super ().__init__ ()
@@ -43,8 +42,11 @@ def __init__(self):
4342 ([- 1 , - 1 , - 1 ], torch .float32 , True ),
4443 ])
4544 def forward (self , x ):
46- return x [:8 , :5 , 8 :]
47-
45+ # TODO: remove hacky cat tensor once refbackend supports 0 size dim
46+ result = x [:8 , :5 , 8 :]
47+ cat_tensor = torch .ones ((6 ,4 ,1 ), dtype = torch .float32 )
48+ return torch .cat ((result ,cat_tensor ), dim = 2 )
49+
4850
4951@register_test_case (module_factory = lambda : SliceOutOfUpperBoundIndexModule ())
5052def SliceOutOfUpperBoundIndexModule_basic (module , tu : TestUtils ):
@@ -90,7 +92,7 @@ def SliceOutOfLowerBoundStartIndexModule_basic(module, tu: TestUtils):
9092
9193# ==============================================================================
9294
93- # This Test currently xfails due to https://github.com/llvm/torch-mlir/issues/448
95+
9496class SliceEndSleStartModule (torch .nn .Module ):
9597 def __init__ (self ):
9698 super ().__init__ ()
@@ -101,7 +103,10 @@ def __init__(self):
101103 ([- 1 , - 1 , - 1 ], torch .float32 , True ),
102104 ])
103105 def forward (self , x ):
104- return x [:0 , 4 :3 , :- 7 ]
106+ # TODO: remove hacky cat tensor once refbackend supports 0 size dim
107+ result = x [:, 4 :3 , :]
108+ cat_tensor = torch .ones ((6 ,1 ,7 ), dtype = torch .float32 )
109+ return torch .cat ((result , cat_tensor ), dim = 1 )
105110
106111
107112@register_test_case (module_factory = lambda : SliceEndSleStartModule ())
@@ -110,7 +115,7 @@ def SliceEndSleStartModule_basic(module, tu: TestUtils):
110115
111116# ==============================================================================
112117
113- # This Test currently xfails due to https://github.com/llvm/torch-mlir/issues/448
118+
114119class SliceStartEqEndModule (torch .nn .Module ):
115120 def __init__ (self ):
116121 super ().__init__ ()
@@ -121,7 +126,10 @@ def __init__(self):
121126 ([- 1 , - 1 , - 1 ], torch .float32 , True ),
122127 ])
123128 def forward (self , x ):
124- return x [5 :5 , 3 :3 , - 1 :]
129+ # TODO: remove hacky cat tensor once refbackend supports 0 size dim
130+ result = x [5 :5 , :, :]
131+ cat_tensor = torch .ones ((1 ,4 ,7 ), dtype = torch .float32 )
132+ return torch .cat ((result , cat_tensor ), dim = 0 )
125133
126134
127135@register_test_case (module_factory = lambda : SliceStartEqEndModule ())
0 commit comments