Skip to content

Commit 0a6f326

Browse files
authored
Bug fix in alfworld (#107)
1 parent 8fa6f76 commit 0a6f326

File tree

4 files changed

+10
-2
lines changed

4 files changed

+10
-2
lines changed

docs/sphinx_doc/source/tutorial/trinity_configs.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,9 @@ Controls the rollout models and workflow execution.
310310
explorer:
311311
name: explorer
312312
runner_num: 32
313+
max_timeout: 900
314+
max_retry_times: 2
315+
env_vars: {}
313316
rollout_model:
314317
engine_type: vllm_async
315318
engine_num: 1
@@ -321,6 +324,9 @@ explorer:
321324

322325
- `name`: Name of the explorer. This name will be used as the Ray actor's name, so it must be unique.
323326
- `runner_num`: Number of parallel workflow runners.
327+
- `max_timeout`: Maximum time (in seconds) for a workflow to complete.
328+
- `max_retry_times`: Maximum number of retries for a workflow.
329+
- `env_vars`: Environment variables to be set for every workflow runners.
324330
- `rollout_model.engine_type`: Type of inference engine. Options: `vllm_async` (recommended), `vllm`.
325331
- `rollout_model.engine_num`: Number of inference engines.
326332
- `rollout_model.tensor_parallel_size`: Degree of tensor parallelism.

trinity/common/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,7 @@ class ExplorerConfig:
305305
runner_num: int = 1
306306
max_timeout: int = 900 # wait each task for 15 minutes
307307
max_retry_times: int = 2 # retry each task for 2 times if it fails or timeout
308+
env_vars: dict = field(default_factory=dict)
308309

309310
# for inference models
310311
# for rollout model

trinity/explorer/runner_pool.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ def _create_actors(self, num: int = 1):
7474
.options(
7575
namespace=self._namespace,
7676
scheduling_strategy="SPREAD",
77+
runtime_env={"env_vars": self.config.explorer.env_vars},
7778
)
7879
.remote(
7980
self.config,

trinity/trainer/trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,15 +73,15 @@ def sync_weight(self) -> None:
7373

7474
def flush_log(self, step: int) -> None:
7575
"""Flush the log of the current step."""
76-
self.engine.logger.log({}, step=step, commit=True)
76+
self.engine.monitor.log({}, step=step, commit=True)
7777

7878
def shutdown(self) -> None:
7979
# if checkpoint not saved, save the last checkpoint
8080
step_num = self.engine.train_step_num
8181
path = os.path.join(self.config.checkpoint_job_dir, f"global_step_{step_num}")
8282
if not os.path.isdir(path) or len(os.listdir(path)) == 0:
8383
self.engine.save_checkpoint()
84-
self.engine.logger.close()
84+
self.engine.monitor.close()
8585

8686

8787
class TrainEngineWrapper(ABC):

0 commit comments

Comments
 (0)