diff --git a/docs/sphinx_doc/source/tutorial/trinity_configs.md b/docs/sphinx_doc/source/tutorial/trinity_configs.md index 88d925f786..c9fbacc798 100644 --- a/docs/sphinx_doc/source/tutorial/trinity_configs.md +++ b/docs/sphinx_doc/source/tutorial/trinity_configs.md @@ -310,6 +310,9 @@ Controls the rollout models and workflow execution. explorer: name: explorer runner_num: 32 + max_timeout: 900 + max_retry_times: 2 + env_vars: {} rollout_model: engine_type: vllm_async engine_num: 1 @@ -321,6 +324,9 @@ explorer: - `name`: Name of the explorer. This name will be used as the Ray actor's name, so it must be unique. - `runner_num`: Number of parallel workflow runners. +- `max_timeout`: Maximum time (in seconds) for a workflow to complete. +- `max_retry_times`: Maximum number of retries for a workflow. +- `env_vars`: Environment variables to be set for every workflow runners. - `rollout_model.engine_type`: Type of inference engine. Options: `vllm_async` (recommended), `vllm`. - `rollout_model.engine_num`: Number of inference engines. - `rollout_model.tensor_parallel_size`: Degree of tensor parallelism. diff --git a/trinity/common/config.py b/trinity/common/config.py index 72e9964857..0360498038 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -305,6 +305,7 @@ class ExplorerConfig: runner_num: int = 1 max_timeout: int = 900 # wait each task for 15 minutes max_retry_times: int = 2 # retry each task for 2 times if it fails or timeout + env_vars: dict = field(default_factory=dict) # for inference models # for rollout model diff --git a/trinity/explorer/runner_pool.py b/trinity/explorer/runner_pool.py index 73c15ab4da..e5ef8bdd5d 100644 --- a/trinity/explorer/runner_pool.py +++ b/trinity/explorer/runner_pool.py @@ -74,6 +74,7 @@ def _create_actors(self, num: int = 1): .options( namespace=self._namespace, scheduling_strategy="SPREAD", + runtime_env={"env_vars": self.config.explorer.env_vars}, ) .remote( self.config, diff --git a/trinity/trainer/trainer.py b/trinity/trainer/trainer.py index 646bfdac4b..61459b7f29 100644 --- a/trinity/trainer/trainer.py +++ b/trinity/trainer/trainer.py @@ -73,7 +73,7 @@ def sync_weight(self) -> None: def flush_log(self, step: int) -> None: """Flush the log of the current step.""" - self.engine.logger.log({}, step=step, commit=True) + self.engine.monitor.log({}, step=step, commit=True) def shutdown(self) -> None: # if checkpoint not saved, save the last checkpoint @@ -81,7 +81,7 @@ def shutdown(self) -> None: path = os.path.join(self.config.checkpoint_job_dir, f"global_step_{step_num}") if not os.path.isdir(path) or len(os.listdir(path)) == 0: self.engine.save_checkpoint() - self.engine.logger.close() + self.engine.monitor.close() class TrainEngineWrapper(ABC):