Skip to content

Commit 1e93a45

Browse files
authored
Removed some obsolete workarounds in ts_compile and added a new one (#875)
1 parent e428146 commit 1e93a45

File tree

2 files changed

+4
-19
lines changed

2 files changed

+4
-19
lines changed

functorch/_src/compilers.py

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -34,16 +34,9 @@ def ts_compile(fx_g: fx.GraphModule, _) -> Callable:
3434
Torch scripted model.
3535
"""
3636
for node in fx_g.graph.nodes:
37-
if node.target in (torch.ops.aten.new_zeros, torch.ops.aten.new_empty):
38-
if node.args[1] == []:
39-
args = list(node.args)
40-
args[1] = [1]
41-
node.args = tuple(args)
42-
elif node.target is torch.ops.aten.masked_fill and node.args[2] == float("-inf"):
43-
# Fx graph to torchscript fails for -inf
44-
args = list(node.args)
45-
args[2] = -3.403 * 10**37
46-
node.args = tuple(args)
37+
if (node.target == torch.ops.aten._to_copy and len(node.args) == 1
38+
and len(node.kwargs) == 1 and 'dtype' in node.kwargs):
39+
node.target = torch.ops.aten.to
4740

4841
for node in fx_g.graph.nodes:
4942
new_kwargs = {}
@@ -55,15 +48,6 @@ def ts_compile(fx_g: fx.GraphModule, _) -> Callable:
5548

5649
fx_g.graph.lint()
5750

58-
# print(set([i.target for i in fx_g.graph.nodes if i.op == 'call_function']))
59-
# Works around this NVFuser issue: https://github.com/csarofeen/pytorch/issues/1311
60-
for i in range(1000):
61-
attr = f"_tensor_constant{i}"
62-
if hasattr(fx_g, attr):
63-
setattr(fx_g, attr, getattr(fx_g, attr).cuda())
64-
else:
65-
break
66-
6751
fx_g.recompile()
6852

6953
f = torch.jit.script(fx_g)

functorch/_src/python_key.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@
77
from torch.fx.experimental.proxy_tensor import make_fx, ProxyTensor, dispatch_trace, PythonKeyTracer, decompose
88

99
pythonkey_decompose = decompose
10+
PythonTensor = ProxyTensor

0 commit comments

Comments
 (0)