Skip to content

Commit 5d80f41

Browse files
author
Wei Wei
committed
[fx2trt] improve pad/permute/setitem/getitem op (#72)
Summary: Pull Request resolved: https://github.com/pytorch/fx2trt/pull/72 1. support value in pad op 2. permute op improvement since it could not handle x.permute(permutation) 3. improve setitem: 1) fix issue with slice(-n,None,None) 4. improve getitem: 1) support x[slice(None,None,None), slice(0,0,None)] This case is needed for setitem pass optimization. When we need to split a range where the dimension is (20,10) (slice(None,None,None), slice(-10,None,None)) --> (slice(None,None,None), slice(0,-10,None)) + (slice(None,None,None), slice(-10,None,None)) Reviewed By: yinghai Differential Revision: D36223023 fbshipit-source-id: 7d87c74b85b5c6c6efcd46dac209a46af208d3f9
1 parent d418a37 commit 5d80f41

File tree

5 files changed

+75
-23
lines changed

5 files changed

+75
-23
lines changed

fx/converters/acc_ops_converters.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -332,10 +332,12 @@ def acc_ops_pad_with_slice_layer(
332332
f"Trying to pad last {len(pad) / 2} dimension but the input only has {rank} dimension."
333333
)
334334

335-
if value != 0:
336-
raise RuntimeError(
337-
f"Currently we only support padding value of 0, got {value}."
338-
)
335+
# cast value to TRTensor
336+
dt = torch_dtype_from_trt(input_val.dtype)
337+
value = 0 if value == None else value
338+
value_const = get_trt_tensor(
339+
network, torch.tensor([value], dtype=dt), f"{name}_value"
340+
)
339341

340342
input_shape = input_val.shape
341343
pre_start = tuple(i - 1 for i in input_shape)
@@ -352,6 +354,7 @@ def acc_ops_pad_with_slice_layer(
352354
pre_shape,
353355
pre_stride,
354356
)
357+
layer.set_input(4, value_const)
355358
layer.mode = trt.SliceMode.FILL
356359
set_layer_name(layer, target, f"pre_{name}")
357360
half_pad_output = layer.get_output(0)
@@ -360,6 +363,7 @@ def acc_ops_pad_with_slice_layer(
360363
mid_start = tuple(i - 1 for i in shape)
361364
mid_stride = [-1] * len(shape)
362365
layer = network.add_slice(half_pad_output, mid_start, shape, mid_stride)
366+
layer.set_input(4, value_const)
363367
layer.mode = trt.SliceMode.FILL
364368
set_layer_name(layer, target, f"transpose_{name}")
365369
transpose_output = layer.get_output(0)
@@ -373,6 +377,7 @@ def acc_ops_pad_with_slice_layer(
373377
post_stride = tuple([1] * len(shape))
374378

375379
layer = network.add_slice(transpose_output, post_start, post_shape, post_stride)
380+
layer.set_input(4, value_const)
376381
layer.mode = trt.SliceMode.FILL
377382
set_layer_name(layer, target, f"post_{name}")
378383
return layer.get_output(0)
@@ -2776,9 +2781,15 @@ def slice_to_trt_params(py_slice, dim_size):
27762781
"""
27772782
Convert python slice to TensorRT slice layer parameters.
27782783
"""
2779-
start = get_positive_dim(py_slice.start, dim_size) if py_slice.start else 0
2780-
stride = py_slice.step if py_slice.step else 1
2781-
stop = get_positive_dim(py_slice.stop, dim_size) if py_slice.stop else dim_size
2784+
start = (
2785+
get_positive_dim(py_slice.start, dim_size) if py_slice.start != None else 0
2786+
)
2787+
stride = py_slice.step if py_slice.step != None else 1
2788+
stop = (
2789+
get_positive_dim(py_slice.stop, dim_size)
2790+
if py_slice.stop != None
2791+
else dim_size
2792+
)
27822793
size = math.ceil((stop - start) * 1.0 / stride)
27832794
return start, size, stride
27842795

@@ -2989,9 +3000,11 @@ def acc_ops_permute(
29893000
) -> Union[TRTTensor, Sequence[TRTTensor]]:
29903001
input_val = kwargs["input"]
29913002
ranks = len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0) # type: ignore[union-attr]
2992-
permutation = [
2993-
get_positive_dim(i, ranks) for i in cast(Sequence[int], kwargs["permutation"])
2994-
]
3003+
if len(kwargs["permutation"]) == 1:
3004+
index = kwargs["permutation"][0]
3005+
else:
3006+
index = kwargs["permutation"]
3007+
permutation = [get_positive_dim(i, ranks) for i in cast(Sequence[int], index)]
29953008

29963009
if not isinstance(input_val, TRTTensor):
29973010
raise RuntimeError(

fx/passes/lower_basic_pass.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import copy
22
import operator
3-
import operator
43
import warnings
54
from typing import Any
65

@@ -297,7 +296,7 @@ def split_across(
297296
start_node = end_node = mid_node = None
298297
if sli.start is None and sli.stop is None:
299298
return (start_node, input_node, end_node)
300-
if sli.start is not None and sli.start > 0:
299+
if sli.start is not None:
301300
st_sli = slice(0, sli.start, None)
302301
slice_list_gen = slice_list(st_sli, dim, size)
303302
start_node = gm.graph.call_function(
@@ -364,7 +363,11 @@ def transform_setitem(gm: torch.fx.GraphModule, input: Input):
364363
for ind, val in enumerate(new_args):
365364
if type(val) == int:
366365
inp_flag = True
367-
new_args[ind] = slice(val, val + 1, None)
366+
if val == -1:
367+
new_args[ind] = slice(-1, None, None)
368+
else:
369+
new_args[ind] = slice(val, val + 1, None)
370+
368371
if inp_flag:
369372
with gm.graph.inserting_before(inp):
370373
new_node = gm.graph.call_function(
@@ -375,7 +378,18 @@ def transform_setitem(gm: torch.fx.GraphModule, input: Input):
375378

376379
if type(sli) is not tuple:
377380
sli = [sli]
378-
sli = [slice(x, x + 1, None) if type(x) == int else x for x in sli]
381+
382+
tmp_sli = []
383+
for x in sli:
384+
if type(x) == int:
385+
if x == -1:
386+
tmp_sli.append(slice(-1, None, None))
387+
else:
388+
tmp_sli.append(slice(x, x + 1, None))
389+
else:
390+
tmp_sli.append(x)
391+
sli = tmp_sli
392+
379393
dimension = len(sli)
380394
with gm.graph.inserting_before(node):
381395
if dimension == 1:

test/converters/acc_op/test_getitem.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@ class TestGetitemConverter(AccTestCase):
3737
"none",
3838
(slice(None, None, None), None, slice(1, -1, 3), 1),
3939
),
40+
(
41+
"slice_zero_slice",
42+
(slice(None, None, None), slice(None, None, None), slice(0, 0, None)),
43+
),
4044
]
4145
)
4246
def test_getitem(self, _, idx):

test/converters/acc_op/test_pad.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,37 +12,42 @@
1212
class TestPadConverter(AccTestCase):
1313
@parameterized.expand(
1414
[
15-
("1d", (1, 2)),
16-
("2d", (2, 0, 0, 1)),
15+
("1d", (1, 2), 9),
16+
("2d", (2, 0, 0, 1), 10),
1717
]
1818
)
19-
def test_pad(self, _, pad):
19+
def test_pad_value(self, _, pad, value):
2020
class Pad(nn.Module):
2121
def forward(self, x):
22-
return torch.nn.functional.pad(x, pad)
22+
return torch.nn.functional.pad(x, pad, value=value)
2323

2424
inputs = [torch.randn(1, 2, 3, 4)]
2525
self.run_test(
2626
Pad(),
2727
inputs,
2828
expected_ops={acc_ops.pad},
29+
# enable value will not work with implicit batch
30+
test_implicit_batch_dim=False,
2931
)
3032

3133
@parameterized.expand(
3234
[
33-
param("value", pad=(2, 0, 0, 1), value=1),
35+
("1d", (1, 2)),
36+
("2d", (2, 0, 0, 1)),
3437
]
3538
)
36-
def test_pad_fail(self, _, pad, mode="constant", value=0):
39+
def test_pad(self, _, pad):
3740
class Pad(nn.Module):
3841
def forward(self, x):
39-
return torch.nn.functional.pad(x, pad, mode, value)
42+
return torch.nn.functional.pad(x, pad)
4043

4144
inputs = [torch.randn(1, 2, 3, 4)]
42-
self.run_test_with_assert_error(
45+
self.run_test(
4346
Pad(),
4447
inputs,
45-
expect_error=RuntimeError,
48+
expected_ops={acc_ops.pad},
49+
# enable value will not work with implicit batch
50+
test_implicit_batch_dim=False,
4651
)
4752

4853
@parameterized.expand(
@@ -64,6 +69,8 @@ def forward(self, x):
6469
Pad(),
6570
inputs,
6671
expected_ops={acc_ops.pad},
72+
# enable value will not work with implicit batch
73+
test_implicit_batch_dim=False,
6774
)
6875

6976

test/converters/acc_op/test_permute.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,20 @@
77

88

99
class TestPermuteConverter(AccTestCase):
10+
@parameterized.expand(
11+
[
12+
("positive", [0, 2, 1]),
13+
("negative", [0, -1, -2]),
14+
]
15+
)
16+
def test_permute_list(self, _, permutation):
17+
class Permute(nn.Module):
18+
def forward(self, x):
19+
return x.permute(permutation)
20+
21+
inputs = [torch.randn(1, 3, 2)]
22+
self.run_test(Permute(), inputs, expected_ops={acc_ops.permute})
23+
1024
@parameterized.expand(
1125
[
1226
("positive", [0, 2, 1]),

0 commit comments

Comments
 (0)