Skip to content

Commit fea84d6

Browse files
committed
resolve num examples diff
1 parent 31bffd4 commit fea84d6

File tree

2 files changed

+23
-20
lines changed

2 files changed

+23
-20
lines changed

verifiers/utils/eval_tui.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ class EnvEvalState:
4242
# updated by on_progress callback
4343
progress: int = 0 # completed rollouts
4444
total: int = 0 # total rollouts
45+
num_examples: int = -1 # num examples (-1 means "all", updated by on_start)
46+
rollouts_per_example: int = 1 # rollouts per example (from config)
4547
reward: float = 0.0 # reward (rolling avg)
4648
metrics: dict[str, float] = field(default_factory=dict) # metrics (rolling avg)
4749
error_rate: float = 0.0 # error rate (rolling avg)
@@ -86,14 +88,19 @@ def __init__(self, configs: list[EvalConfig]):
8688
# initialize env states
8789
for config in configs:
8890
total = config.num_examples * config.rollouts_per_example
89-
self.state.envs[config.env_id] = EnvEvalState(total=total)
91+
self.state.envs[config.env_id] = EnvEvalState(
92+
total=total,
93+
num_examples=config.num_examples,
94+
rollouts_per_example=config.rollouts_per_example,
95+
)
9096

9197
def update_env_state(
9298
self,
9399
env_id: str,
94100
status: Literal["pending", "running", "completed", "failed"] | None = None,
95101
progress: int | None = None,
96102
total: int | None = None,
103+
num_examples: int | None = None,
97104
reward: float | None = None,
98105
metrics: dict[str, float] | None = None,
99106
error_rate: float | None = None,
@@ -118,6 +125,9 @@ def update_env_state(
118125
if total is not None:
119126
env_state.total = total
120127

128+
if num_examples is not None:
129+
env_state.num_examples = num_examples
130+
121131
if reward is not None:
122132
env_state.reward = reward
123133

@@ -139,13 +149,10 @@ def update_env_state(
139149
self.refresh()
140150

141151
def _get_error_rate_color(self, error_rate: float) -> str:
142-
"""Get color for error rate: green at 0.0, red at 1.0."""
143-
# clamp to [0, 1]
144-
error_rate = max(0.0, min(1.0, error_rate))
145-
# interpolate from green (0, 255, 0) to red (255, 0, 0)
146-
red = int(255 * error_rate)
147-
green = int(255 * (1 - error_rate))
148-
return f"rgb({red},{green},0)"
152+
"""Get color for error rate: red if > 10%, otherwise default."""
153+
if error_rate > 0.10:
154+
return "red"
155+
return "white"
149156

150157
def _make_metrics_row(
151158
self, reward: float, metrics: dict[str, float], error_rate: float
@@ -208,17 +215,10 @@ def _make_env_panel(self, env_id: str) -> Panel:
208215
config_line.append(" via ", style="dim")
209216
config_line.append(config.client_config.api_base_url, style="white")
210217
config_line.append(" | ", style="dim")
211-
if config.num_examples == -1:
212-
config_line.append("all", style="white")
213-
config_line.append(" examples", style="dim")
214-
config_line.append(" and ", style="dim")
215-
config_line.append(str(config.rollouts_per_example), style="white")
216-
config_line.append(" rollouts", style="dim")
217-
else:
218-
config_line.append(str(config.num_examples), style="white")
219-
config_line.append("x", style="white")
220-
config_line.append(str(config.rollouts_per_example), style="white")
221-
config_line.append(" rollouts", style="dim")
218+
config_line.append(str(env_state.num_examples), style="white")
219+
config_line.append("x", style="white")
220+
config_line.append(str(env_state.rollouts_per_example), style="white")
221+
config_line.append(" rollouts", style="dim")
222222

223223
def fmt_concurrency(val: int) -> str:
224224
return "∞" if val == -1 else str(val)

verifiers/utils/eval_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -393,7 +393,10 @@ async def run_with_progress(env_config: EvalConfig) -> GenerateOutputs:
393393
error_accum = 0
394394

395395
def on_start(total: int) -> None:
396-
tui.update_env_state(env_id, total=total)
396+
# total is num_examples * rollouts_per_example
397+
# compute actual num_examples (resolves -1 to actual count)
398+
num_examples = total // env_config.rollouts_per_example
399+
tui.update_env_state(env_id, total=total, num_examples=num_examples)
397400

398401
def on_progress(all_states: list[State], new_states: list[State]) -> None:
399402
nonlocal error_accum, reward_accum, metrics_accum

0 commit comments

Comments
 (0)