Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
207cfee
poc vf-eval tui
mikasenghaas Jan 15, 2026
db4fc58
exit on input
mikasenghaas Jan 15, 2026
0767dc2
streaming works
mikasenghaas Jan 15, 2026
bd68594
full width boxes
mikasenghaas Jan 15, 2026
70aa7b1
make env id part of border
mikasenghaas Jan 15, 2026
17bdea0
remove header
mikasenghaas Jan 15, 2026
614e367
use static env config
mikasenghaas Jan 15, 2026
0541206
remove redundant info
mikasenghaas Jan 15, 2026
136939b
show running avg of all metrics
mikasenghaas Jan 15, 2026
3f4e715
spacing
mikasenghaas Jan 15, 2026
0e7bbd7
ckpt
mikasenghaas Jan 15, 2026
70aec2d
fix
mikasenghaas Jan 15, 2026
ed1e77a
final summary + stack
mikasenghaas Jan 15, 2026
e46e631
remove global progress
mikasenghaas Jan 15, 2026
50b615f
spacing
mikasenghaas Jan 15, 2026
d33ced4
unify progress callback behavior
mikasenghaas Jan 15, 2026
c659583
show gen/sem concurrency
mikasenghaas Jan 15, 2026
33d8aaf
show sampling args
mikasenghaas Jan 15, 2026
117dbb8
show saved results path
mikasenghaas Jan 15, 2026
a715e5c
formatting
mikasenghaas Jan 15, 2026
3481ff9
remove print_results
mikasenghaas Jan 15, 2026
83f59d1
show -1 concurrency with infinite
mikasenghaas Jan 15, 2026
b82b620
fix
mikasenghaas Jan 15, 2026
88856b8
on log callback
mikasenghaas Jan 15, 2026
579a005
show save every
mikasenghaas Jan 15, 2026
2004380
fix tests
mikasenghaas Jan 15, 2026
f0b98fe
resolve num_examples=-1
mikasenghaas Jan 15, 2026
6c7f359
show error
mikasenghaas Jan 15, 2026
5c66715
cosmetics
mikasenghaas Jan 15, 2026
29fde50
remove global pbar
mikasenghaas Jan 15, 2026
7474b0c
refactor progress
mikasenghaas Jan 15, 2026
38e9917
refactor accums
mikasenghaas Jan 15, 2026
08c8aa2
fix progress bar
mikasenghaas Jan 15, 2026
25185cd
minor
mikasenghaas Jan 15, 2026
b177569
minor
mikasenghaas Jan 15, 2026
e023756
cleanup
mikasenghaas Jan 15, 2026
e855aab
fix linter
mikasenghaas Jan 15, 2026
a5e1340
cleanup
mikasenghaas Jan 15, 2026
2c77573
resolve num examples diff
mikasenghaas Jan 15, 2026
74f323f
fix
mikasenghaas Jan 15, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion configs/evals/debug.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
[[env]]
env_id = "gsm8k"
num_examples = 1
num_examples = 20
rollouts_per_example = 1
sampling_args = { max_tokens = 1024 }
independent_scoring = true
save_results = true
save_every = 10

[[env]]
env_id = "alphabet-sort"
20 changes: 20 additions & 0 deletions configs/evals/single-turn.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
[[env]]
env_id = "math500"
num_examples = -1
rollouts_per_example = 1

[[env]]
env_id = "aime2024"
num_examples = -1
rollouts_per_example = 8

[[env]]
env_id = "gpqa"
num_examples = -1
rollouts_per_example = 1

