Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def metric_list(self, metric_prefix: str) -> List[str]:
class RayUnittestBase(unittest.TestCase):
@classmethod
def setUpClass(cls):
ray.init(ignore_reinit_error=True)
ray.init(ignore_reinit_error=True, namespace="trinity_unittest")

@classmethod
def tearDownClass(cls):
Expand Down
2 changes: 1 addition & 1 deletion tests/trainer/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def setUp(self):
self.config.model.model_path = get_model_path()
self.config.explorer.rollout_model.engine_type = "vllm_async"
self.config.algorithm.repeat_times = 3
self.config.explorer.rollout_model.use_v1 = False
self.config.project = "Trainer-unittest"
self.config.name = f"trainer-{datetime.now().strftime('%Y%m%d%H%M%S')}"
self.config.monitor.monitor_type = "tensorboard"
Expand All @@ -45,6 +44,7 @@ class TestTrainerCountdown(BaseTrainerCase):
def test_trainer(self):
"""Test the both and bench mode."""
# test both mode
self.config.explorer.rollout_model.use_v1 = False
self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown")
self.config.buffer.explorer_input.eval_tasksets.append(
get_unittest_dataset_config("countdown", "test")
Expand Down
3 changes: 0 additions & 3 deletions trinity/buffer/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
from copy import deepcopy
from typing import List

import ray

from trinity.buffer.writer.file_writer import JSONWriter
from trinity.buffer.writer.sql_writer import SQLWriter
from trinity.common.config import BufferConfig, StorageConfig
Expand All @@ -20,7 +18,6 @@ def is_json_file(path: str) -> bool:
return path.endswith(".json") or path.endswith(".jsonl")


@ray.remote
class QueueActor:
"""An asyncio.Queue based queue actor."""

Expand Down
2 changes: 2 additions & 0 deletions trinity/buffer/ray_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def get_wrapper(cls, storage_config: StorageConfig, config: BufferConfig):
ray.remote(cls)
.options(
name=f"sql-{storage_config.name}",
namespace=ray.get_runtime_context().namespace,
get_if_exists=True,
)
.remote(storage_config, config)
Expand Down Expand Up @@ -154,6 +155,7 @@ def get_wrapper(cls, storage_config: StorageConfig, config: BufferConfig):
ray.remote(cls)
.options(
name=f"json-{storage_config.name}",
namespace=ray.get_runtime_context().namespace,
get_if_exists=True,
)
.remote(storage_config, config)
Expand Down
13 changes: 9 additions & 4 deletions trinity/buffer/reader/queue_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,15 @@ class QueueReader(BufferReader):
def __init__(self, storage_config: StorageConfig, config: BufferConfig):
assert storage_config.storage_type == StorageType.QUEUE
self.read_batch_size = config.read_batch_size
self.queue = QueueActor.options(
name=f"queue-{storage_config.name}",
get_if_exists=True,
).remote(storage_config, config)
self.queue = (
ray.remote(QueueActor)
.options(
name=f"queue-{storage_config.name}",
namespace=ray.get_runtime_context().namespace,
get_if_exists=True,
)
.remote(storage_config, config)
)

def read(
self, batch_size: Optional[int] = None, strategy: Optional[ReadStrategy] = None
Expand Down
13 changes: 9 additions & 4 deletions trinity/buffer/writer/queue_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,15 @@ class QueueWriter(BufferWriter):
def __init__(self, meta: StorageConfig, config: BufferConfig):
assert meta.storage_type == StorageType.QUEUE
self.config = config
self.queue = QueueActor.options(
name=f"queue-{meta.name}",
get_if_exists=True,
).remote(meta, config)
self.queue = (
ray.remote(QueueActor)
.options(
name=f"queue-{meta.name}",
namespace=ray.get_runtime_context().namespace,
get_if_exists=True,
)
.remote(meta, config)
)

def write(self, data: List) -> None:
ray.get(self.queue.put_batch.remote(data))
Expand Down
53 changes: 44 additions & 9 deletions trinity/cli/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,14 @@

def bench(config: Config) -> None:
"""Evaluate model."""
explorer = ray.remote(Explorer).options(name=EXPLORER_NAME).remote(config)
explorer = (
ray.remote(Explorer)
.options(
name=EXPLORER_NAME,
namespace=ray.get_runtime_context().namespace,
)
.remote(config)
)
try:
ray.get(explorer.prepare.remote())
ray.get(explorer.benchmark.remote())
Expand All @@ -34,7 +41,14 @@ def bench(config: Config) -> None:
def explore(config: Config) -> None:
"""Run explorer."""
try:
explorer = ray.remote(Explorer).options(name=EXPLORER_NAME).remote(config)
explorer = (
ray.remote(Explorer)
.options(
name=EXPLORER_NAME,
namespace=ray.get_runtime_context().namespace,
)
.remote(config)
)
ray.get(explorer.prepare.remote())
ray.get(explorer.sync_weight.remote())
ray.get(explorer.explore.remote())
Expand All @@ -47,7 +61,14 @@ def explore(config: Config) -> None:
def train(config: Config) -> None:
"""Run trainer."""
try:
trainer = ray.remote(Trainer).options(name=TRAINER_NAME).remote(config)
trainer = (
ray.remote(Trainer)
.options(
name=TRAINER_NAME,
namespace=ray.get_runtime_context().namespace,
)
.remote(config)
)
ray.get(trainer.prepare.remote())
ray.get(trainer.sync_weight.remote())
ray.get(trainer.train.remote())
Expand All @@ -67,8 +88,23 @@ def both(config: Config) -> None:
the latest step. The specific number of experiences may vary for different
algorithms and tasks.
"""
explorer = ray.remote(Explorer).options(name=EXPLORER_NAME).remote(config)
trainer = ray.remote(Trainer).options(name=TRAINER_NAME).remote(config)
namespace = ray.get_runtime_context().namespace
explorer = (
ray.remote(Explorer)
.options(
name=EXPLORER_NAME,
namespace=namespace,
)
.remote(config)
)
trainer = (
ray.remote(Trainer)
.options(
name=TRAINER_NAME,
namespace=namespace,
)
.remote(config)
)
ray.get([explorer.__ray_ready__.remote(), trainer.__ray_ready__.remote()])
ray.get(
[
Expand Down Expand Up @@ -191,17 +227,16 @@ def run(config_path: str, dlc: bool = False, plugin_dir: str = None):
activate_data_module(
f"{data_processor_config.data_processor_url}/experience_pipeline", config_path
)
ray_namespace = config.ray_namespace
if dlc:
from trinity.utils.dlc_utils import setup_ray_cluster

setup_ray_cluster(namespace=ray_namespace)
setup_ray_cluster(namespace=config.ray_namespace)
else:
from trinity.utils.dlc_utils import is_running

if not is_running:
raise RuntimeError("Ray is not running, please start it by `ray start --head`.")
ray.init(namespace=ray_namespace, ignore_reinit_error=True)
ray.init(namespace=config.ray_namespace, ignore_reinit_error=True)
if config.mode == "explore":
explore(config)
elif config.mode == "train":
Expand All @@ -214,7 +249,7 @@ def run(config_path: str, dlc: bool = False, plugin_dir: str = None):
if dlc:
from trinity.utils.dlc_utils import stop_ray_cluster

stop_ray_cluster()
stop_ray_cluster(namespace=config.ray_namespace)


def studio(port: int = 8501):
Expand Down
9 changes: 3 additions & 6 deletions trinity/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,6 @@ class InferenceModelConfig:

# ! DO NOT SET
bundle_indices: str = ""
ray_namespace: str = ""


@dataclass
Expand Down Expand Up @@ -354,7 +353,7 @@ class Config:
checkpoint_root_dir: str = ""
# ! DO NOT SET, automatically generated as `checkpoint_root_dir/project/name`
checkpoint_job_dir: str = ""
# ! DO NOT SET, automatically generated as f"{config.project}-{config.name}"
# If not set, automatically generated as f"{config.project}-{config.name}"
ray_namespace: str = ""

algorithm: AlgorithmConfig = field(default_factory=AlgorithmConfig)
Expand Down Expand Up @@ -579,7 +578,8 @@ def check_and_update(self) -> None: # noqa: C901
self._check_deprecated()

# set namespace
self.ray_namespace = f"{self.project}-{self.name}"
if self.ray_namespace is None or len(self.ray_namespace) == 0:
self.ray_namespace = f"{self.project}-{self.name}"

# check algorithm
self._check_algorithm()
Expand Down Expand Up @@ -611,9 +611,6 @@ def check_and_update(self) -> None: # noqa: C901
self.explorer.rollout_model.max_prompt_tokens = self.model.max_prompt_tokens
if self.explorer.rollout_model.max_response_tokens is None:
self.explorer.rollout_model.max_response_tokens = self.model.max_response_tokens
self.explorer.rollout_model.ray_namespace = self.ray_namespace
for model in self.explorer.auxiliary_models:
model.ray_namespace = self.ray_namespace

# check synchronizer
self.synchronizer.explorer_world_size = (
Expand Down
4 changes: 3 additions & 1 deletion trinity/common/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def create_inference_models(
for bundle_id, node_id in bundle_node_map.items():
node_bundle_map[node_id].append(bundle_id)
allocator = _BundleAllocator(node_bundle_map)

namespace = ray.get_runtime_context().namespace
# create rollout models
for _ in range(config.explorer.rollout_model.engine_num):
bundles_for_engine = allocator.allocate(config.explorer.rollout_model.tensor_parallel_size)
Expand All @@ -101,6 +101,7 @@ def create_inference_models(
.options(
num_cpus=0,
num_gpus=0 if config.explorer.rollout_model.tensor_parallel_size > 1 else 1,
namespace=namespace,
scheduling_strategy=PlacementGroupSchedulingStrategy(
placement_group=pg,
placement_group_capture_child_tasks=True,
Expand Down Expand Up @@ -128,6 +129,7 @@ def create_inference_models(
.options(
num_cpus=0,
num_gpus=0 if model_config.tensor_parallel_size > 1 else 1,
namespace=namespace,
scheduling_strategy=PlacementGroupSchedulingStrategy(
placement_group=pg,
placement_group_capture_child_tasks=True,
Expand Down
3 changes: 2 additions & 1 deletion trinity/common/models/vllm_async_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Any, Dict, List, Optional, Tuple, Union

import aiohttp
import ray
import torch
import vllm
from vllm.sampling_params import RequestOutputKind
Expand Down Expand Up @@ -298,7 +299,7 @@ async def init_process_group(
timeout,
update_with_checkpoint,
state_dict_meta,
self.config.ray_namespace,
ray.get_runtime_context().namespace,
),
)

Expand Down
3 changes: 2 additions & 1 deletion trinity/common/models/vllm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import threading
from typing import List, Optional, Tuple

import ray
import torch
import vllm
from vllm import LLM
Expand Down Expand Up @@ -112,7 +113,7 @@ def init_process_group(
timeout,
update_with_checkpoint,
state_dict_meta,
self.config.ray_namespace,
ray.get_runtime_context().namespace,
),
)

Expand Down
6 changes: 3 additions & 3 deletions trinity/common/models/vllm_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def init_process_group(
timeout: int = 1200,
update_with_checkpoint: bool = True,
state_dict_meta: list = None,
namespace: str = "",
namespace: str = None,
):
"""Init torch process group for model weights update"""
assert torch.distributed.is_initialized(), "default torch process group must be initialized"
Expand Down Expand Up @@ -53,7 +53,7 @@ def init_process_group(
group_name=group_name,
)
logger.info("vLLM init_process_group finished.")
self.namespace = namespace
self._namespace = namespace
self._explorer_actor = None

def set_state_dict_meta(self, state_dict_meta):
Expand All @@ -63,7 +63,7 @@ def update_weight(self):
"""Broadcast weight to all vllm workers from source rank 0 (actor model)"""
assert self._state_dict_meta is not None
if self._explorer_actor is None:
self._explorer_actor = ray.get_actor(name=EXPLORER_NAME, namespace=self.namespace)
self._explorer_actor = ray.get_actor(name=EXPLORER_NAME, namespace=self._namespace)
for name, dtype_str, shape in self._state_dict_meta:
if self._weight_update_rank == 0:
weight = ray.get(self._explorer_actor.get_weight.remote(name))
Expand Down
21 changes: 18 additions & 3 deletions trinity/explorer/runner_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def __init__(
]
self._idle_actors = list()
self.actor_to_engine_index = {}
self._namespace = ray.get_runtime_context().namespace
self._create_actors(config.explorer.runner_num)

def _create_actors(self, num: int = 1):
Expand All @@ -68,8 +69,17 @@ def _create_actors(self, num: int = 1):
self.auxiliary_models, self.auxiliary_engine_status_list
)
]
new_actor = WorkflowRunner.remote(
self.config, self.models[engine_index], selected_auxiliary_models
new_actor = (
ray.remote(WorkflowRunner)
.options(
namespace=self._namespace,
scheduling_strategy="SPREAD",
)
.remote(
self.config,
self.models[engine_index],
selected_auxiliary_models,
)
)
new_actors.append(new_actor)
self.engine_status[engine_index] += 1
Expand Down Expand Up @@ -219,7 +229,12 @@ def get_next(self) -> Status:
ray.kill(a)
# TODO: balance the model
self._return_actor(
WorkflowRunner.remote(
ray.remote(WorkflowRunner)
.options(
namespace=self._namespace,
scheduling_strategy="SPREAD",
)
.remote(
self.config,
self.models[
random.randint(0, self.config.explorer.rollout_model.engine_num - 1)
Expand Down
3 changes: 0 additions & 3 deletions trinity/explorer/workflow_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
from dataclasses import dataclass
from typing import List, Optional

import ray

from trinity.buffer import get_buffer_writer
from trinity.common.config import Config
from trinity.common.experience import Experience
Expand All @@ -26,7 +24,6 @@ class Status:
message: Optional[str] = None


@ray.remote(scheduling_strategy="SPREAD")
class WorkflowRunner:
"""A Ray remote actor to run the workflow and put the returned experiences into the buffer."""

Expand Down
Loading