Skip to content

Commit d68a19b

Browse files
committed
fix unittest and bots_reward
1 parent b6197d8 commit d68a19b

File tree

4 files changed

+31
-15
lines changed

4 files changed

+31
-15
lines changed

examples/bots/workflow/bots_reward.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
# Adapted from Reasoning360: https://github.com/LLM360/Reasoning360/blob/main/verl/utils/reward_score/naive_dapo.py
22

3+
import concurrent
34
import contextlib
45
import math
56
import re
7+
import resource
68
from math import isclose
79
from typing import Optional, Union
810

@@ -585,17 +587,25 @@ def should_allow_eval(expr: str):
585587

586588
# @timeout(timeout_seconds=10)
587589
def are_equal_under_sympy(ground_truth_normalized: str, given_normalized: str):
588-
are_equal = False
589-
try:
590+
def check_equal():
591+
memory_size = 1024**3
592+
resource.setrlimit(resource.RLIMIT_AS, (memory_size, memory_size))
593+
590594
expr = f"({ground_truth_normalized})-({given_normalized})"
591595
if should_allow_eval(expr):
592596
sympy_diff = _sympy_parse(expr)
593597
simplified = sympy.simplify(sympy_diff)
594598
if simplified == 0:
595-
are_equal = True
596-
except Exception:
597-
pass
598-
return are_equal
599+
return True
600+
return False
601+
602+
with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor:
603+
future = executor.submit(check_equal)
604+
try:
605+
return future.result(timeout=10)
606+
except (concurrent.futures.TimeoutError, Exception):
607+
future.cancel()
608+
return False
599609

600610

601611
def split_tuple(expr: str):

trinity/buffer/pipelines/task_pipeline.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66

77
def check_and_run_task_pipeline(config: Config) -> Dict:
8+
if config.mode not in {"explore", "train", "both"}:
9+
return {}
810
if config.data_processor.task_pipeline is None:
911
return {}
1012

trinity/cli/launcher.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ def bench(config: Config) -> None:
3737

3838
def explore(config: Config) -> None:
3939
"""Run explorer."""
40-
check_and_run_task_pipeline(config)
4140
try:
4241
explorer = Explorer.get_actor(config)
4342
ray.get(explorer.prepare.remote())
@@ -82,7 +81,6 @@ def both(config: Config) -> None:
8281
the latest step. The specific number of experiences may vary for different
8382
algorithms and tasks.
8483
"""
85-
check_and_run_task_pipeline(config)
8684
try:
8785
explorer = Explorer.get_actor(config)
8886
trainer = Trainer.get_actor(config)
@@ -153,6 +151,7 @@ def run_stage(config: Config) -> None:
153151
)
154152
pprint(config)
155153
try:
154+
check_and_run_task_pipeline(config)
156155
MODE_MAP[config.mode](config)
157156
finally:
158157
if config.monitor.enable_ray_timeline:

trinity/common/config.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -853,8 +853,8 @@ def _check_interval(self) -> None:
853853
)
854854

855855
def _check_explorer_input(self) -> None:
856-
if self.mode in {"train", "bench", "serve"}:
857-
# no need to check explorer_input in train/bench/serve mode
856+
if self.mode == "serve":
857+
# no need to check explorer_input in serve mode
858858
return
859859

860860
explorer_input = self.buffer.explorer_input
@@ -864,7 +864,7 @@ def _check_explorer_input(self) -> None:
864864
raise ValueError("Do not support setting `taskset` and `tasksets` simultaneously!")
865865
explorer_input.tasksets = [explorer_input.taskset]
866866
explorer_input.taskset = None
867-
elif len(explorer_input.tasksets) == 0:
867+
elif self.mode not in {"bench", "train"} and len(explorer_input.tasksets) == 0:
868868
raise ValueError("At least one taskset should be provided in explorer_input!")
869869

870870
for i, taskset in enumerate(explorer_input.tasksets):
@@ -913,8 +913,8 @@ def _check_explorer_input(self) -> None:
913913
set_if_none(dataset.rollout_args, "max_tokens", self.model.max_response_tokens)
914914

915915
def _check_trainer_input(self) -> None:
916-
if self.mode in {"explore", "bench", "serve"}:
917-
# no need to check trainer_input in explore/bench/serve mode
916+
if self.mode in {"bench", "serve"}:
917+
# no need to check trainer_input in bench/serve mode
918918
return
919919

920920
trainer_input = self.buffer.trainer_input
@@ -986,10 +986,15 @@ def _check_data_processor(self) -> None:
986986
)
987987

988988
task_pipeline = self.data_processor.task_pipeline
989-
if task_pipeline is not None and self.mode in {"explore", "both"}:
989+
if task_pipeline is not None and self.mode in {"explore", "train", "both"}:
990990
if task_pipeline.output is None:
991991
if self.mode != "train":
992-
task_pipeline.output = self.buffer.explorer_input.tasksets[0]
992+
if len(self.buffer.explorer_input.tasksets) > 0:
993+
task_pipeline.output = self.buffer.explorer_input.tasksets[0]
994+
else:
995+
raise ValueError(
996+
"At least one taskset should be provided in explorer_input!"
997+
)
993998
elif self.mode == "train" and self.algorithm.algorithm_type in {"dpo", "sft"}:
994999
task_pipeline.output = self.buffer.trainer_input.experience_buffer
9951000
else:

0 commit comments

Comments
 (0)