Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.

Commit ae404ce

Browse files
committed
Added subprocess launcher for NVFuser minifying
1 parent 97822c1 commit ae404ce

File tree

2 files changed

+24
-2
lines changed

2 files changed

+24
-2
lines changed

functorch/_src/fx_minifier.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def minimizer(fail_f: fx.GraphModule, inps, module_fails):
5959
cur_size = len(failing_graph.nodes)
6060

6161
def graph_fails(graph, inps):
62+
6263
mod = fx.GraphModule(fail_f, graph)
6364
mod.graph.lint()
6465
return module_fails(mod, inps)
@@ -190,4 +191,26 @@ def delta_debugging(cur_graph: fx.Graph, cur_inps):
190191
failing_fx = fx.GraphModule(fail_f, failing_graph)
191192
print(failing_fx.code)
192193
print([i.shape for i in inps])
193-
return failing_fx, inps
194+
return failing_fx, inps
195+
196+
import subprocess
197+
def check_nvfuser_subprocess(f, inps):
198+
f.to_folder("temp")
199+
with open("_temp.py", 'w') as fil:
200+
fil.write(f'''
201+
import torch
202+
from temp import FxModule
203+
f = FxModule().cuda()
204+
inps = {[(i.shape, i.dtype) for i in inps]}
205+
inps = [torch.randn(shape, dtype=dtype, device='cuda') for shape, dtype in inps]
206+
with torch.jit.fuser("fuser2"):
207+
nf = torch.jit.script(f)
208+
for _ in range(5):
209+
nf(*inps)
210+
''')
211+
try:
212+
subprocess.check_call("PYTORCH_NVFUSER_DISABLE_FALLBACK=1 python _temp.py", shell=True)
213+
except Exception as e:
214+
print(e)
215+
return True
216+
return False

test/test_minifier.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ def pass_checker(fx_g, inps):
4444
return torch.isnan(fx_g(*inps)[0]).any()
4545

4646
min_f, inps = minimizer(failing_f, inps, pass_checker)
47-
import pdb; pdb.set_trace()
4847
assert len(min_f.graph.nodes) == 3
4948
assert len(inps) == 1
5049

0 commit comments

Comments
 (0)