@@ -42,6 +42,8 @@ class EnvEvalState:
4242 # updated by on_progress callback
4343 progress : int = 0 # completed rollouts
4444 total : int = 0 # total rollouts
45+ num_examples : int = - 1 # num examples (-1 means "all", updated by on_start)
46+ rollouts_per_example : int = 1 # rollouts per example (from config)
4547 reward : float = 0.0 # reward (rolling avg)
4648 metrics : dict [str , float ] = field (default_factory = dict ) # metrics (rolling avg)
4749 error_rate : float = 0.0 # error rate (rolling avg)
@@ -86,14 +88,19 @@ def __init__(self, configs: list[EvalConfig]):
8688 # initialize env states
8789 for config in configs :
8890 total = config .num_examples * config .rollouts_per_example
89- self .state .envs [config .env_id ] = EnvEvalState (total = total )
91+ self .state .envs [config .env_id ] = EnvEvalState (
92+ total = total ,
93+ num_examples = config .num_examples ,
94+ rollouts_per_example = config .rollouts_per_example ,
95+ )
9096
9197 def update_env_state (
9298 self ,
9399 env_id : str ,
94100 status : Literal ["pending" , "running" , "completed" , "failed" ] | None = None ,
95101 progress : int | None = None ,
96102 total : int | None = None ,
103+ num_examples : int | None = None ,
97104 reward : float | None = None ,
98105 metrics : dict [str , float ] | None = None ,
99106 error_rate : float | None = None ,
@@ -118,6 +125,9 @@ def update_env_state(
118125 if total is not None :
119126 env_state .total = total
120127
128+ if num_examples is not None :
129+ env_state .num_examples = num_examples
130+
121131 if reward is not None :
122132 env_state .reward = reward
123133
@@ -139,13 +149,10 @@ def update_env_state(
139149 self .refresh ()
140150
141151 def _get_error_rate_color (self , error_rate : float ) -> str :
142- """Get color for error rate: green at 0.0, red at 1.0."""
143- # clamp to [0, 1]
144- error_rate = max (0.0 , min (1.0 , error_rate ))
145- # interpolate from green (0, 255, 0) to red (255, 0, 0)
146- red = int (255 * error_rate )
147- green = int (255 * (1 - error_rate ))
148- return f"rgb({ red } ,{ green } ,0)"
152+ """Get color for error rate: red if > 10%, otherwise default."""
153+ if error_rate > 0.10 :
154+ return "red"
155+ return "white"
149156
150157 def _make_metrics_row (
151158 self , reward : float , metrics : dict [str , float ], error_rate : float
@@ -208,17 +215,10 @@ def _make_env_panel(self, env_id: str) -> Panel:
208215 config_line .append (" via " , style = "dim" )
209216 config_line .append (config .client_config .api_base_url , style = "white" )
210217 config_line .append (" | " , style = "dim" )
211- if config .num_examples == - 1 :
212- config_line .append ("all" , style = "white" )
213- config_line .append (" examples" , style = "dim" )
214- config_line .append (" and " , style = "dim" )
215- config_line .append (str (config .rollouts_per_example ), style = "white" )
216- config_line .append (" rollouts" , style = "dim" )
217- else :
218- config_line .append (str (config .num_examples ), style = "white" )
219- config_line .append ("x" , style = "white" )
220- config_line .append (str (config .rollouts_per_example ), style = "white" )
221- config_line .append (" rollouts" , style = "dim" )
218+ config_line .append (str (env_state .num_examples ), style = "white" )
219+ config_line .append ("x" , style = "white" )
220+ config_line .append (str (env_state .rollouts_per_example ), style = "white" )
221+ config_line .append (" rollouts" , style = "dim" )
222222
223223 def fmt_concurrency (val : int ) -> str :
224224 return "∞" if val == - 1 else str (val )
0 commit comments