Skip to content

Commit daecb64

Browse files
authored
Print failing input id when error; print failing backend name (#510)
1 parent 6896537 commit daecb64

File tree

1 file changed

+16
-1
lines changed

1 file changed

+16
-1
lines changed

tritonbench/utils/triton_op.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -684,6 +684,7 @@ class BenchmarkOperator(metaclass=PostInitProcessor):
684684
DEFAULT_METRICS = ["latency"]
685685
required_metrics: List[str]
686686
_cur_input_id: Optional[int] = None
687+
_cur_backend_name: Optional[str] = None
687688
_input_iter: Optional[Generator] = None
688689
extra_args: List[str] = []
689690
example_inputs: Any = None
@@ -980,6 +981,7 @@ def run(
980981

981982
current_pos = 0
982983
for input_id in input_id_range:
984+
self._cur_backend_name = None
983985
# Skip to the correct position if there are gaps
984986
while current_pos < input_id:
985987
self.example_inputs = self.get_example_inputs()
@@ -1078,6 +1080,7 @@ def run(
10781080

10791081
# get metrics for for each registered benchmark
10801082
def _reduce_benchmarks(acc, bm_name: str):
1083+
self._cur_backend_name = bm_name
10811084
baseline = (
10821085
bm_name == BASELINE_BENCHMARKS[self.name]
10831086
if self.name in BASELINE_BENCHMARKS
@@ -1101,6 +1104,7 @@ def _reduce_benchmarks(acc, bm_name: str):
11011104
y_vals: Dict[str, BenchmarkOperatorMetrics] = functools.reduce(
11021105
_reduce_benchmarks, benchmarks, {}
11031106
)
1107+
self._cur_backend_name = None
11041108
metrics.append((x_val, y_vals))
11051109
del self.example_inputs # save some memory
11061110
if "proton" in self.required_metrics:
@@ -1112,10 +1116,21 @@ def _reduce_benchmarks(acc, bm_name: str):
11121116
proton.exit_scope()
11131117
proton.finalize()
11141118
except (KeyboardInterrupt, Exception):
1119+
backend_suffix = (
1120+
f" on backend {self._cur_backend_name}"
1121+
if self._cur_backend_name is not None
1122+
else ""
1123+
)
11151124
logger.warning(
1116-
"Caught exception, terminating early with partial results",
1125+
"Caught exception%s, terminating early with partial results",
1126+
backend_suffix,
11171127
exc_info=True,
11181128
)
1129+
if getattr(self, "_cur_input_id", None) is not None:
1130+
logger.warning(
1131+
"Failing input: --input-id %s --num-inputs 1",
1132+
self._cur_input_id,
1133+
)
11191134
if self.tb_args.exit_on_exception:
11201135
os._exit(1)
11211136
raise

0 commit comments

Comments
 (0)