Skip to content

Commit 0102f2d

Browse files
committed
For proposer, solver model, skip file exists check as they might be on HF hub.
1 parent 977398f commit 0102f2d

File tree

3 files changed

+11
-29
lines changed

3 files changed

+11
-29
lines changed

wandering_light/evals/evaluate_proposer.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -290,15 +290,9 @@ def file_evaluate_proposer(
290290
if isinstance(model, str):
291291
model = TrainedLLMTokenGenerator(model)
292292
if isinstance(solver_model, str):
293-
if os.path.exists(solver_model):
294-
solver_model = create_token_solver(
295-
TrainedLLMTokenGenerator(solver_model), budget=1
296-
)
297-
else:
298-
print(
299-
f"Solver model not found at {solver_model}, skipping solver based evaluation"
300-
)
301-
solver_model = None
293+
solver_model = create_token_solver(
294+
TrainedLLMTokenGenerator(solver_model), budget=1
295+
)
302296

303297
trajectories = TrajectoryList.from_file(eval_file)
304298
num_samples = num_samples or len(trajectories)

wandering_light/training/rl_grpo.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -200,18 +200,11 @@ def _load_eval_data(self):
200200

201201
# Load solver model for proposer task
202202
if self.task == Task.PROPOSER:
203-
if os.path.exists(DEFAULT_SOLVER_CHECKPOINT):
204-
self.solver_model = create_token_solver(
205-
TrainedLLMTokenGenerator(
206-
DEFAULT_SOLVER_CHECKPOINT, temperature=0.8
207-
),
208-
budget=1,
209-
)
210-
logger.info(f"Loaded solver model from {DEFAULT_SOLVER_CHECKPOINT}")
211-
else:
212-
logger.warning(
213-
"Solver checkpoint not found, skipping solver based evaluation for proposer task"
214-
)
203+
self.solver_model = create_token_solver(
204+
TrainedLLMTokenGenerator(DEFAULT_SOLVER_CHECKPOINT, temperature=0.8),
205+
budget=1,
206+
)
207+
logger.info(f"Loaded solver model from {DEFAULT_SOLVER_CHECKPOINT}")
215208

216209
def _run_evaluation(self, model_path: str):
217210
"""Run evaluation on the current model using pre-computed trajectories."""

wandering_light/training/sft.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -54,14 +54,9 @@ def _load_eval_data(self):
5454
self.trajectories = None
5555
self.available_functions = None
5656
if self.task == Task.PROPOSER:
57-
if os.path.exists(DEFAULT_SOLVER_CHECKPOINT):
58-
self.solver_model = create_token_solver(
59-
TrainedLLMTokenGenerator(DEFAULT_SOLVER_CHECKPOINT), budget=1
60-
)
61-
else:
62-
print(
63-
"Solver checkpoint not found, skipping solver based evaluation for proposer task"
64-
)
57+
self.solver_model = create_token_solver(
58+
TrainedLLMTokenGenerator(DEFAULT_SOLVER_CHECKPOINT), budget=1
59+
)
6560

6661
def _run_evaluation(self, model_path: str):
6762
"""Run evaluation on the current model using pre-computed trajectories."""

0 commit comments

Comments
 (0)