Skip to content

Commit c10dd93

Browse files
committed
* update ray init method
* set ray_namespace for StorageConfigs to the global ray_namespace if they are not set
1 parent 55800f6 commit c10dd93

File tree

2 files changed

+11
-1
lines changed

2 files changed

+11
-1
lines changed

trinity/common/config.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -452,6 +452,8 @@ def _check_buffer(self) -> None: # noqa: C901
452452
self.buffer.explorer_input.taskset.format.reply_prefix = (
453453
self.buffer.explorer_input.reply_prefix
454454
)
455+
if self.buffer.explorer_input.taskset.ray_namespace is None:
456+
self.buffer.explorer_input.taskset.ray_namespace = self.ray_namespace
455457

456458
remained_tasksets = []
457459
for idx, dataset in enumerate(self.buffer.explorer_input.eval_tasksets):
@@ -469,6 +471,8 @@ def _check_buffer(self) -> None: # noqa: C901
469471
dataset.format.system_prompt = self.buffer.explorer_input.system_prompt
470472
if dataset.format.reply_prefix is None:
471473
dataset.format.reply_prefix = self.buffer.explorer_input.reply_prefix
474+
if dataset.ray_namespace is None:
475+
dataset.ray_namespace = self.ray_namespace
472476
remained_tasksets.append(dataset)
473477
self.buffer.explorer_input.eval_tasksets = remained_tasksets
474478

@@ -493,12 +497,16 @@ def _check_buffer(self) -> None: # noqa: C901
493497
self.buffer.trainer_input.experience_buffer.algorithm_type = (
494498
self.algorithm.algorithm_type
495499
)
500+
if self.buffer.trainer_input.experience_buffer.ray_namespace is None:
501+
self.buffer.trainer_input.experience_buffer.ray_namespace = self.ray_namespace
496502

497503
# set buffer.explorer_output
498504
if self.buffer.explorer_output is None:
499505
self.buffer.explorer_output = self.buffer.trainer_input.experience_buffer
500506
else:
501507
self.buffer.explorer_output.algorithm_type = self.algorithm.algorithm_type
508+
if self.buffer.explorer_output.ray_namespace is None:
509+
self.buffer.explorer_output.ray_namespace = self.ray_namespace
502510

503511
# check trainer_input.sft_warmup_dataset
504512
if (
@@ -510,6 +518,8 @@ def _check_buffer(self) -> None: # noqa: C901
510518
)
511519
if self.buffer.trainer_input.sft_warmup_dataset is not None:
512520
self.buffer.trainer_input.sft_warmup_dataset.algorithm_type = "sft" # TODO
521+
if self.buffer.trainer_input.sft_warmup_dataset.ray_namespace is None:
522+
self.buffer.trainer_input.sft_warmup_dataset.ray_namespace = self.ray_namespace
513523

514524
# check input/output buffers in experience pipelines
515525
if self.data_processor.experience_pipeline is not None:

trinity/data/server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def data_processor(pipeline_type):
2424
config.check_and_update()
2525

2626
# init ray
27-
ray.init(namespace=config.ray_namespace)
27+
ray.init(namespace=config.ray_namespace, ignore_reinit_error=True)
2828

2929
pipeline_config = getattr(config.data_processor, pipeline_type)
3030
if pipeline_config is None:

0 commit comments

Comments
 (0)