29
29
# --- Cost Loading ---
30
30
31
31
32
- def load_model_costs (file_path : str ) -> Dict :
33
- """Loads model costs from a CSV file and returns a structured dictionary .
32
+ def load_model_costs (file_path : str ) -> tuple [ Dict , Dict ] :
33
+ """Loads model costs and friendly names from a CSV file and returns structured dictionaries .
34
34
35
35
Args:
36
36
file_path: The path to the cost file.
37
37
38
38
Returns:
39
- A dictionary containing the model costs .
39
+ A tuple containing (model_costs_dict, friendly_names_dict) .
40
40
"""
41
41
try :
42
42
with open (file_path , "r" , encoding = "utf-8" ) as f :
@@ -45,24 +45,45 @@ def load_model_costs(file_path: str) -> Dict:
45
45
line for line in f if not line .strip ().startswith ("#" ) and line .strip ()
46
46
]
47
47
48
- # Find the start of the dictionary-like definition
48
+ # Find the start of the dictionary-like definitions
49
49
dict_str = "" .join (lines )
50
- match = re .search (r"MODEL_COSTS\s*=\s*({.*})" , dict_str , re .DOTALL )
51
- if not match :
52
- st .error (f"Could not find 'MODEL_COSTS' dictionary in { file_path } " )
53
- return {}
54
-
55
- # Safely evaluate the dictionary string
56
- model_costs_raw = eval (match .group (1 ), {"float" : float })
57
50
58
- return model_costs_raw
51
+ # Extract MODEL_COSTS
52
+ costs_match = re .search (r"MODEL_COSTS\s*=\s*({.*})" , dict_str , re .DOTALL )
53
+ if not costs_match :
54
+ st .error (f"Could not find 'MODEL_COSTS' dictionary in { file_path } " )
55
+ return {}, {}
56
+
57
+ # Safely evaluate the dictionary strings
58
+ model_costs_raw = eval (costs_match .group (1 ), {"float" : float })
59
+
60
+ # Extract friendly names from the inline entries
61
+ friendly_names = {}
62
+ model_costs_clean = {}
63
+
64
+ for model_id , model_data in model_costs_raw .items ():
65
+ # Extract friendly name if it exists
66
+ if isinstance (model_data , dict ) and "friendly_name" in model_data :
67
+ friendly_names [model_id ] = model_data ["friendly_name" ]
68
+ # Create a clean copy without the friendly_name for cost calculations
69
+ model_costs_clean [model_id ] = {
70
+ key : value
71
+ for key , value in model_data .items ()
72
+ if key != "friendly_name"
73
+ }
74
+ else :
75
+ # No friendly name, use model_id as fallback
76
+ friendly_names [model_id ] = model_id
77
+ model_costs_clean [model_id ] = model_data
78
+
79
+ return model_costs_clean , friendly_names
59
80
60
81
except FileNotFoundError :
61
82
st .warning (f"Cost file not found at { file_path } . Using empty cost config." )
62
- return {}
83
+ return {}, {}
63
84
except (SyntaxError , NameError , Exception ) as e :
64
85
st .error (f"Error parsing cost file { file_path } : { e } " )
65
- return {}
86
+ return {}, {}
66
87
67
88
68
89
# --- Data Loading and Processing ---
@@ -388,8 +409,76 @@ def create_leaderboard(
388
409
return leaderboard .sort_values ("Correct" , ascending = sort_ascending )
389
410
390
411
412
+ def _calculate_smart_label_positions (
413
+ x_data , y_data , labels , min_distance_threshold = 0.1
414
+ ):
415
+ """Calculate optimal label positions to avoid overlaps.
416
+
417
+ Args:
418
+ x_data: Array of x coordinates (normalized to 0-1 range for distance calc)
419
+ y_data: Array of y coordinates (normalized to 0-1 range for distance calc)
420
+ labels: Array of label strings
421
+ min_distance_threshold: Minimum distance threshold for considering overlap
422
+
423
+ Returns:
424
+ List of textposition strings for each point
425
+ """
426
+ import numpy as np
427
+
428
+ # Normalize coordinates to 0-1 range for distance calculations
429
+ x_norm = (
430
+ (x_data - x_data .min ()) / (x_data .max () - x_data .min ())
431
+ if x_data .max () != x_data .min ()
432
+ else x_data * 0
433
+ )
434
+ y_norm = (
435
+ (y_data - y_data .min ()) / (y_data .max () - y_data .min ())
436
+ if y_data .max () != y_data .min ()
437
+ else y_data * 0
438
+ )
439
+
440
+ positions = ["top center" ] * len (x_data )
441
+ position_options = [
442
+ "top center" ,
443
+ "bottom center" ,
444
+ "middle left" ,
445
+ "middle right" ,
446
+ "top left" ,
447
+ "top right" ,
448
+ "bottom left" ,
449
+ "bottom right" ,
450
+ ]
451
+
452
+ # Calculate distances between all pairs of points
453
+ for i in range (len (x_data )):
454
+ for j in range (i + 1 , len (x_data )):
455
+ distance = np .sqrt (
456
+ (x_norm [i ] - x_norm [j ]) ** 2 + (y_norm [i ] - y_norm [j ]) ** 2
457
+ )
458
+
459
+ if distance < min_distance_threshold :
460
+ # Points are close, try different positions
461
+ for pos_idx , position in enumerate (position_options ):
462
+ if positions [i ] == "top center" :
463
+ positions [i ] = position_options [pos_idx % len (position_options )]
464
+ break
465
+
466
+ for pos_idx , position in enumerate (position_options ):
467
+ if positions [j ] == "top center" or positions [j ] == positions [i ]:
468
+ positions [j ] = position_options [
469
+ (pos_idx + 1 ) % len (position_options )
470
+ ]
471
+ break
472
+
473
+ return positions
474
+
475
+
391
476
def create_pareto_frontier_plot (
392
- df : pd .DataFrame , selected_groups : List [str ], x_axis_mode : str , config : Dict
477
+ df : pd .DataFrame ,
478
+ selected_groups : List [str ],
479
+ x_axis_mode : str ,
480
+ config : Dict ,
481
+ friendly_names : Dict = None ,
393
482
) -> go .Figure :
394
483
"""Visualizes the trade-off between model performance and cost/token usage.
395
484
@@ -436,29 +525,64 @@ def create_pareto_frontier_plot(
436
525
else :
437
526
hover_format = ":.0f"
438
527
439
- fig .add_trace (
440
- go .Scatter (
441
- x = x_data ,
442
- y = model_metrics ["y_axis" ],
443
- mode = "markers+text" ,
444
- marker = dict (
445
- size = 18 ,
446
- color = model_metrics ["color_axis" ],
447
- colorscale = "RdYlGn_r" ,
448
- showscale = True ,
449
- colorbar = dict (title = f"Avg { plot_config ['color_axis' ]} (s)" ),
450
- ),
451
- text = model_metrics ["Model" ],
452
- textposition = "top center" ,
453
- hovertemplate = (
454
- "<b>%{text}</b><br>"
455
- f"{ y_axis_label } : %{{y:.1f}}%<br>"
456
- f"{ hover_label } : %{{x{ hover_format } }}<br>"
457
- f"Avg { plot_config ['color_axis' ]} : %{{marker.color:.1f}}s<extra></extra>"
458
- ),
459
- )
528
+ # Calculate smart label positions to avoid overlaps
529
+ label_positions = _calculate_smart_label_positions (
530
+ x_data .values , model_metrics ["y_axis" ].values , model_metrics ["Model" ].values
460
531
)
461
532
533
+ # Group data by text position to create separate traces
534
+ from collections import defaultdict
535
+
536
+ position_groups = defaultdict (list )
537
+
538
+ for i , position in enumerate (label_positions ):
539
+ position_groups [position ].append (i )
540
+
541
+ # Create a trace for each text position group
542
+ first_trace = True
543
+ for position , indices in position_groups .items ():
544
+ x_vals = [x_data .iloc [i ] for i in indices ]
545
+ y_vals = [model_metrics ["y_axis" ].iloc [i ] for i in indices ]
546
+ colors = [model_metrics ["color_axis" ].iloc [i ] for i in indices ]
547
+
548
+ # Get model names for this position group
549
+ original_names = [model_metrics ["Model" ].iloc [i ] for i in indices ]
550
+
551
+ # Use friendly names for display if available, otherwise use original names
552
+ if friendly_names :
553
+ display_texts = [friendly_names .get (name , name ) for name in original_names ]
554
+ else :
555
+ display_texts = original_names
556
+
557
+ fig .add_trace (
558
+ go .Scatter (
559
+ x = x_vals ,
560
+ y = y_vals ,
561
+ mode = "markers+text" ,
562
+ marker = dict (
563
+ size = 18 ,
564
+ color = colors ,
565
+ colorscale = "RdYlGn_r" ,
566
+ showscale = first_trace , # Show colorbar only on first trace
567
+ colorbar = dict (title = f"Avg { plot_config ['color_axis' ]} (s)" )
568
+ if first_trace
569
+ else None ,
570
+ ),
571
+ text = display_texts , # Use friendly names for display
572
+ textposition = position ,
573
+ customdata = original_names , # Store original names for hover
574
+ hovertemplate = (
575
+ "<b>%{text}</b><br>" # Friendly name as title
576
+ "API Name: %{customdata}<br>" # Original API name
577
+ f"{ y_axis_label } : %{{y:.1f}}%<br>"
578
+ f"{ hover_label } : %{{x{ hover_format } }}<br>"
579
+ f"Avg { plot_config ['color_axis' ]} : %{{marker.color:.1f}}s<extra></extra>"
580
+ ),
581
+ showlegend = False , # Don't show legend for individual position groups
582
+ )
583
+ )
584
+ first_trace = False
585
+
462
586
fig .update_layout (
463
587
title = plot_config ["title" ].format (x_axis_label = x_title ),
464
588
xaxis_title = f"Average { x_title } " ,
@@ -714,7 +838,7 @@ def main() -> None:
714
838
# Cost configuration in sidebar
715
839
st .sidebar .subheader ("💰 Cost Configuration" )
716
840
cost_file_path = os .path .join (os .path .dirname (__file__ ), "costs.csv" )
717
- model_costs = load_model_costs (cost_file_path )
841
+ model_costs , friendly_names = load_model_costs (cost_file_path )
718
842
available_models = sorted (df_initial ["Model" ].unique ())
719
843
720
844
cost_config = {}
@@ -833,7 +957,7 @@ def main() -> None:
833
957
)
834
958
st .plotly_chart (
835
959
create_pareto_frontier_plot (
836
- df , selected_groups , x_axis_mode , eval_config .model_dump ()
960
+ df , selected_groups , x_axis_mode , eval_config .model_dump (), friendly_names
837
961
),
838
962
use_container_width = True ,
839
963
)
0 commit comments