Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.

Commit 7b70939

Browse files
authored
add a pass in ts_compile to prepare for jit.script (#899)
* add a pass in ts_compile to strip overloads and prepare for jit.script
1 parent d9de359 commit 7b70939

File tree

1 file changed

+15
-0
lines changed

1 file changed

+15
-0
lines changed

functorch/_src/compilers.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,19 @@ def _canonicalize(fx_g):
2020
return fx_g
2121

2222

23+
def strip_overloads(gm):
24+
"""
25+
Modifies the target of graph nodes in :attr:`gm` to strip overloads.
26+
27+
Args:
28+
gm(fx.GraphModule): The input Fx graph module to be modified
29+
"""
30+
for node in gm.graph.nodes:
31+
if isinstance(node.target, torch._ops.OpOverload):
32+
node.target = node.target.overloadpacket
33+
gm.recompile()
34+
35+
2336
def ts_compile(fx_g: fx.GraphModule, _) -> Callable:
2437
"""
2538
Compiles the :attr:`fx_g` with Torchscript compiler.
@@ -46,6 +59,8 @@ def ts_compile(fx_g: fx.GraphModule, _) -> Callable:
4659
new_kwargs[k] = v
4760
node.kwargs = new_kwargs
4861

62+
strip_overloads(fx_g)
63+
4964
fx_g.graph.lint()
5065

5166
fx_g.recompile()

0 commit comments

Comments
 (0)