1313from utils .utils import download , run
1414from abc import ABC , abstractmethod
1515from utils .unitrace import get_unitrace
16- from utils .logger import log
1716from utils .flamegraph import get_flamegraph
17+ from utils .logger import log
1818
1919
2020class TracingType (Enum ):
2121 """Enumeration of available tracing types."""
2222
23+ NONE = ""
2324 UNITRACE = "unitrace"
2425 FLAMEGRAPH = "flamegraph"
2526
@@ -88,15 +89,12 @@ def teardown(self):
8889 pass
8990
9091 @abstractmethod
91- def run (
92- self , env_vars , run_unitrace : bool = False , run_flamegraph : bool = False
93- ) -> list [Result ]:
92+ def run (self , env_vars , run_trace : TracingType = TracingType .NONE ) -> list [Result ]:
9493 """Execute the benchmark with the given environment variables.
9594
9695 Args:
9796 env_vars: Environment variables to use when running the benchmark.
98- run_unitrace: Whether to run benchmark under Unitrace.
99- run_flamegraph: Whether to run benchmark under FlameGraph.
97+ run_trace: The type of tracing to run (NONE, UNITRACE, or FLAMEGRAPH).
10098
10199 Returns:
102100 A list of Result objects with the benchmark results.
@@ -125,10 +123,8 @@ def run_bench(
125123 ld_library = [],
126124 add_sycl = True ,
127125 use_stdout = True ,
128- run_unitrace = False ,
129- extra_unitrace_opt = None ,
130- run_flamegraph = False ,
131- extra_perf_opt = None , # VERIFY
126+ run_trace : TracingType = TracingType .NONE ,
127+ extra_trace_opt = None ,
132128 ):
133129 env_vars = env_vars .copy ()
134130 if options .ur is not None :
@@ -141,11 +137,11 @@ def run_bench(
141137 ld_libraries = options .extra_ld_libraries .copy ()
142138 ld_libraries .extend (ld_library )
143139
144- if self .traceable (TracingType .UNITRACE ) and run_unitrace :
145- if extra_unitrace_opt is None :
146- extra_unitrace_opt = []
140+ if self .traceable (TracingType .UNITRACE ) and run_trace == TracingType . UNITRACE :
141+ if extra_trace_opt is None :
142+ extra_trace_opt = []
147143 unitrace_output , command = get_unitrace ().setup (
148- self .name (), command , extra_unitrace_opt
144+ self .name (), command , extra_trace_opt
149145 )
150146 log .debug (f"Unitrace output: { unitrace_output } " )
151147 log .debug (f"Unitrace command: { ' ' .join (command )} " )
@@ -159,24 +155,22 @@ def run_bench(
159155 ld_library = ld_libraries ,
160156 )
161157 except subprocess .CalledProcessError :
162- if run_unitrace :
158+ if run_trace == TracingType . UNITRACE :
163159 get_unitrace ().cleanup (options .benchmark_cwd , unitrace_output )
164160 raise
165161
166- if self .traceable (TracingType .UNITRACE ) and run_unitrace :
162+ if self .traceable (TracingType .UNITRACE ) and run_trace == TracingType . UNITRACE :
167163 get_unitrace ().handle_output (unitrace_output )
168164
169165 # flamegraph run
170166
171- ld_libraries = options .extra_ld_libraries .copy ()
172- ld_libraries .extend (ld_library )
173-
174167 perf_data_file = None
175- if self .traceable (TracingType .FLAMEGRAPH ) and run_flamegraph :
176- if extra_perf_opt is None :
177- extra_perf_opt = []
168+ if (
169+ self .traceable (TracingType .FLAMEGRAPH )
170+ and run_trace == TracingType .FLAMEGRAPH
171+ ):
178172 perf_data_file , command = get_flamegraph ().setup (
179- self .name (), command , extra_perf_opt
173+ self .name (), self . get_suite_name (), command
180174 )
181175 log .debug (f"FlameGraph perf data: { perf_data_file } " )
182176 log .debug (f"FlameGraph command: { ' ' .join (command )} " )
@@ -190,11 +184,15 @@ def run_bench(
190184 ld_library = ld_libraries ,
191185 )
192186 except subprocess .CalledProcessError :
193- if run_flamegraph and perf_data_file :
194- get_flamegraph ().cleanup (options . benchmark_cwd , perf_data_file )
187+ if run_trace == TracingType . FLAMEGRAPH and perf_data_file :
188+ get_flamegraph ().cleanup (perf_data_file )
195189 raise
196190
197- if self .traceable (TracingType .FLAMEGRAPH ) and run_flamegraph and perf_data_file :
191+ if (
192+ self .traceable (TracingType .FLAMEGRAPH )
193+ and run_trace == TracingType .FLAMEGRAPH
194+ and perf_data_file
195+ ):
198196 svg_file = get_flamegraph ().handle_output (
199197 self .name (), perf_data_file , self .get_suite_name ()
200198 )
0 commit comments