Skip to content

Commit ae3eb7d

Browse files
committed
fix mode check
1 parent 777ab91 commit ae3eb7d

File tree

3 files changed

+16
-8
lines changed

3 files changed

+16
-8
lines changed

trinity/cli/launcher.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def bench(config: Config) -> None:
3737

3838
def explore(config: Config) -> None:
3939
"""Run explorer."""
40+
check_and_run_task_pipeline(config)
4041
try:
4142
explorer = Explorer.get_actor(config)
4243
ray.get(explorer.prepare.remote())
@@ -81,6 +82,7 @@ def both(config: Config) -> None:
8182
the latest step. The specific number of experiences may vary for different
8283
algorithms and tasks.
8384
"""
85+
check_and_run_task_pipeline(config)
8486
try:
8587
explorer = Explorer.get_actor(config)
8688
trainer = Trainer.get_actor(config)
@@ -151,7 +153,6 @@ def run_stage(config: Config) -> None:
151153
)
152154
pprint(config)
153155
try:
154-
check_and_run_task_pipeline(config)
155156
MODE_MAP[config.mode](config)
156157
finally:
157158
if config.monitor.enable_ray_timeline:

trinity/common/config.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -853,8 +853,8 @@ def _check_interval(self) -> None:
853853
)
854854

855855
def _check_explorer_input(self) -> None:
856-
if self.mode == "train":
857-
# no need to check explorer_input in train mode
856+
if self.mode in {"train", "bench", "serve"}:
857+
# no need to check explorer_input in train/bench/serve mode
858858
return
859859

860860
explorer_input = self.buffer.explorer_input
@@ -866,9 +866,8 @@ def _check_explorer_input(self) -> None:
866866
explorer_input.taskset = None
867867
elif len(explorer_input.tasksets) == 0:
868868
raise ValueError("At least one taskset should be provided in explorer_input!")
869-
tasksets = explorer_input.tasksets
870869

871-
for i, taskset in enumerate(tasksets):
870+
for i, taskset in enumerate(explorer_input.tasksets):
872871
if self.mode != "train" and not taskset.path:
873872
raise ValueError(
874873
"`buffer.explorer_input.taskset.path` is required, please set it to the path of the taskset."
@@ -914,6 +913,10 @@ def _check_explorer_input(self) -> None:
914913
set_if_none(dataset.rollout_args, "max_tokens", self.model.max_response_tokens)
915914

916915
def _check_trainer_input(self) -> None:
916+
if self.mode in {"explore", "bench", "serve"}:
917+
# no need to check trainer_input in train/bench/serve mode
918+
return
919+
917920
trainer_input = self.buffer.trainer_input
918921
experience_buffer = trainer_input.experience_buffer
919922

@@ -973,7 +976,7 @@ def _default_storage_path(self, storage_type: StorageType, name: str) -> str:
973976
def _check_data_processor(self) -> None:
974977
# check input/output buffers in pipelines
975978
experience_pipeline = self.data_processor.experience_pipeline
976-
if experience_pipeline is not None:
979+
if experience_pipeline is not None and self.mode in {"explore", "both", "serve"}:
977980
if experience_pipeline.save_input and experience_pipeline.input_save_path is None:
978981
experience_pipeline.input_save_path = os.path.join(
979982
self.buffer.cache_dir, "explorer_output.jsonl" # type: ignore[arg-type]
@@ -983,7 +986,7 @@ def _check_data_processor(self) -> None:
983986
)
984987

985988
task_pipeline = self.data_processor.task_pipeline
986-
if task_pipeline is not None:
989+
if task_pipeline is not None and self.mode in {"explore", "both"}:
987990
if task_pipeline.output is None:
988991
if self.mode != "train":
989992
task_pipeline.output = self.buffer.explorer_input.tasksets[0]

trinity/explorer/explorer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,9 @@ def __init__(self, config: Config):
5252
self.models, self.auxiliary_models = create_inference_models(config)
5353
self.experience_pipeline = self._init_experience_pipeline()
5454
self.taskset = (
55-
TasksetScheduler(explorer_state, config) if self.config.mode != "serve" else None
55+
TasksetScheduler(explorer_state, config)
56+
if self.config.mode not in {"bench", "serve"}
57+
else None
5658
)
5759
self.scheduler = None
5860
self.monitor = MONITOR.get(self.config.monitor.monitor_type)(
@@ -406,6 +408,8 @@ async def is_alive(self) -> bool:
406408

407409
def _init_experience_pipeline(self) -> ray.actor.ActorHandle:
408410
"""Init experience pipeline for the explorer."""
411+
if self.config.mode == "bench":
412+
return None
409413
node_id = ray.get_runtime_context().get_node_id()
410414
return (
411415
ray.remote(ExperiencePipeline)

0 commit comments

Comments
 (0)