Skip to content

Commit 6896537

Browse files
authored
[trace] Enable single run of the benchmark function
Differential Revision: D83702432 Pull Request resolved: #503
1 parent d8b41f2 commit 6896537

File tree

2 files changed

+10
-28
lines changed

2 files changed

+10
-28
lines changed

tritonbench/components/ncu/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def do_bench_in_task(
2121
grad_to_none=None,
2222
range_name: str = "",
2323
warmup: bool = False,
24-
warmup_time: int = 25,
24+
test_run: bool = False,
2525
use_cuda_profiler_range: bool = False,
2626
) -> None:
2727
"""
@@ -34,7 +34,8 @@ def do_bench_in_task(
3434
:type grad_to_none: torch.tensor, optional
3535
"""
3636

37-
fn()
37+
if test_run:
38+
fn()
3839
torch.cuda.synchronize()
3940

4041
cache = torch.empty(int(256e6 // 4), dtype=torch.int, device="cuda")

tritonbench/utils/triton_op.py

Lines changed: 7 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1741,13 +1741,13 @@ def _init_extra_metrics() -> Dict[str, Any]:
17411741
self._latency_with_compile_in_task = metrics.extra_metrics[
17421742
"_compile_time_in_task"
17431743
]
1744-
if "_ncu_trace_in_task" in self.required_metrics:
1744+
if "single_run_in_task" in self.required_metrics:
17451745
assert (
1746-
self.required_metrics == ["_ncu_trace_in_task"]
1746+
self.required_metrics == ["single_run_in_task"]
17471747
and len(self._only) == 1
17481748
and (self._cur_input_id is not None)
17491749
), (
1750-
"_ncu_trace_in_task must be measured by itself. "
1750+
"single_run_in_task must be measured by itself. "
17511751
f"required_metrics: {self.required_metrics}, _only: {self._only}, _input_id: {self._cur_input_id}"
17521752
)
17531753
from tritonbench.components.ncu import do_bench_in_task
@@ -1757,26 +1757,7 @@ def _init_extra_metrics() -> Dict[str, Any]:
17571757
grad_to_none=self.get_grad_to_none(self.example_inputs),
17581758
range_name=_RANGE_NAME,
17591759
)
1760-
metrics.extra_metrics["_ncu_trace_in_task"] = "success"
1761-
if "_nsys_rep_in_task" in self.required_metrics:
1762-
assert (
1763-
self.required_metrics == ["_nsys_rep_in_task"]
1764-
and len(self._only) == 1
1765-
and (self._cur_input_id is not None)
1766-
), (
1767-
"_nsys_rep_in_task must be measured by itself. "
1768-
f"required_metrics: {self.required_metrics}, _only: {self._only}, _input_id: {self._cur_input_id}"
1769-
)
1770-
from tritonbench.components.ncu import do_bench_in_task
1771-
1772-
do_bench_in_task(
1773-
fn=fn,
1774-
grad_to_none=self.get_grad_to_none(self.example_inputs),
1775-
range_name=_RANGE_NAME,
1776-
warmup=True,
1777-
use_cuda_profiler_range=True,
1778-
)
1779-
metrics.extra_metrics["_nsys_rep_in_task"] = "success"
1760+
metrics.extra_metrics["single_run_in_task"] = "success"
17801761
if self.tb_args.export:
17811762
export_data(
17821763
x_val=self.get_x_val(self.example_inputs),
@@ -1929,7 +1910,7 @@ def _get_op_task_args(
19291910
return op_task_args
19301911

19311912
def nsys_rep(self, input_id: int, fn_name: str) -> str:
1932-
op_task_args = self._get_op_task_args(input_id, fn_name, "_nsys_rep_in_task")
1913+
op_task_args = self._get_op_task_args(input_id, fn_name, "single_run_in_task")
19331914
nsys_output_dir = self.get_temp_path(fn_name)
19341915
nsys_output_dir.mkdir(parents=True, exist_ok=True)
19351916
ext = ".nsys-rep"
@@ -1971,7 +1952,7 @@ def ncu_trace(
19711952
"full",
19721953
]
19731954
)
1974-
op_task_args = self._get_op_task_args(input_id, fn_name, "_ncu_trace_in_task")
1955+
op_task_args = self._get_op_task_args(input_id, fn_name, "single_run_in_task")
19751956
# Disable DCGM
19761957
disable_dyno_dcgm = [
19771958
"sudo",
@@ -2052,7 +2033,7 @@ def service_exists(service_name):
20522033
return str(ncu_output_file.resolve())
20532034

20542035
def att_trace(self, input_id: int, fn_name: str) -> str:
2055-
op_task_args = self._get_op_task_args(input_id, fn_name, "_ncu_trace_in_task")
2036+
op_task_args = self._get_op_task_args(input_id, fn_name, "single_run_in_task")
20562037
att_output_dir = self.get_temp_path(fn_name)
20572038
att_trace_dir = launch_att(att_output_dir, op_task_args)
20582039
return att_trace_dir

0 commit comments

Comments
 (0)