Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/sphinx_doc/source/tutorial/trinity_configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,9 @@ Controls the rollout models and workflow execution.
explorer:
name: explorer
runner_num: 32
max_timeout: 900
max_retry_times: 2
env_vars: {}
rollout_model:
engine_type: vllm_async
engine_num: 1
Expand All @@ -321,6 +324,9 @@ explorer:

- `name`: Name of the explorer. This name will be used as the Ray actor's name, so it must be unique.
- `runner_num`: Number of parallel workflow runners.
- `max_timeout`: Maximum time (in seconds) for a workflow to complete.
- `max_retry_times`: Maximum number of retries for a workflow.
- `env_vars`: Environment variables to be set for every workflow runners.
- `rollout_model.engine_type`: Type of inference engine. Options: `vllm_async` (recommended), `vllm`.
- `rollout_model.engine_num`: Number of inference engines.
- `rollout_model.tensor_parallel_size`: Degree of tensor parallelism.
Expand Down
1 change: 1 addition & 0 deletions trinity/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@ class ExplorerConfig:
runner_num: int = 1
max_timeout: int = 900 # wait each task for 15 minutes
max_retry_times: int = 2 # retry each task for 2 times if it fails or timeout
env_vars: dict = field(default_factory=dict)

# for inference models
# for rollout model
Expand Down
1 change: 1 addition & 0 deletions trinity/explorer/runner_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def _create_actors(self, num: int = 1):
.options(
namespace=self._namespace,
scheduling_strategy="SPREAD",
runtime_env={"env_vars": self.config.explorer.env_vars},
)
.remote(
self.config,
Expand Down
4 changes: 2 additions & 2 deletions trinity/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,15 +73,15 @@ def sync_weight(self) -> None:

def flush_log(self, step: int) -> None:
"""Flush the log of the current step."""
self.engine.logger.log({}, step=step, commit=True)
self.engine.monitor.log({}, step=step, commit=True)

def shutdown(self) -> None:
# if checkpoint not saved, save the last checkpoint
step_num = self.engine.train_step_num
path = os.path.join(self.config.checkpoint_job_dir, f"global_step_{step_num}")
if not os.path.isdir(path) or len(os.listdir(path)) == 0:
self.engine.save_checkpoint()
self.engine.logger.close()
self.engine.monitor.close()


class TrainEngineWrapper(ABC):
Expand Down