[[env]]
env_id = "livecodebench"
num_examples = -1
rollouts_per_example = 1
max_concurrent = 16 # to limit sandbox usage
1 change: 0 additions & 1 deletion docs/reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,6 @@ class EvalConfig(BaseModel):
max_concurrent_generation: int | None = None
max_concurrent_scoring: int | None = None
extra_env_kwargs: dict = {}
print_results: bool = False
verbose: bool = False
state_columns: list[str] | None = None
save_results: bool = False
Expand Down
4 changes: 2 additions & 2 deletions tests/test_eval_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,14 +114,14 @@ def _run_cli(monkeypatch, overrides, capture_all_configs: bool = False):
"temperature": 0.9,
"sampling_args": None,
"verbose": False,
"print_results": False,
"no_interleave_scoring": False,
"state_columns": [],
"save_results": False,
"save_every": -1,
"save_to_hf_hub": False,
"hf_hub_dataset_name": "",
"extra_env_kwargs": {},
"tui": False,
}
base_args.update(overrides)
args_namespace = SimpleNamespace(**base_args)
Expand All @@ -136,7 +136,7 @@ def _run_cli(monkeypatch, overrides, capture_all_configs: bool = False):
monkeypatch.setattr(vf_eval, "setup_logging", lambda *_, **__: None)
monkeypatch.setattr(vf_eval, "load_endpoints", lambda *_: {})

async def fake_run_evaluation(config):
async def fake_run_evaluation(config, **kwargs):
captured["sampling_args"] = dict(config.sampling_args)
captured["configs"].append(config)
metadata = _make_metadata(config)
Expand Down
109 changes: 52 additions & 57 deletions verifiers/envs/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,15 @@
ChatMessage,
GenerateMetadata,
GenerateOutputs,
LogCallback,
Messages,
MessageType,
ModelResponse,
ProgressCallback,
RolloutInput,
RolloutTiming,
SamplingArgs,
StartCallback,
State,
)
from verifiers.utils.async_utils import maybe_semaphore
Expand Down Expand Up @@ -830,17 +833,25 @@ async def generate(
state_columns: list[str] | None = None,
save_results: bool = False,
save_every: int = -1,
use_tqdm: bool = True,
independent_scoring: bool = False,
on_start: StartCallback | None = None,
on_progress: ProgressCallback | None = None,
on_log: LogCallback | None = None,
) -> GenerateOutputs:
"""
Generate rollouts for a set of inputs.
"""
on_log = on_log or self.logger.debug

if isinstance(inputs, Dataset):
inputs_list = inputs.to_list()
elif isinstance(inputs, list):
inputs_list = inputs

# notify caller of actual total count (useful when num_examples=-1)
if on_start is not None:
on_start(len(inputs_list))

# resolve concurrency knobs
gen_limit = max_concurrent_generation
score_limit = max_concurrent_scoring
Expand Down Expand Up @@ -876,8 +887,6 @@ async def generate(
)
)
tasks[task] = i
pbar_total = len(inputs_list)
pbar_desc = f"Processing {len(inputs_list)} rollouts"
else:
input_groups: dict[int, list[RolloutInput]] = {}
for input_item in inputs_list:
Expand All @@ -899,62 +908,41 @@ async def generate(
)
)
tasks[task] = i
pbar_total = len(group_list)
pbar_desc = f"Processing {len(group_list)} groups ({len(inputs_list)} total rollouts)"

# set up progress bar
pbar = None
if use_tqdm:
from tqdm import tqdm

pbar = tqdm(total=pbar_total, desc=pbar_desc, postfix=dict(reward="?"))

# process tasks as they complete
reward_sum, reward_count = 0, 0
groups_or_rollouts_completed = 0
completed_groups_or_rollouts = 0
total_groups_or_rollouts = len(tasks)
all_states: list[State] = []
try:
for coro in asyncio.as_completed(tasks.keys()):
result = await coro
# normalize: independent_scoring returns State, group returns list[State]
states = [result] if independent_scoring else result
all_states.extend(states)
groups_or_rollouts_completed += 1

# track reward for rolling average
for s in states:
r = s.get("reward")
if r is not None:
reward_sum += r
reward_count += 1

if pbar is not None:
pbar.update(1)
if reward_count > 0:
pbar.set_postfix(reward=f"{reward_sum / reward_count:.3f}")

