@@ -853,8 +853,8 @@ def _check_interval(self) -> None:
853853 )
854854
855855 def _check_explorer_input (self ) -> None :
856- if self .mode == "train" :
857- # no need to check explorer_input in train mode
856+ if self .mode in { "train" , "bench" , "serve" } :
857+ # no need to check explorer_input in train/bench/serve mode
858858 return
859859
860860 explorer_input = self .buffer .explorer_input
@@ -866,9 +866,8 @@ def _check_explorer_input(self) -> None:
866866 explorer_input .taskset = None
867867 elif len (explorer_input .tasksets ) == 0 :
868868 raise ValueError ("At least one taskset should be provided in explorer_input!" )
869- tasksets = explorer_input .tasksets
870869
871- for i , taskset in enumerate (tasksets ):
870+ for i , taskset in enumerate (explorer_input . tasksets ):
872871 if self .mode != "train" and not taskset .path :
873872 raise ValueError (
874873 "`buffer.explorer_input.taskset.path` is required, please set it to the path of the taskset."
@@ -914,6 +913,10 @@ def _check_explorer_input(self) -> None:
914913 set_if_none (dataset .rollout_args , "max_tokens" , self .model .max_response_tokens )
915914
916915 def _check_trainer_input (self ) -> None :
916+ if self .mode in {"explore" , "bench" , "serve" }:
917+ # no need to check trainer_input in train/bench/serve mode
918+ return
919+
917920 trainer_input = self .buffer .trainer_input
918921 experience_buffer = trainer_input .experience_buffer
919922
@@ -973,7 +976,7 @@ def _default_storage_path(self, storage_type: StorageType, name: str) -> str:
973976 def _check_data_processor (self ) -> None :
974977 # check input/output buffers in pipelines
975978 experience_pipeline = self .data_processor .experience_pipeline
976- if experience_pipeline is not None :
979+ if experience_pipeline is not None and self . mode in { "explore" , "both" , "serve" } :
977980 if experience_pipeline .save_input and experience_pipeline .input_save_path is None :
978981 experience_pipeline .input_save_path = os .path .join (
979982 self .buffer .cache_dir , "explorer_output.jsonl" # type: ignore[arg-type]
@@ -983,7 +986,7 @@ def _check_data_processor(self) -> None:
983986 )
984987
985988 task_pipeline = self .data_processor .task_pipeline
986- if task_pipeline is not None :
989+ if task_pipeline is not None and self . mode in { "explore" , "both" } :
987990 if task_pipeline .output is None :
988991 if self .mode != "train" :
989992 task_pipeline .output = self .buffer .explorer_input .tasksets [0 ]
0 commit comments