Skip to content

Commit 8adba6e

Browse files
kiya00Copilot
andauthored
Improve reporting tools (#2787)
Co-authored-by: Copilot <[email protected]>
1 parent 9c1ac8e commit 8adba6e

File tree

3 files changed

+126
-29
lines changed

3 files changed

+126
-29
lines changed

thunder/dynamo/benchmark_utils.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def compile(self, fn, *, inputs, **kwargs):
150150

151151
# to_source will always use symbolic trace
152152
def to_source(self, fn_name):
153-
return f"TorchInductorSpecification.torch_inductor({fn_name}, inputs)"
153+
return f"TorchInductorSpecification.torch_inductor({fn_name}, inputs, skip_symbolic_trace={self.skip_symbolic_trace})"
154154

155155
def import_str(self):
156156
return ["import torch", "from thunder.dynamo.benchmark_utils import TorchInductorSpecification"]
@@ -353,6 +353,12 @@ def time(self, stmt="pass", setup="pass", globals=None) -> Measurement:
353353
Measurement: A benchmarking result containing execution time statistics, see :class:`torch.utils.benchmark.utils.common.Measurement`.
354354
"""
355355
t = TorchBenchmarkTimer(stmt=stmt, setup=setup, globals=globals, timer=self.inner_timer)
356+
# If the timer measures an extremely short execution time, adaptive_autorange may hang.
357+
# To prevent this, we perform a preliminary run to check for such cases, e.g. measure kernel time on a cpu-only graph.
358+
# If detected, we return the time of a single run, avoiding potential hangs.
359+
pre_run = t.timeit(1)
360+
if pre_run.median <= 1e-9:
361+
return pre_run
356362
measurement = t.adaptive_autorange(
357363
threshold=self.threshold, min_run_time=self.min_run_time, max_run_time=self.max_run_time
358364
)

thunder/dynamo/report.py

Lines changed: 57 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -520,7 +520,7 @@ def write_repro(
520520
code_str = f"{code_str}\n{main_code.format(graph_name=self.graph_name)}\n{comment_str}"
521521

522522
if file_name is None:
523-
file_name = f"{self.graph_name}.py"
523+
file_name = f"{self.graph_name}_{compile_fn.name}_repro.py"
524524
with open(folder / file_name, "w") as f:
525525
print(code_str, file=f)
526526
format_python_file(folder / file_name)
@@ -633,7 +633,7 @@ def write_benchmark(
633633

634634
code_str = f"{code_str}\n{main_code.format(graph_name=self.graph_name)}\n{comment_str}"
635635
if file_name is None:
636-
file_name = f"{self.graph_name}.py"
636+
file_name = f"{self.graph_name}_{compile_fn.name}_{time_fn.name}_benchmark.py"
637637
with open(folder / file_name, "w") as f:
638638
print(code_str, file=f)
639639
format_python_file(folder / file_name)
@@ -924,7 +924,7 @@ def write_nvfuser_benchmark(self, folder, time_fn: TimerInterface, file_name=Non
924924
{comment_str}
925925
"""
926926
if file_name is None:
927-
file_name = f"{self.name}_benchmark_nvfuser.py"
927+
file_name = f"{self.name}_benchmark_nvfuser_{time_fn.name}.py"
928928
with open(folder / file_name, "w") as f:
929929
print(code_str, file=f)
930930
format_python_file(folder / file_name)
@@ -983,7 +983,7 @@ def write_inductor_benchmark(self, folder: PathLike, time_fn: TimerInterface, fi
983983
print(measurement)
984984
"""
985985
if file_name is None:
986-
file_name = f"{self.name}_benchmark_inductor.py"
986+
file_name = f"{self.name}_benchmark_inductor_{time_fn.name}.py"
987987
with open(folder / file_name, "w") as f:
988988
f.write(code_str)
989989
format_python_file(folder / file_name)
@@ -1428,22 +1428,39 @@ def save_thunderfx_repros(
14281428
Saves reproduction scripts for ThunderFX subgraphs.
14291429
14301430
This function:
1431-
1. Creates a folder structure to organize the repros
1432-
.
1433-
└── graph0
1434-
├── fusion_reports
1435-
│ ├── graph0_thunder_0_nvFusion0_forward_repro_nvfuser.py
1436-
│ ├── graph0_thunder_0_nvFusion1_forward_repro_nvfuser.py
1437-
│ ├── graph0_thunder_0_nvFusion2_backward_repro_nvfuser.py
1438-
├── graph0_thunder_0_bwd_trace.py
1439-
├── graph0_thunder_0_fwd_trace.py
1440-
└── graph0_thunder_0.py
1431+
1. Creates a folder structure to organize the repro or benchmark scripts:
1432+
1433+
If use_benchmark is True:
1434+
graph0/
1435+
├── fusion_reports/
1436+
│ ├── graph0_thunder_0_nvFusion0_forward_benchmark_inductor_KernelTime.py
1437+
│ ├── graph0_thunder_0_nvFusion0_forward_benchmark_inductor_WallTimeWithMemoryUsage.py
1438+
│ ├── graph0_thunder_0_nvFusion0_forward_benchmark_nvfuser_KernelTime.py
1439+
│ └── graph0_thunder_0_nvFusion0_forward_benchmark_nvfuser_WallTimeWithMemoryUsage.py
1440+
├── graph0_repro_torchcompile.py
1441+
├── graph0_thunder_0_bwd_trace.py
1442+
├── graph0_thunder_0_fwd_trace.py
1443+
├── graph0_thunder_0_inductor_KernelTime_benchmark.py
1444+
├── graph0_thunder_0_inductor_WallTimeWithMemoryUsage_benchmark.py
1445+
├── graph0_thunder_0_thunder_KernelTime_benchmark.py
1446+
└── graph0_thunder_0_thunder_WallTimeWithMemoryUsage_benchmark.py
1447+
1448+
If use_benchmark is False:
1449+
graph0/
1450+
├── fusion_reports/
1451+
│ ├── graph0_thunder_0_nvFusion0_forward_repro_inductor.py
1452+
│ └── graph0_thunder_0_nvFusion0_forward_repro_nvfuser.py
1453+
├── graph0_repro_torchcompile.py
1454+
├── graph0_thunder_0_fwd_trace.py
1455+
├── graph0_thunder_0_bwd_trace.py
1456+
├── graph0_thunder_0_inductor_repro.py
1457+
└── graph0_thunder_0_thunder_repro.py
14411458
14421459
2. For each Thunder FX graph and its subgraphs:
1443-
- Checks runnability if requested
1444-
- Saves benchmark or repro scripts
1445-
- Saves trace information if requested
1446-
- Saves nvFusion repros if requested
1460+
- Checks runnability if requested
1461+
- Saves benchmark or repro scripts
1462+
- Saves trace information if requested
1463+
- Saves nvFusion repros if requested
14471464
14481465
Args:
14491466
fn: The callable to analyze
@@ -1452,7 +1469,7 @@ def save_thunderfx_repros(
14521469
check_runnability: If True, checks if graphs can run with Thunder
14531470
save_fusion: If True, saves nvFusion repros
14541471
save_trace: If True, saves trace information
1455-
stream: Stream to write output log informationto
1472+
stream: Stream to write output log information to
14561473
force_overwrite: If True, overwrites existing folder at folder_path
14571474
**compile_kwargs: Keyword arguments for Thunder and torch.compile
14581475
@@ -1472,6 +1489,7 @@ def inner_fn(*args, **kwargs):
14721489
for thunder_fxgraph_report in thunder_fxgraph_reports:
14731490
graph_folder = folder_path / thunder_fxgraph_report.graph_name
14741491
graph_folder.mkdir(exist_ok=True, parents=True)
1492+
thunder_fxgraph_report.write_inductor_repro(graph_folder)
14751493
for split_report in thunder_fxgraph_report.subgraph_reports:
14761494
if check_runnability or save_trace or save_fusion:
14771495
try:
@@ -1484,22 +1502,38 @@ def inner_fn(*args, **kwargs):
14841502
continue
14851503
else:
14861504
stream.write(f"Successfully ran the {split_report.graph_name} using Thunder\n")
1505+
1506+
from torch._inductor.compile_fx import graph_returns_tuple
1507+
1508+
# torch._inductor.compile requires the output to be tuple, if not, the symbolic trace is necessary
1509+
skip_symbolic_trace = graph_returns_tuple(split_report.graph)
1510+
torchinductor = TorchInductorSpecification(skip_symbolic_trace=skip_symbolic_trace)
14871511
if use_benchmark:
1488-
split_report.write_benchmark(graph_folder, thunderjit, WallTime)
1512+
split_report.write_benchmark(graph_folder, thunderjit, WallTimeWithMemoryUsage)
1513+
split_report.write_benchmark(graph_folder, thunderjit, KernelTime)
1514+
1515+
split_report.write_benchmark(graph_folder, torchinductor, WallTimeWithMemoryUsage)
1516+
split_report.write_benchmark(graph_folder, torchinductor, KernelTime)
14891517
else:
14901518
split_report.write_repro(graph_folder, thunderjit)
1519+
split_report.write_repro(graph_folder, torchinductor)
14911520
if save_trace:
14921521
with open(graph_folder / f"{split_report.graph_name}_fwd_trace.py", "w") as f:
14931522
f.write(str(split_report.fwd_trc))
1494-
with open(graph_folder / f"{split_report.graph_name}_bwd_trace.py", "w") as f:
1495-
f.write(str(split_report.bwd_trc))
1523+
if split_report.bwd_trc is not None:
1524+
with open(graph_folder / f"{split_report.graph_name}_bwd_trace.py", "w") as f:
1525+
f.write(str(split_report.bwd_trc))
14961526
if save_fusion:
14971527
fusion_folder = graph_folder / "fusion_reports"
14981528
fusion_folder.mkdir(exist_ok=True, parents=True)
14991529
for fusion_report in split_report.fusion_reports:
15001530
if use_benchmark:
1501-
fusion_report.write_nvfuser_benchmark(fusion_folder, WallTime)
1531+
fusion_report.write_nvfuser_benchmark(fusion_folder, WallTimeWithMemoryUsage)
1532+
fusion_report.write_inductor_benchmark(fusion_folder, WallTimeWithMemoryUsage)
1533+
fusion_report.write_nvfuser_benchmark(fusion_folder, KernelTime)
1534+
fusion_report.write_inductor_benchmark(fusion_folder, KernelTime)
15021535
else:
15031536
fusion_report.write_nvfuser_repro(fusion_folder)
1537+
fusion_report.write_inductor_repro(fusion_folder)
15041538

15051539
return inner_fn

thunder/tests/test_dynamo.py

Lines changed: 62 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1575,7 +1575,7 @@ def foo(x):
15751575
assert len(thunder_fx_graph_report.subgraph_reports) == 1 # cos
15761576
thunder_split_report = thunder_fx_graph_report.subgraph_reports[0]
15771577

1578-
torchinductor = TorchInductorSpecification()
1578+
torchinductor = TorchInductorSpecification(skip_symbolic_trace=False)
15791579
thunder_split_report.run_benchmark(torchinductor, WallTime)
15801580
thunder_split_report.run_repro(torchinductor)
15811581
thunder_split_report.write_benchmark(tmp_path, torchinductor, WallTime)
@@ -1601,14 +1601,14 @@ def foo(x):
16011601
results = fx_report(foo)(x)
16021602
with patch("thunder.dynamo.report.FXGraphReport.run_repro", side_effect=Exception("run_Repro raises exception")):
16031603
save_failing_repros(results.fx_graph_reports, TorchCompileSpecification(), tmp_path)
1604-
assert os.path.exists(tmp_path / "graph0.py")
1604+
assert os.path.exists(tmp_path / "graph0_torchcompile_repro.py")
16051605

16061606
# Tests for thunder split reports
16071607
thunder_fxgraph_reports = get_thunder_fxgraph_reports(foo)(x)
16081608
assert len(thunder_fxgraph_reports) == 1
16091609
with patch("thunder.dynamo.report.FXGraphReport.run_repro", side_effect=Exception("run_Repro raises exception")):
16101610
save_failing_repros(thunder_fxgraph_reports[0].subgraph_reports, ThunderCompileSpecification(), tmp_path)
1611-
assert os.path.exists(tmp_path / "graph0_thunder_0.py")
1611+
assert os.path.exists(tmp_path / "graph0_thunder_0_thunder_repro.py")
16121612

16131613
# Tests for check_consistency
16141614
def wrapped_fn(x):
@@ -1622,12 +1622,12 @@ def compile(self, fn, **kwargs):
16221622
save_failing_repros(
16231623
results.fx_graph_reports, _BadCompileSpecification(), tmp_path / "consistency", check_consistency=False
16241624
)
1625-
assert not os.path.exists(tmp_path / "consistency" / "graph0.py")
1625+
assert not os.path.exists(tmp_path / "consistency" / "graph0_torcheager_repro.py")
16261626

16271627
save_failing_repros(
16281628
results.fx_graph_reports, _BadCompileSpecification(), tmp_path / "consistency", check_consistency=True
16291629
)
1630-
assert os.path.exists(tmp_path / "consistency" / "graph0.py")
1630+
assert os.path.exists(tmp_path / "consistency" / "graph0_torcheager_repro.py")
16311631

16321632

16331633
@requiresCUDA
@@ -1935,3 +1935,60 @@ def fn():
19351935
"is a `torch.cuda.Stream` method which is not supported by Thunder" in getattr(reason, "info", "")
19361936
for reason in split_reasons
19371937
)
1938+
1939+
1940+
@requiresCUDA
1941+
@pytest.mark.parametrize("use_benchmark", (True, False), ids=("benchmark", "repro"))
1942+
def test_save_thunderfx_repros(use_benchmark, tmp_path):
1943+
from thunder.dynamo.report import save_thunderfx_repros
1944+
1945+
x = torch.ones(2, 2, device="cuda", requires_grad=False)
1946+
1947+
def foo(x):
1948+
# torch.sinc has automatic fallback registered,
1949+
# so that operation will be given to inductor.
1950+
x = x.exp()
1951+
torch._dynamo.graph_break()
1952+
return torch.sinc(x) + torch.cos(x)
1953+
1954+
save_thunderfx_repros(
1955+
foo,
1956+
tmp_path,
1957+
use_benchmark=use_benchmark,
1958+
check_runnability=True,
1959+
save_fusion=True,
1960+
save_trace=True,
1961+
force_overwrite=True,
1962+
disable_torch_autograd=True,
1963+
)(x)
1964+
1965+
# Checks the scripts are generated correctly
1966+
subdirs = [d for d in os.listdir(tmp_path) if os.path.isdir(os.path.join(tmp_path, d))]
1967+
for d in subdirs:
1968+
assert d.startswith("graph"), f"{d} is not graph folder"
1969+
assert len(subdirs) == 2, f"it should be 2 graphs, but in fact {subdirs}"
1970+
1971+
num_backend = 2
1972+
num_traces = 1 # forward only
1973+
if use_benchmark:
1974+
num_g_files = num_backend * 2 + num_traces + 1
1975+
num_fusion_files = num_backend * 2
1976+
else:
1977+
num_g_files = num_backend + num_traces + 1
1978+
num_fusion_files = num_backend
1979+
1980+
for d in subdirs:
1981+
graph_dir = os.path.join(tmp_path, d)
1982+
# Checks if fusion_reports directory exists
1983+
fusion_reports_dir = os.path.join(graph_dir, "fusion_reports")
1984+
assert os.path.isdir(fusion_reports_dir), f"{fusion_reports_dir} doesn't exist"
1985+
# Checks if graph directory has the correct number of files
1986+
g_files = [f for f in os.listdir(graph_dir) if os.path.isfile(os.path.join(graph_dir, f))]
1987+
assert len(g_files) == num_g_files, f"{graph_dir} should have {num_g_files} files, but in fact {g_files}"
1988+
# Checks if fusion_reports directory has the correct number of files
1989+
fusion_files = [
1990+
f for f in os.listdir(fusion_reports_dir) if os.path.isfile(os.path.join(fusion_reports_dir, f))
1991+
]
1992+
assert len(fusion_files) == num_fusion_files, (
1993+
f"{fusion_reports_dir} should have {num_fusion_files} files, but in fact {fusion_files}"
1994+
)

0 commit comments

Comments
 (0)