# save intermediate results
if (
save_results
and save_every > 0
and groups_or_rollouts_completed % save_every == 0
):
temp_results = self._prepare_rollout_results(
all_states,
model,
client,
state_columns,
results_path,
gen_sampling_args,
start_time,
)
self.logger.debug(
f"Saving intermediate results to {temp_results['metadata']['path_to_save']}"
)
save_rollout_results(temp_results)
finally:
if pbar is not None:
pbar.close()
for coro in asyncio.as_completed(tasks.keys()):
result = await coro
# normalize: independent_scoring returns State, group returns list[State]
new_states = [result] if independent_scoring else result
all_states.extend(new_states)
completed_groups_or_rollouts += 1

# call progress callback with all finished states and new states
if on_progress is not None:
on_progress(all_states, new_states)

# save intermediate results
if (
save_results
and save_every > 0
and completed_groups_or_rollouts % save_every == 0
):
temp_results = self._prepare_rollout_results(
all_states,
model,
client,
state_columns,
results_path,
gen_sampling_args,
start_time,
)
on_log(
f"Saving intermediate results ({completed_groups_or_rollouts}/{total_groups_or_rollouts} {('rollouts' if independent_scoring else 'groups')}) to {temp_results['metadata']['path_to_save']}"
)
save_rollout_results(temp_results)

# sort by example_id to ensure deterministic ordering regardless of completion order
all_states.sort(key=lambda s: s.get("example_id", 0))
Expand All @@ -969,9 +957,10 @@ async def generate(
start_time,
)

# Save if requested
# save if requested
if save_results:
save_rollout_results(results)
on_log(f"Saved final results to {results['metadata']['path_to_save']}")

return results

Expand Down Expand Up @@ -1041,6 +1030,9 @@ async def evaluate(
save_results: bool = False,
save_every: int = -1,
independent_scoring: bool = False,
on_start: StartCallback | None = None,
on_progress: ProgressCallback | None = None,
on_log: LogCallback | None = None,
**kwargs,
) -> GenerateOutputs:
"""
Expand All @@ -1060,6 +1052,9 @@ async def evaluate(
save_results=save_results,
save_every=save_every,
independent_scoring=independent_scoring,
on_start=on_start,
on_progress=on_progress,
on_log=on_log,
**kwargs,
)

Expand Down
13 changes: 12 additions & 1 deletion verifiers/scripts/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
load_endpoints,
load_toml_config,
run_multi_evaluation,
run_multi_evaluation_tui,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -250,6 +251,13 @@ def main():
default={},
help='Extra environment as JSON object (e.g., \'{"key": "value", "num": 42}\'). Passed to environment constructor.',
)
parser.add_argument(
"--tui",
"-u",
default=False,
action="store_true",
help="Use TUI mode for live evaluation display",
)
args = parser.parse_args()

setup_logging("DEBUG" if args.verbose else os.getenv("VF_LOG_LEVEL", "INFO"))
Expand Down Expand Up @@ -415,7 +423,10 @@ def resolve_eval_config(raw_env_config: dict) -> EvalConfig:
logger.debug(f"Evaluation config: {eval_config.model_dump_json(indent=2)}")

multi_eval_config = MultiEvalConfig(env=eval_configs)
asyncio.run(run_multi_evaluation(multi_eval_config))
if args.tui:
asyncio.run(run_multi_evaluation_tui(multi_eval_config))
else:
asyncio.run(run_multi_evaluation(multi_eval_config))


if __name__ == "__main__":
Expand Down
6 changes: 6 additions & 0 deletions verifiers/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,11 @@ def get(self, key: str, default: Any = None) -> Any:
# oai tools
JsonPrimitive = Literal["string", "number", "integer", "boolean", "array", "object"]

# callbacks
StartCallback = Callable[[int], None] # total rollouts
ProgressCallback = Callable[[list[State], list[State]], None] # all_states, new_states
LogCallback = Callable[[str], None] # log messages


class GenerateMetadata(TypedDict):
"""Pydantic model for generation metadata."""
Expand Down Expand Up @@ -237,6 +242,7 @@ class EvalConfig(BaseModel):
extra_env_kwargs: dict = {}
# logging
verbose: bool = False
use_tqdm: bool = True
# saving
state_columns: list[str] | None = None
save_results: bool = False
Expand Down
Loading
Loading