Skip to content

Commit 2b49f98

Browse files
alexbeloiWei Wei
authored andcommitted
[fx][acc_ops] add acc_op mapper for torch.tensor_split (#66)
Summary: Pull Request resolved: https://github.com/pytorch/fx2trt/pull/66 Reviewed By: 842974287 Differential Revision: D36126848 fbshipit-source-id: a8324bd773105be7cbdad00681cb971aa09f1822
1 parent 6cf5620 commit 2b49f98

File tree

2 files changed

+75
-1
lines changed

2 files changed

+75
-1
lines changed

test/tracer/test_acc_tracer.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2336,6 +2336,43 @@ def forward(self, a, b, c):
23362336
res = traced(cond, x, y)
23372337
self.assertTrue(torch.equal(ref, res))
23382338

2339+
@parameterized.expand(
2340+
[
2341+
("sections divisible", 2, 0),
2342+
("sections indivisible", 3, 0),
2343+
("indices list", [1, 3], 0),
2344+
("indices tuple", (1, 3), 0),
2345+
("indices tensor", torch.tensor([1, 3]), 0),
2346+
("indices tensor dim1", torch.tensor([1, 3]), 1),
2347+
("indices tensor dim2", torch.tensor([1, 3]), 2),
2348+
("indices tensor long dim2", torch.tensor([1, 3, 5, 7]), 2),
2349+
]
2350+
)
2351+
def test_tensor_split(self, _, indices_or_sections, dim):
2352+
"""
2353+
Test that the tracer works for torch.tensor_split with indices and sections
2354+
"""
2355+
2356+
class TestModule(nn.Module):
2357+
def __init__(self, indices_or_sections, dim):
2358+
super().__init__()
2359+
self._indices_or_sections = indices_or_sections
2360+
self._dim = dim
2361+
2362+
def forward(self, a):
2363+
return torch.tensor_split(a, self._indices_or_sections, self._dim)
2364+
2365+
m = TestModule(indices_or_sections, dim)
2366+
a = torch.randn(4, 8, 16)
2367+
traced = acc_tracer.trace(m, [a])
2368+
2369+
results = traced(a)
2370+
references = m(a)
2371+
for res, ref in zip(results, references):
2372+
self.assertTrue(
2373+
torch.equal(ref, res), f"Tensors at don't match {ref=} {res=}"
2374+
)
2375+
23392376
def test_all_acc_ops_registered(self):
23402377
self.assertEqual(
23412378
acc_normalizer._acc_ops,
@@ -2453,5 +2490,6 @@ def test_all_acc_ops_registered(self):
24532490
acc_ops.dtype,
24542491
acc_ops.isinf,
24552492
acc_ops.any,
2493+
acc_ops.tensor_split,
24562494
},
24572495
)

tracer/acc_tracer/acc_ops.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import warnings
44

55
import torch # isort:skip
6-
from typing import Sequence, List, cast
6+
from typing import Iterable, Sequence, List, cast
77

88
import fx2trt_oss.tracer.acc_tracer.acc_utils as acc_utils
99
import torch.nn as nn
@@ -2616,3 +2616,39 @@ def interpolate(
26162616
mode=mode,
26172617
align_corners=align_corners,
26182618
)
2619+
2620+
2621+
@register_acc_op_mapping(
2622+
op_and_target=("call_function", torch.tensor_split),
2623+
arg_replacement_tuples=[
2624+
("input", "input"),
2625+
(("tensor_indices_or_sections", "sections", "indices"), "indices_or_sections"),
2626+
("dim", "dim", this_arg_is_optional),
2627+
],
2628+
)
2629+
@register_acc_op_mapping(
2630+
op_and_target=("call_method", "tensor_split"),
2631+
arg_replacement_tuples=[
2632+
("input", "input"),
2633+
(("tensor_indices_or_sections", "sections", "indices"), "indices_or_sections"),
2634+
("dim", "dim", this_arg_is_optional),
2635+
],
2636+
)
2637+
@register_acc_op
2638+
def tensor_split(*, input, indices_or_sections, dim=0):
2639+
# Need to de-coalesce the indices_or_sections because tensor_split accepts
2640+
# one of three kwarg signatures:
2641+
# * (Tensor input, Tensor tensor_indices_or_sections, int dim)
2642+
# * (Tensor input, int sections, int dim)
2643+
# * (Tensor input, tuple of ints indices, int dim)
2644+
if isinstance(indices_or_sections, torch.Tensor):
2645+
indices_or_sections = indices_or_sections.tolist()
2646+
if isinstance(indices_or_sections, int):
2647+
return torch.tensor_split(input, sections=indices_or_sections, dim=dim)
2648+
elif isinstance(indices_or_sections, Iterable):
2649+
return torch.tensor_split(input, indices=tuple(indices_or_sections), dim=dim)
2650+
else:
2651+
raise RuntimeError(
2652+
f"Expected int, Iterable or Tensor for "
2653+
f"indices_or_sections arg, got: {type(indices_or_sections)}"
2654+
)

0 commit comments

Comments
 (0)