@@ -89,12 +89,18 @@ def teardown(self):
8989 pass
9090
9191 @abstractmethod
92- def run (self , env_vars , run_trace : TracingType = TracingType .NONE ) -> list [Result ]:
92+ def run (
93+ self ,
94+ env_vars ,
95+ run_trace : TracingType = TracingType .NONE ,
96+ force_trace : bool = False ,
97+ ) -> list [Result ]:
9398 """Execute the benchmark with the given environment variables.
9499
95100 Args:
96101 env_vars: Environment variables to use when running the benchmark.
97102 run_trace: The type of tracing to run (NONE, UNITRACE, or FLAMEGRAPH).
103+ force_trace: If True, ignore the traceable() method and force tracing.
98104
99105 Returns:
100106 A list of Result objects with the benchmark results.
@@ -125,6 +131,7 @@ def run_bench(
125131 use_stdout = True ,
126132 run_trace : TracingType = TracingType .NONE ,
127133 extra_trace_opt = None ,
134+ force_trace : bool = False ,
128135 ):
129136 env_vars = env_vars .copy ()
130137 if options .ur is not None :
@@ -137,7 +144,10 @@ def run_bench(
137144 ld_libraries = options .extra_ld_libraries .copy ()
138145 ld_libraries .extend (ld_library )
139146
140- if self .traceable (TracingType .UNITRACE ) and run_trace == TracingType .UNITRACE :
147+ unitrace_output = None
148+ if (
149+ self .traceable (TracingType .UNITRACE ) or force_trace
150+ ) and run_trace == TracingType .UNITRACE :
141151 if extra_trace_opt is None :
142152 extra_trace_opt = []
143153 unitrace_output , command = get_unitrace ().setup (
@@ -146,29 +156,12 @@ def run_bench(
146156 log .debug (f"Unitrace output: { unitrace_output } " )
147157 log .debug (f"Unitrace command: { ' ' .join (command )} " )
148158
149- try :
150- result = run (
151- command = command ,
152- env_vars = env_vars ,
153- add_sycl = add_sycl ,
154- cwd = options .benchmark_cwd ,
155- ld_library = ld_libraries ,
156- )
157- except subprocess .CalledProcessError :
158- if run_trace == TracingType .UNITRACE :
159- get_unitrace ().cleanup (options .benchmark_cwd , unitrace_output )
160- raise
161-
162- if self .traceable (TracingType .UNITRACE ) and run_trace == TracingType .UNITRACE :
163- get_unitrace ().handle_output (unitrace_output )
164-
165159 # flamegraph run
166160
167161 perf_data_file = None
168162 if (
169- self .traceable (TracingType .FLAMEGRAPH )
170- and run_trace == TracingType .FLAMEGRAPH
171- ):
163+ self .traceable (TracingType .FLAMEGRAPH ) or force_trace
164+ ) and run_trace == TracingType .FLAMEGRAPH :
172165 perf_data_file , command = get_flamegraph ().setup (
173166 self .name (), self .get_suite_name (), command
174167 )
@@ -184,12 +177,21 @@ def run_bench(
184177 ld_library = ld_libraries ,
185178 )
186179 except subprocess .CalledProcessError :
180+ if run_trace == TracingType .UNITRACE and unitrace_output :
181+ get_unitrace ().cleanup (options .benchmark_cwd , unitrace_output )
187182 if run_trace == TracingType .FLAMEGRAPH and perf_data_file :
188183 get_flamegraph ().cleanup (perf_data_file )
189184 raise
190185
191186 if (
192- self .traceable (TracingType .FLAMEGRAPH )
187+ (self .traceable (TracingType .UNITRACE ) or force_trace )
188+ and run_trace == TracingType .UNITRACE
189+ and unitrace_output
190+ ):
191+ get_unitrace ().handle_output (unitrace_output )
192+
193+ if (
194+ (self .traceable (TracingType .FLAMEGRAPH ) or force_trace )
193195 and run_trace == TracingType .FLAMEGRAPH
194196 and perf_data_file
195197 ):
0 commit comments