@@ -107,26 +107,113 @@ def update_args_with_global(base_args: argparse.Namespace, global_args: List[str
107
107
return updated_args
108
108
109
109
110
+ def _analyze_config_differences (config_a_args : List [str ], config_b_args : List [str ]) -> Dict [str , Tuple [str , str ]]:
111
+ """Analyze differences between two configurations."""
112
+ # Parse arguments into dictionaries
113
+ def parse_config_to_dict (args ):
114
+ config_dict = {}
115
+ i = 0
116
+ while i < len (args ):
117
+ if args [i ].startswith ('--' ):
118
+ key = args [i ][2 :] # Remove --
119
+ if '=' in args [i ]:
120
+ # Format: --key=value
121
+ key , value = args [i ][2 :].split ('=' , 1 )
122
+ config_dict [key ] = value
123
+ i += 1
124
+ elif i + 1 < len (args ) and not args [i + 1 ].startswith ('-' ):
125
+ # Format: --key value
126
+ config_dict [key ] = args [i + 1 ]
127
+ i += 2
128
+ else :
129
+ # Flag without value
130
+ config_dict [key ] = "True"
131
+ i += 1
132
+ else :
133
+ i += 1
134
+ return config_dict
135
+
136
+ config_a = parse_config_to_dict (config_a_args )
137
+ config_b = parse_config_to_dict (config_b_args )
138
+
139
+ # Find differences
140
+ differences = {}
141
+ all_keys = set (config_a .keys ()) | set (config_b .keys ())
142
+
143
+ for key in all_keys :
144
+ val_a = config_a .get (key , "default" )
145
+ val_b = config_b .get (key , "default" )
146
+ if val_a != val_b :
147
+ differences [key ] = (val_a , val_b )
148
+
149
+ return differences
150
+
151
+
152
+ def _calculate_performance_summary (result_a : BenchmarkOperatorResult , result_b : BenchmarkOperatorResult ,
153
+ common_x_vals : List , common_backends : List [str ]) -> Dict [str , Dict [str , float ]]:
154
+ """Calculate performance summary statistics."""
155
+ summary = {}
156
+
157
+ # Create result dictionaries for easier lookup
158
+ result_dict_a = {x_val : metrics_dict for x_val , metrics_dict in result_a .result }
159
+ result_dict_b = {x_val : metrics_dict for x_val , metrics_dict in result_b .result }
160
+
161
+ for backend in common_backends :
162
+ backend_summary = {}
163
+
164
+ for metric in result_a .metrics :
165
+ improvements = []
166
+
167
+ for x_val in common_x_vals :
168
+ if (backend in result_dict_a [x_val ] and backend in result_dict_b [x_val ]):
169
+ metrics_a = result_dict_a [x_val ][backend ]
170
+ metrics_b = result_dict_b [x_val ][backend ]
171
+
172
+ val_a = getattr (metrics_a , metric , None )
173
+ val_b = getattr (metrics_b , metric , None )
174
+
175
+ if val_a is not None and val_b is not None :
176
+ # Handle different metric types
177
+ if hasattr (val_a , 'p50' ):
178
+ val_a_num = val_a .p50
179
+ else :
180
+ val_a_num = val_a
181
+
182
+ if hasattr (val_b , 'p50' ):
183
+ val_b_num = val_b .p50
184
+ else :
185
+ val_b_num = val_b
186
+
187
+ if val_a_num != 0 :
188
+ improvement = ((val_b_num - val_a_num ) / val_a_num ) * 100
189
+ improvements .append (improvement )
190
+
191
+ if improvements :
192
+ backend_summary [metric ] = {
193
+ 'avg_improvement' : sum (improvements ) / len (improvements ),
194
+ 'min_improvement' : min (improvements ),
195
+ 'max_improvement' : max (improvements ),
196
+ 'count' : len (improvements )
197
+ }
198
+
199
+ summary [backend ] = backend_summary
200
+
201
+ return summary
202
+
203
+
110
204
def compare_ab_results (result_a : BenchmarkOperatorResult , result_b : BenchmarkOperatorResult ,
111
205
config_a_args : List [str ], config_b_args : List [str ]):
112
- """Compare A/B test results and display formatted comparison. """
206
+ """Compare A/B test results"""
113
207
if not result_a or not result_b :
114
208
print ("\n [A/B Comparison] ERROR: One or both results are invalid" )
115
209
return
116
210
117
- print ("\n " + "=" * 80 )
118
- print (f"[A/B Test Results Comparison] - { result_a .op_name } " )
119
- print ("=" * 80 )
120
- print (f"Configuration A: { ' ' .join (config_a_args )} " )
121
- print (f"Configuration B: { ' ' .join (config_b_args )} " )
122
- print ()
123
-
124
211
# Check if both results have data
125
212
if not result_a .result or not result_b .result :
126
213
print ("ERROR: No benchmark data available for comparison" )
127
214
return
128
215
129
- # Get all x_vals (input shapes) that are common to both results
216
+ # Get common data for analysis
130
217
x_vals_a = {x_val for x_val , _ in result_a .result }
131
218
x_vals_b = {x_val for x_val , _ in result_b .result }
132
219
common_x_vals = sorted (x_vals_a .intersection (x_vals_b ))
@@ -135,11 +222,10 @@ def compare_ab_results(result_a: BenchmarkOperatorResult, result_b: BenchmarkOpe
135
222
print ("ERROR: No common input shapes found between configurations" )
136
223
return
137
224
138
- # Create result dictionaries for easier lookup
225
+ # Get common backends
139
226
result_dict_a = {x_val : metrics_dict for x_val , metrics_dict in result_a .result }
140
227
result_dict_b = {x_val : metrics_dict for x_val , metrics_dict in result_b .result }
141
228
142
- # Get all backends that are common to both results
143
229
all_backends_a = set ()
144
230
all_backends_b = set ()
145
231
for x_val in common_x_vals :
@@ -151,58 +237,89 @@ def compare_ab_results(result_a: BenchmarkOperatorResult, result_b: BenchmarkOpe
151
237
print ("ERROR: No common backends found between configurations" )
152
238
return
153
239
154
- print (f"Comparing { len (common_x_vals )} input shapes across { len (common_backends )} backends" )
240
+ # ============================================================================
241
+ # SECTION 1: Configuration Analysis
242
+ # ============================================================================
243
+ print ("\n " + "=" * 70 )
244
+ print (f"A/B Test Results: { result_a .op_name } " )
245
+ print ("=" * 70 )
246
+
247
+ print ("Configuration Differences:" )
248
+ differences = _analyze_config_differences (config_a_args , config_b_args )
249
+
250
+ if differences :
251
+ for param , (val_a , val_b ) in differences .items ():
252
+ print (f" { param :<15} : { val_a :<15} → { val_b } " )
253
+ else :
254
+ print (" No configuration differences detected" )
255
+
256
+ print (f"\n Test Scope: { len (common_x_vals )} input shapes, { len (common_backends )} backends" )
155
257
print (f"Metrics: { ', ' .join (result_a .metrics )} " )
156
- print ()
157
258
158
- # Create comparison table
159
- x_val_name = REGISTERED_X_VALS .get (result_a .op_name , "x_val" )
259
+ # ============================================================================
260
+ # SECTION 2: Performance Summary
261
+ # ============================================================================
262
+ print ("\n " + "-" * 70 )
263
+ print ("Performance Summary" )
264
+ print ("-" * 70 )
265
+
266
+ summary = _calculate_performance_summary (result_a , result_b , common_x_vals , common_backends )
160
267
161
268
for backend in common_backends :
162
- print (f"Backend: { backend } " )
163
- print ( "-" * 60 )
269
+ print (f"\n { backend } : " )
270
+ backend_data = summary . get ( backend , {} )
164
271
165
- # Create table headers
166
- headers = [x_val_name ]
167
- for metric in result_a .metrics :
168
- headers .extend ([f"{ metric } _A" , f"{ metric } _B" , f"{ metric } _diff%" ])
272
+ if not backend_data :
273
+ print (" No comparable data" )
274
+ continue
169
275
170
- # Print headers
171
- print ("{:<15} " .format (headers [0 ]), end = "" )
172
- for i in range (1 , len (headers )):
173
- print ("{:<12} " .format (headers [i ]), end = "" )
174
- print ()
175
- print ("-" * (15 + 12 * (len (headers ) - 1 )))
276
+ for metric , stats in backend_data .items ():
277
+ avg_improvement = stats ['avg_improvement' ]
278
+ min_improvement = stats ['min_improvement' ]
279
+ max_improvement = stats ['max_improvement' ]
280
+
281
+ print (f" { metric :<12} : { avg_improvement :+5.1f} % avg [{ min_improvement :+.1f} % to { max_improvement :+.1f} %]" )
282
+
283
+ # ============================================================================
284
+ # SECTION 3: Detailed Comparison (Compact)
285
+ # ============================================================================
286
+ print ("\n " + "-" * 70 )
287
+ print ("Detailed Comparison" )
288
+ print ("-" * 70 )
289
+
290
+ x_val_name = REGISTERED_X_VALS .get (result_a .op_name , "x_val" )
291
+
292
+ # Show all metrics for detailed comparison
293
+ for metric in result_a .metrics :
294
+ print (f"\n Metric: { metric } " )
295
+ print ("Backend" .ljust (15 ), end = "" )
296
+ print (x_val_name .ljust (20 ), end = "" )
297
+ print ("Config A" .ljust (12 ), end = "" )
298
+ print ("Config B" .ljust (12 ), end = "" )
299
+ print ("Difference" .ljust (12 ))
300
+ print ("-" * 71 )
176
301
177
- # Print data rows
178
- for x_val in common_x_vals :
179
- if backend not in result_dict_a [x_val ] or backend not in result_dict_b [x_val ]:
180
- continue
302
+ for backend in common_backends :
303
+ first_row = True
304
+ for x_val in common_x_vals :
305
+ if backend not in result_dict_a [x_val ] or backend not in result_dict_b [x_val ]:
306
+ continue
307
+
308
+ metrics_a = result_dict_a [x_val ][backend ]
309
+ metrics_b = result_dict_b [x_val ][backend ]
181
310
182
- metrics_a = result_dict_a [x_val ][backend ]
183
- metrics_b = result_dict_b [x_val ][backend ]
184
-
185
- # Print x_val
186
- print ("{:<15} " .format (str (x_val )), end = "" )
187
-
188
- # Print metrics comparisons
189
- for metric in result_a .metrics :
190
311
val_a = getattr (metrics_a , metric , None )
191
312
val_b = getattr (metrics_b , metric , None )
192
313
193
314
if val_a is not None and val_b is not None :
194
- # Handle latency objects
315
+ # Handle different data types
195
316
if hasattr (val_a , 'p50' ):
196
317
val_a_num = val_a .p50
197
- else :
198
- val_a_num = val_a
199
-
200
- if hasattr (val_b , 'p50' ):
201
318
val_b_num = val_b .p50
202
319
else :
320
+ val_a_num = val_a
203
321
val_b_num = val_b
204
322
205
- # Calculate percentage difference
206
323
if val_a_num != 0 :
207
324
diff_pct = ((val_b_num - val_a_num ) / val_a_num ) * 100
208
325
else :
@@ -216,13 +333,14 @@ def compare_ab_results(result_a: BenchmarkOperatorResult, result_b: BenchmarkOpe
216
333
val_a_str = str (val_a_num )
217
334
val_b_str = str (val_b_num )
218
335
219
- print ("{:<12} {:<12} {:<12} " .format (
220
- val_a_str , val_b_str , f"{ diff_pct :+.1f} %"
221
- ), end = "" )
222
- else :
223
- print ("{:<12} {:<12} {:<12} " .format ("N/A" , "N/A" , "N/A" ), end = "" )
224
- print ()
225
- print ()
336
+ # Print row
337
+ backend_name = backend if first_row else ""
338
+ print (f"{ backend_name :<15} { str (x_val ):<20} { val_a_str :<12} { val_b_str :<12} { diff_pct :+5.1f} %" )
339
+ first_row = False
340
+
341
+ if not first_row : # Only print separator if we printed data
342
+ print ()
343
+
226
344
227
345
228
346
def run_ab_test (base_args : argparse .Namespace , base_extra_args : List [str ], _run_func ) -> Tuple [BenchmarkOperatorResult , BenchmarkOperatorResult ]:
0 commit comments