1212from options import options
1313from utils .utils import download , run
1414from abc import ABC , abstractmethod
15+ from utils .flamegraph import get_flamegraph
16+ from utils .logger import log
1517
1618benchmark_tags = [
1719 BenchmarkTag ("SYCL" , "Benchmark uses SYCL runtime" ),
@@ -61,6 +63,12 @@ def enabled(self) -> bool:
6163 By default, it returns True, but can be overridden to disable a benchmark."""
6264 return True
6365
66+ def traceable (self ) -> bool :
67+ """Returns whether this benchmark should be traced by FlameGraph.
68+ By default, it returns True, but can be overridden to disable tracing for a benchmark.
69+ """
70+ return True
71+
6472 @abstractmethod
6573 def setup (self ):
6674 pass
@@ -70,11 +78,12 @@ def teardown(self):
7078 pass
7179
7280 @abstractmethod
73- def run (self , env_vars ) -> list [Result ]:
81+ def run (self , env_vars , run_flamegraph : bool = False ) -> list [Result ]:
7482 """Execute the benchmark with the given environment variables.
7583
7684 Args:
7785 env_vars: Environment variables to use when running the benchmark.
86+ run_flamegraph: Whether to run benchmark under FlameGraph.
7887
7988 Returns:
8089 A list of Result objects with the benchmark results.
@@ -97,7 +106,14 @@ def get_adapter_full_path():
97106 ), f"could not find adapter file { adapter_path } (and in similar lib paths)"
98107
99108 def run_bench (
100- self , command , env_vars , ld_library = [], add_sycl = True , use_stdout = True
109+ self ,
110+ command ,
111+ env_vars ,
112+ ld_library = [],
113+ add_sycl = True ,
114+ use_stdout = True ,
115+ run_flamegraph = False ,
116+ extra_perf_opt = None ,
101117 ):
102118 env_vars = env_vars .copy ()
103119 if options .ur is not None :
@@ -110,13 +126,32 @@ def run_bench(
110126 ld_libraries = options .extra_ld_libraries .copy ()
111127 ld_libraries .extend (ld_library )
112128
113- result = run (
114- command = command ,
115- env_vars = env_vars ,
116- add_sycl = add_sycl ,
117- cwd = options .benchmark_cwd ,
118- ld_library = ld_libraries ,
119- )
129+ perf_data_file = None
130+ if self .traceable () and run_flamegraph :
131+ if extra_perf_opt is None :
132+ extra_perf_opt = []
133+ perf_data_file , command = get_flamegraph ().setup (
134+ self .name (), command , extra_perf_opt
135+ )
136+ log .debug (f"FlameGraph perf data: { perf_data_file } " )
137+ log .debug (f"FlameGraph command: { ' ' .join (command )} " )
138+
139+ try :
140+ result = run (
141+ command = command ,
142+ env_vars = env_vars ,
143+ add_sycl = add_sycl ,
144+ cwd = options .benchmark_cwd ,
145+ ld_library = ld_libraries ,
146+ )
147+ except subprocess .CalledProcessError :
148+ if run_flamegraph and perf_data_file :
149+ get_flamegraph ().cleanup (options .benchmark_cwd , perf_data_file )
150+ raise
151+
152+ if self .traceable () and run_flamegraph and perf_data_file :
153+ svg_file = get_flamegraph ().handle_output (self .name (), perf_data_file )
154+ log .info (f"FlameGraph generated: { svg_file } " )
120155
121156 if use_stdout :
122157 return result .stdout .decode ()
0 commit comments