Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.

Commit d11b5c2

Browse files
committed
made partitioner strip overloads
1 parent 7b70939 commit d11b5c2

File tree

6 files changed

+21
-17
lines changed

6 files changed

+21
-17
lines changed

benchmarks/cse.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from functorch import make_fx
44
from torch.profiler import profile, ProfilerActivity
55

6-
from functorch._src.cse import fx_graph_cse
6+
from functorch._src.compile_utils import fx_graph_cse
77

88
def profile_it(f, inp):
99
for _ in range(5):

functorch/_src/cse.py renamed to functorch/_src/compile_utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
12
import torch
23
import torch.fx as fx
34
from torch.utils._pytree import tree_flatten
@@ -56,3 +57,16 @@ def substitute(arg_list):
5657
token_map[hash_val] = token
5758

5859
return new_graph
60+
61+
62+
def strip_overloads(gm):
63+
"""
64+
Modifies the target of graph nodes in :attr:`gm` to strip overloads.
65+
66+
Args:
67+
gm(fx.GraphModule): The input Fx graph module to be modified
68+
"""
69+
for node in gm.graph.nodes:
70+
if isinstance(node.target, torch._ops.OpOverload):
71+
node.target = node.target.overloadpacket
72+
gm.recompile()

functorch/_src/compilers.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from .aot_autograd import aot_function, aot_module
88
from .decompositions import get_decompositions
99
from .partitioners import draw_graph, min_cut_rematerialization_partition
10+
from .compile_utils import strip_overloads
1011
import time
1112

1213

@@ -20,19 +21,6 @@ def _canonicalize(fx_g):
2021
return fx_g
2122

2223

23-
def strip_overloads(gm):
24-
"""
25-
Modifies the target of graph nodes in :attr:`gm` to strip overloads.
26-
27-
Args:
28-
gm(fx.GraphModule): The input Fx graph module to be modified
29-
"""
30-
for node in gm.graph.nodes:
31-
if isinstance(node.target, torch._ops.OpOverload):
32-
node.target = node.target.overloadpacket
33-
gm.recompile()
34-
35-
3624
def ts_compile(fx_g: fx.GraphModule, _) -> Callable:
3725
"""
3826
Compiles the :attr:`fx_g` with Torchscript compiler.
@@ -245,6 +233,7 @@ def nop(fx_g: fx.GraphModule, _) -> Callable:
245233

246234

247235
def simple_ts_compile(fx_g, _):
236+
strip_overloads(fx_g)
248237
f = torch.jit.script(fx_g)
249238
f = torch.jit.freeze(f.eval())
250239
return f

functorch/_src/partitioners.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import os
88
from torch.fx.passes import graph_drawer
99
from typing import Tuple
10-
from .cse import fx_graph_cse
10+
from .compile_utils import fx_graph_cse, strip_overloads
1111

1212

1313
class InvalidNodeBase(object):
@@ -228,6 +228,7 @@ def min_cut_rematerialization_partition(
228228
raise RuntimeError("Need networkx installed to perform smart recomputation heuristics")
229229

230230
# add the CSE pass
231+
strip_overloads(joint_module)
231232
fx_g = joint_module.graph
232233
cse_graph = fx_graph_cse(fx_g)
233234
joint_module.graph = cse_graph

test/test_memory_efficient_fusion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from functorch import make_fx
55
from torch.nn import functional as F
66
from functorch.compile import memory_efficient_fusion
7-
from functorch._src.cse import fx_graph_cse
7+
from functorch._src.compile_utils import fx_graph_cse
88
from torch.testing._internal.common_utils import TestCase, run_tests
99
import inspect
1010
import random

test/test_pythonkey.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -496,7 +496,7 @@ def f(a, b, c, d):
496496
def f(x):
497497
return torch.mm(x, torch.ones(x.shape)).tanh().tanh()
498498
fw_graph, bw_graph = get_fw_bw_graph(f, [torch.randn(5, 5, requires_grad=True)])
499-
self.assertEqual(get_num_ins_outs(fw_graph), (1, 2))
499+
self.assertEqual(get_num_ins_outs(fw_graph), (1, 3))
500500

501501
ins, outs = get_ins_outs(fw_graph)
502502
self.assertEqual(outs[1].target, torch.ops.aten.mm)

0 commit comments

Comments
 (0)