1313VISUALIZATIONS_PATH = os .path .abspath (os .path .join (os .path .dirname (__file__ ), "visualizations/" ))
1414
1515
16+ # Map --sweep-mode values to the x_name used in benchmark CSV data.
17+ # "model_config" sweeps always write x_name="model_config"; token-length
18+ # sweeps use kernel-specific names (e.g. "T"), so we match them by exclusion.
19+ SWEEP_MODE_X_NAME = "model_config"
20+
21+
1622@dataclass
1723class VisualizationsConfig :
1824 """
@@ -22,6 +28,9 @@ class VisualizationsConfig:
2228 kernel_name (str): Kernel name to benchmark. (Will run `scripts/benchmark_{kernel_name}.py`)
2329 metric_name (str): Metric name to visualize (speed/memory)
2430 kernel_operation_mode (str): Kernel operation mode to visualize (forward/backward/full). Defaults to "full"
31+ sweep_mode (str, optional): Sweep mode to filter data. "token_length" selects
32+ token/sequence-length sweep data; "model_config" selects model-configuration
33+ sweep data. When None, all data is considered (legacy behaviour).
2534 extra_config_filter (str, optional): A string to filter extra_benchmark_config.
2635 Can be a substring to match or a 'key=value' pair (e.g., "'H': 4096").
2736 Defaults to None, which means the first available config will be used if multiple exist.
@@ -33,6 +42,7 @@ class VisualizationsConfig:
3342 kernel_name : str
3443 metric_name : str
3544 kernel_operation_mode : str = "full"
45+ sweep_mode : str = "token_length"
3646 extra_config_filter : str | None = None
3747 display : bool = False
3848 overwrite : bool = False
@@ -59,6 +69,15 @@ def parse_args() -> VisualizationsConfig:
5969 default = None ,
6070 help = "Kernel operation modes to visualize (forward/backward/full). If not provided, generate for all available modes." ,
6171 )
72+ parser .add_argument (
73+ "--sweep-mode" ,
74+ type = str ,
75+ choices = ["token_length" , "model_config" ],
76+ default = "token_length" ,
77+ help = "Sweep mode used when running the benchmark. "
78+ "'token_length' selects token/sequence-length sweep data (default); "
79+ "'model_config' selects model-configuration sweep data." ,
80+ )
6281 parser .add_argument (
6382 "--extra-config-filter" ,
6483 type = str ,
@@ -93,11 +112,19 @@ def load_data(config: VisualizationsConfig) -> pd.DataFrame:
93112 df = pd .read_csv (DATA_PATH )
94113 df ["extra_benchmark_config" ] = df ["extra_benchmark_config_str" ].apply (json .loads )
95114
96- base_filtered_df = df [
115+ mask = (
97116 (df ["kernel_name" ] == config .kernel_name )
98117 & (df ["metric_name" ] == config .metric_name )
99118 & (df ["kernel_operation_mode" ] == config .kernel_operation_mode )
100- ]
119+ )
120+
121+ # Filter by sweep mode early, before extra_benchmark_config resolution.
122+ if config .sweep_mode == "model_config" :
123+ mask = mask & (df ["x_name" ] == SWEEP_MODE_X_NAME )
124+ elif config .sweep_mode == "token_length" :
125+ mask = mask & (df ["x_name" ] != SWEEP_MODE_X_NAME )
126+
127+ base_filtered_df = df [mask ]
101128
102129 if base_filtered_df .empty :
103130 raise ValueError (
@@ -201,6 +228,14 @@ def plot_data(df: pd.DataFrame, config: VisualizationsConfig):
201228 if col in df .columns :
202229 df [col ] = pd .to_numeric (df [col ], errors = "coerce" )
203230
231+ # Convert x_value to numeric where possible so matplotlib uses a real
232+ # numeric axis (proper proportional spacing). String x_values (e.g.
233+ # model names) stay as-is and will be treated as categorical (evenly spaced).
234+ x_numeric = pd .to_numeric (df ["x_value" ], errors = "coerce" )
235+ is_numeric_x = x_numeric .notna ().all ()
236+ if is_numeric_x :
237+ df ["x_value" ] = x_numeric
238+
204239 xlabel = df ["x_label" ].iloc [0 ]
205240 ylabel = f"{ config .metric_name } ({ df ['metric_unit' ].iloc [0 ]} )"
206241 # Sort by "kernel_provider" to ensure consistent color assignment
@@ -229,12 +264,17 @@ def plot_data(df: pd.DataFrame, config: VisualizationsConfig):
229264 errorbar = None ,
230265 )
231266
267+ # For numeric x axes, show tick labels only at actual data points
268+ if is_numeric_x :
269+ tick_values = sorted (df ["x_value" ].unique ())
270+ ax .set_xticks (tick_values )
271+ ax .set_xticklabels ([str (int (v )) if v == int (v ) else str (v ) for v in tick_values ])
272+
232273 # Seaborn can't plot pre-computed error bars, so we need to do it manually
233274 lines = ax .get_lines ()
234275 colors = [line .get_color () for line in lines ]
235276
236277 for (_ , group_data ), color in zip (df .groupby ("kernel_provider" ), colors ):
237- # for i, row in group_data.iterrows():
238278 y_error_lower = group_data ["y_value_50" ] - group_data ["y_value_20" ]
239279 y_error_upper = group_data ["y_value_80" ] - group_data ["y_value_50" ]
240280 y_error = [y_error_lower , y_error_upper ]
@@ -252,9 +292,10 @@ def plot_data(df: pd.DataFrame, config: VisualizationsConfig):
252292 plt .ylabel (ylabel )
253293 plt .tight_layout ()
254294
295+ sweep_suffix = f"_{ config .sweep_mode } " if config .sweep_mode else ""
255296 out_path = os .path .join (
256297 VISUALIZATIONS_PATH ,
257- f"{ config .kernel_name } _{ config .metric_name } _{ config .kernel_operation_mode } .png" ,
298+ f"{ config .kernel_name } _{ config .metric_name } _{ config .kernel_operation_mode } { sweep_suffix } .png" ,
258299 )
259300
260301 if config .display :
@@ -288,6 +329,7 @@ def main():
288329 kernel_name = args .kernel_name ,
289330 metric_name = args .metric_name ,
290331 kernel_operation_mode = mode ,
332+ sweep_mode = args .sweep_mode ,
291333 extra_config_filter = args .extra_config_filter ,
292334 display = args .display ,
293335 overwrite = args .overwrite ,
0 commit comments