Skip to content

Commit a52aded

Browse files
authored
Add lowering for slice and selectInt (#398)
1 parent 46a2189 commit a52aded

File tree

9 files changed

+416
-15
lines changed

9 files changed

+416
-15
lines changed

e2e_testing/torchscript/basic.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,8 @@ def forward(self, lhs, rhs):
100100
def matmul(self, lhs, rhs):
101101
return torch.mm(lhs, rhs)
102102

103+
# ==============================================================================
104+
103105

104106
@register_test_case(module_factory=lambda: MmTanhModule())
105107
def MmTanhModule_basic(module, tu: TestUtils):
@@ -192,6 +194,8 @@ def forward(self, x):
192194
def AdaptiveAvgPool2dModule_basic(module, tu: TestUtils):
193195
module.forward(tu.rand(10, 3, 8, 9))
194196

197+
# ==============================================================================
198+
195199

196200
class FlattenStaticModule(torch.nn.Module):
197201
def __init__(self):
@@ -211,6 +215,8 @@ def forward(self, x):
211215
def FlattenStaticModule_basic(module, tu: TestUtils):
212216
module.forward(tu.rand(10, 3, 8, 9, 3, 4))
213217

218+
# ==============================================================================
219+
214220

215221
class FlattenRank0Module(torch.nn.Module):
216222
def __init__(self):
@@ -230,6 +236,8 @@ def forward(self, x):
230236
def FlattenRank0Module_basic(module, tu: TestUtils):
231237
module.forward(torch.tensor(4.0))
232238

239+
# ==============================================================================
240+
233241

234242
class FlattenDynamicModule(torch.nn.Module):
235243
def __init__(self):
@@ -249,6 +257,8 @@ def forward(self, x):
249257
def FlattenDynamicModule_basic(module, tu: TestUtils):
250258
module.forward(tu.rand(10, 3, 8, 9, 3, 4))
251259

260+
# ==============================================================================
261+
252262

253263
class MaxPool2dModule(torch.nn.Module):
254264
def __init__(self):
@@ -266,6 +276,8 @@ def __init__(self):
266276
def forward(self, x):
267277
return self.mp2d(x)
268278

279+
# ==============================================================================
280+
269281

270282
@register_test_case(module_factory=lambda: MaxPool2dModule())
271283
def MaxPool2dModule_basic(module, tu: TestUtils):
@@ -284,6 +296,8 @@ def __init__(self):
284296
def forward(self, x):
285297
return torch.transpose(x, 0, 1)
286298

299+
# ==============================================================================
300+
287301

288302
@register_test_case(module_factory=lambda: TransposeIntModule())
289303
def TransposeIntModule_basic(module, tu: TestUtils):
@@ -305,6 +319,8 @@ def forward(self, x):
305319
def PermuteModule_basic(module, tu: TestUtils):
306320
module.forward(tu.rand(3, 4, 2))
307321

322+
# ==============================================================================
323+
308324
class TransposeIntNegDimsModule(torch.nn.Module):
309325
def __init__(self):
310326
super().__init__()
@@ -317,6 +333,8 @@ def __init__(self):
317333
def forward(self, x):
318334
return torch.transpose(x, -1, -2)
319335

336+
# ==============================================================================
337+
320338

321339
@register_test_case(module_factory=lambda: TransposeIntNegDimsModule())
322340
def TransposeIntNegDimsModule_basic(module, tu: TestUtils):
@@ -335,6 +353,8 @@ def __init__(self):
335353
def forward(self, x):
336354
return x.permute(0, -1, 1)
337355

356+
# ==============================================================================
357+
338358
@register_test_case(module_factory=lambda: PermuteNegativeIndexModule())
339359
def PermuteNegativeIndexModule_basic(module, tu: TestUtils):
340360
module.forward(tu.rand(3, 4, 2))
@@ -357,6 +377,8 @@ def forward(self, x, y, z):
357377
def TensorsConcatModule_basic(module, tu: TestUtils):
358378
module.forward(tu.rand(2, 2, 4), tu.rand(2, 1, 4), tu.rand(2, 3, 4))
359379

380+
# ==============================================================================
381+
360382

361383
class GatherModule(torch.nn.Module):
362384
def __init__(self):
@@ -376,6 +398,8 @@ def forward(self, tensor, indices):
376398
def GatherModule_basic(module, tu: TestUtils):
377399
module.forward(tu.rand(2, 3, 4), torch.tensor([[[1, 2, 3], [1, 2, 3]]]))
378400

401+
# ==============================================================================
402+
379403
class AddSizeIntModule(torch.nn.Module):
380404
def __init__(self):
381405
super().__init__()
@@ -396,6 +420,8 @@ def forward(self, tensor):
396420
def AddSizeIntModule_basic(module, tu: TestUtils):
397421
module.forward(torch.randn(3, 3))
398422

423+
# ==============================================================================
424+
399425

400426
class AddSizeIntNegDimModule(torch.nn.Module):
401427
def __init__(self):
@@ -417,6 +443,8 @@ def forward(self, tensor):
417443
def AddSizeIntNegDimModule_basic(module, tu: TestUtils):
418444
module.forward(torch.randn(3, 3))
419445

446+
# ==============================================================================
447+
420448
class EmbeddingModule(torch.nn.Module):
421449
def __init__(self):
422450
super().__init__()
@@ -438,6 +466,7 @@ def forward(self, indices):
438466
def EmbeddingModule_basic(module, tu: TestUtils):
439467
module.forward(torch.randint(100, (3, 3)))
440468

469+
# ==============================================================================
441470

442471
class SoftmaxIntModule(torch.nn.Module):
443472
def __init__(self):
@@ -474,6 +503,8 @@ def forward(self, tensor):
474503
def _SoftmaxModule_basic(module, tu: TestUtils):
475504
module.forward(torch.randn(3, 2, 4))
476505

506+
# ==============================================================================
507+
477508

478509
class SoftmaxIntNegDimModule(torch.nn.Module):
479510
def __init__(self):
@@ -494,6 +525,8 @@ def forward(self, tensor):
494525
def SoftmaxIntNegDimModule_basic(module, tu: TestUtils):
495526
module.forward(torch.randn(3, 2, 4))
496527

528+
# ==============================================================================
529+
497530

498531
class SoftmaxIntArgTypeF64Module(torch.nn.Module):
499532
def __init__(self):
@@ -513,6 +546,7 @@ def forward(self, tensor):
513546
def SoftmaxIntArgTypeF64Module_basic(module, tu: TestUtils):
514547
module.forward(torch.randn(3, 2, 4).double())
515548

549+
# ==============================================================================
516550

517551
class BroadcastToModule(torch.nn.Module):
518552
def __init__(self):
@@ -531,6 +565,8 @@ def forward(self, x):
531565
def BroadcastToModule_basic(module, tu: TestUtils):
532566
module.forward(tu.rand(3, 1, 1))
533567

568+
# ==============================================================================
569+
534570
class ExpandModule(torch.nn.Module):
535571
def __init__(self):
536572
super().__init__()
@@ -548,6 +584,9 @@ def forward(self, x):
548584
def ExpandModule_basic(module, tu: TestUtils):
549585
module.forward(tu.rand(3, 1, 1))
550586

587+
# ==============================================================================
588+
589+
551590
class OnesModuleInt(torch.nn.Module):
552591
def __init__(self):
553592
super().__init__()
@@ -563,6 +602,8 @@ def forward(self):
563602
def OnesModuleInt_basic(module, tu: TestUtils):
564603
module.forward()
565604

605+
# ==============================================================================
606+
566607
class OnesModuleFloat(torch.nn.Module):
567608
def __init__(self):
568609
super().__init__()
@@ -594,6 +635,7 @@ def forward(self):
594635
def OnesModuleFalsePinMemory_basic(module, tu: TestUtils):
595636
module.forward()
596637

638+
# ==============================================================================
597639

598640
class ContiguousModule(torch.nn.Module):
599641
def __init__(self):
@@ -611,7 +653,7 @@ def forward(self, x):
611653
@register_test_case(module_factory=lambda: ContiguousModule())
612654
def ContiguousModule_basic(module, tu: TestUtils):
613655
module.forward(tu.rand(3, 1))
614-
656+
615657
class TensorToInt(torch.nn.Module):
616658
def __init__(self):
617659
super().__init__()
@@ -681,6 +723,7 @@ def forward(self):
681723
def NumToTensorFloatModule_basic(module, tu: TestUtils):
682724
module.forward()
683725

726+
# ==============================================================================
684727

685728
# This test can be removed once we have one real op returning 3 float32 tensors
686729
class ReturnThreeTensorFloat32(torch.nn.Module):

e2e_testing/torchscript/main.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from . import view
4343
from . import scalar
4444
from . import squeeze
45+
from . import slice_like
4546

4647
def _get_argparse():
4748
config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'external']

0 commit comments

Comments
 (0)