Skip to content

Commit 6142f97

Browse files
author
Wei Wei
committed
[fx2trt] fix setitem (#69)
Summary: Pull Request resolved: https://github.com/pytorch/fx2trt/pull/69 fix the pass for corner cases: 1. position index is not a slice type but a int in setitem. Change it to a slice type. For ex, [:, 1] - > [:, 1:2] 2. Similarly, for getitem (input for this setitem) that has integer index like [:, 1], the output will remove the 2nd dimension. It will bring trouble for setitem as we are lack of shape info. I will replace with a slice like [:, 1:2] Reviewed By: yinghai Differential Revision: D36136098 fbshipit-source-id: c48f976fe639ff3fa159b44c8c2287f202898901
1 parent 2b49f98 commit 6142f97

File tree

2 files changed

+104
-6
lines changed

2 files changed

+104
-6
lines changed

fx/passes/lower_basic_pass.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import copy
2+
import operator
13
import operator
24
import warnings
35
from typing import Any
@@ -348,13 +350,32 @@ def transform_setitem(gm: torch.fx.GraphModule, input: Input):
348350
"""
349351
map_replace = {}
350352
for node in gm.graph.nodes:
353+
for old_node in map_replace:
354+
node.replace_input_with(old_node, map_replace[old_node])
355+
351356
if node.target == operator.setitem:
352357
input_node = node.args[0]
353358
sli = node.args[1]
354359
inp = node.args[2]
355360

361+
inp_flag = False
362+
if inp.target == operator.getitem:
363+
new_args = list(copy.deepcopy(inp.args[1]))
364+
for ind, val in enumerate(new_args):
365+
if type(val) == int:
366+
inp_flag = True
367+
new_args[ind] = slice(val, val + 1, None)
368+
if inp_flag:
369+
with gm.graph.inserting_before(inp):
370+
new_node = gm.graph.call_function(
371+
operator.getitem, args=(inp.args[0], new_args)
372+
)
373+
inp.replace_all_uses_with(new_node)
374+
inp = new_node
375+
356376
if type(sli) is not tuple:
357377
sli = [sli]
378+
sli = [slice(x, x + 1, None) if type(x) == int else x for x in sli]
358379
dimension = len(sli)
359380
with gm.graph.inserting_before(node):
360381
if dimension == 1:
@@ -418,9 +439,6 @@ def transform_setitem(gm: torch.fx.GraphModule, input: Input):
418439
continue
419440
node.replace_all_uses_with(concat_node_0)
420441
map_replace[input_node] = concat_node_0
421-
else:
422-
for old_node in map_replace:
423-
node.replace_input_with(old_node, map_replace[old_node])
424442

425443
gm.graph.eliminate_dead_code()
426444
gm.graph.lint()

test/passes/test_setitem.py

Lines changed: 83 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,55 @@ def transform_fx(gm, example_inputs):
3232
with optimize_ctx:
3333
m(*inputs)
3434

35+
def test_setitem1d_c2(self):
36+
class TestModule(torch.nn.Module):
37+
def forward(self, x, y):
38+
y[:-1] = x
39+
y[1:] = x
40+
return y
41+
42+
inputs = [torch.randn(2), torch.randn(3)]
43+
m = TestModule()
44+
45+
inputs = [i.cuda() for i in inputs]
46+
m.cuda()
47+
48+
def transform_fx(gm, example_inputs):
49+
gm = transform_setitem(gm, example_inputs)
50+
return gm
51+
52+
optimize_ctx = torchdynamo.optimize(
53+
transform_fx,
54+
nopython=True,
55+
)
56+
57+
with optimize_ctx:
58+
m(*inputs)
59+
60+
def test_setitem1d_c3(self):
61+
class TestModule(torch.nn.Module):
62+
def forward(self, x, y):
63+
y[1] = x
64+
return y
65+
66+
inputs = [torch.randn(2), torch.randn(3)]
67+
m = TestModule()
68+
69+
inputs = [i.cuda() for i in inputs]
70+
m.cuda()
71+
72+
def transform_fx(gm, example_inputs):
73+
gm = transform_setitem(gm, example_inputs)
74+
return gm
75+
76+
optimize_ctx = torchdynamo.optimize(
77+
transform_fx,
78+
nopython=True,
79+
)
80+
81+
with optimize_ctx:
82+
m(*inputs)
83+
3584
@parameterized.expand(
3685
[
3786
("c1", (4, 2), (4, 5), 0, 2),
@@ -96,6 +145,37 @@ def transform_fx(gm, example_inputs):
96145
with optimize_ctx:
97146
m(*inputs)
98147

148+
@parameterized.expand(
149+
[
150+
("c1", (4, 2), (4, 2), 0, 1),
151+
]
152+
)
153+
def test_setitem2d_1v_ex2(self, name, x_shape, y_shape, y_start, y_end):
154+
class TestModule(torch.nn.Module):
155+
def __init__(self):
156+
super().__init__()
157+
158+
def forward(self, x, y):
159+
y[:, y_start:y_end] = x[:, 0]
160+
return y
161+
162+
inputs = [torch.randn(x_shape), torch.randn(y_shape)]
163+
m = TestModule()
164+
165+
inputs = [i.cuda() for i in inputs]
166+
m.cuda()
167+
168+
def transform_fx(gm, example_inputs):
169+
gm = transform_setitem(gm, example_inputs)
170+
return gm
171+
172+
optimize_ctx = torchdynamo.optimize(
173+
transform_fx,
174+
nopython=True,
175+
)
176+
with optimize_ctx:
177+
m(*inputs)
178+
99179
@parameterized.expand(
100180
[
101181
("c1", (3, 2), (4, 5), 0, 3, 0, 2),
@@ -429,17 +509,17 @@ def transform_fx(gm, example_inputs):
429509
with optimize_ctx:
430510
m(*inputs)
431511

432-
## test with torchdynamo
512+
# test with torchdynamo
433513
def test_setitem1d_trt(self):
434514
class TestModule(torch.nn.Module):
435515
def __init__(self):
436516
super().__init__()
437517

438518
def forward(self, x, y):
439-
y[0:2] = x
519+
y[1] = x
440520
return y
441521

442-
inputs = [torch.randn(2), torch.randn(3)]
522+
inputs = [torch.randn(1), torch.randn(3)]
443523
m = TestModule()
444524

445525
inputs = [i.cuda() for i in inputs]

0 commit comments

Comments
 (0)