Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions verifiers/scripts/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,9 +192,11 @@ def main():
parser.add_argument(
"--save-results",
"-s",
nargs="?",
const=True,
default=False,
action="store_true",
help="Save results to disk",
metavar="PATH",
help="Save results to disk. Optionally specify custom output path.",
)
# save every n rollouts
parser.add_argument(
Expand Down Expand Up @@ -317,8 +319,9 @@ def main():
print_results=True,
verbose=args.verbose,
# saving
output_dir=args.save_results if isinstance(args.save_results, str) else None,
state_columns=args.state_columns,
save_results=args.save_results,
save_results=bool(args.save_results),
save_every=args.save_every,
save_to_hf_hub=args.save_to_hf_hub,
hf_hub_dataset_name=args.hf_hub_dataset_name,
Expand Down
1 change: 1 addition & 0 deletions verifiers/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ class EvalConfig(BaseModel):
print_results: bool = False
verbose: bool = False
# saving
output_dir: str | None = None
state_columns: list[str] | None = None
save_results: bool = False
save_every: int = -1
Expand Down
10 changes: 7 additions & 3 deletions verifiers/utils/path_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,17 @@ def get_results_path(


def get_eval_results_path(config: EvalConfig) -> Path:
# Use custom output_dir if provided
if config.output_dir is not None:
base_path = Path(config.output_dir)
return get_results_path(config.env_id, config.model, base_path)

# Fall back to default behavior
module_name = config.env_id.replace("-", "_")
local_env_dir = Path(config.env_dir_path) / module_name

if local_env_dir.exists():
base_path = local_env_dir / "outputs"
results_path = get_results_path(config.env_id, config.model, base_path)
else:
base_path = Path("./outputs")
results_path = get_results_path(config.env_id, config.model, base_path)
return results_path
return get_results_path(config.env_id, config.model, base_path)