Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/unittest.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:
- uses: actions/checkout@v4
with:
path: trinity-${{ github.run_id }}
fetch-depth: 0
ref: refs/pull/${{ github.event.issue.number }}/head

- name: Setup docker compose
working-directory: trinity-${{ github.run_id }}/.github/workflows/docker
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ requires-python = ">=3.10"
dependencies = [
"verl==0.3.0.post1",
"ray[default]==2.43.0",
"vllm>=0.8.3",
"vllm>=0.8.5",
"tensordict==0.6.2",
"wandb",
"omegaconf",
Expand Down
10 changes: 3 additions & 7 deletions tests/buffer/queue_test.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,21 @@
import unittest

import ray
import torch

from tests.tools import RayUnittestBase
from trinity.buffer.reader.queue_reader import QueueReader
from trinity.buffer.writer.queue_writer import QueueWriter
from trinity.common.config import BufferConfig, DatasetConfig
from trinity.common.constants import AlgorithmType, StorageType
from trinity.common.experience import Experience


class TestQueueBuffer(unittest.TestCase):
def setUp(self):
ray.init(ignore_reinit_error=True)

class TestQueueBuffer(RayUnittestBase):
def test_queue_buffer(self):
total_num = 8
put_batch_size = 2
read_batch_size = 4
meta = DatasetConfig(
name="test_buffer",
namespace="test_namespace",
algorithm_type=AlgorithmType.PPO,
storage_type=StorageType.QUEUE,
)
Expand Down
4 changes: 3 additions & 1 deletion tests/common/config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ def test_all_examples_are_valid(self):
example_dir = os.path.join(os.path.dirname(__file__), "..", "..", "examples")
for example_name in os.listdir(example_dir):
for filename in os.listdir(os.path.join(example_dir, example_name)):
if filename.endswith(".yaml") and not filename.startswith("train"):
if filename.endswith(".yaml") and not (
filename.startswith("train_") or filename.startswith("verl_")
):
print(f"Checking config: {filename}")
config_path = os.path.join(example_dir, example_name, filename)
try:
Expand Down
59 changes: 44 additions & 15 deletions tests/common/vllm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,47 +129,76 @@ def test_generate(self):
self.assertTrue(torch.equal(result_dict["input_ids"][0], exp.tokens))


class TestModelWrapperSync(BaseTestModelWrapper, RayUnittestBase):
class TestModelWrapperSyncV0(BaseTestModelWrapper, RayUnittestBase):
def setUp(self):
ray.init(ignore_reinit_error=True)
self.config = get_template_config()
self.config.model.model_path = get_model_path()
self.config.explorer.engine_type = "vllm"
self.config.explorer.engine_num = 1
self.config.explorer.tensor_parallel_size = 1
self.config.explorer.engine_num = 2
self.config.explorer.chat_template = CHAT_TEMPLATE
self.engines = create_rollout_models(self.config)
self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm")


class TestModelWrapperAsync(BaseTestModelWrapper, RayUnittestBase):
@classmethod
def setUpClass(cls):
class TestModelWrapperAsyncV0(BaseTestModelWrapper, RayUnittestBase):
def setUp(self):
ray.init(ignore_reinit_error=True)
self.config = get_template_config()
self.config.model.model_path = get_model_path()
self.config.explorer.engine_type = "vllm_async"
self.config.explorer.engine_num = 2
self.config.explorer.tensor_parallel_size = 1
self.config.explorer.use_v1 = False
self.config.explorer.chat_template = CHAT_TEMPLATE
self.engines = create_rollout_models(self.config)
self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm_async")

@classmethod
def tearDownClass(cls):
ray.shutdown()

class TestModelWrapperAsyncTPV0(BaseTestModelWrapper, RayUnittestBase):
def setUp(self):
ray.init(ignore_reinit_error=True)
self.config = get_template_config()
self.config.model.model_path = get_model_path()
self.config.explorer.engine_type = "vllm_async"
self.config.explorer.engine_num = 1
self.config.explorer.engine_num = 2
self.config.explorer.tensor_parallel_size = 2
self.config.explorer.use_v1 = False
self.config.explorer.chat_template = CHAT_TEMPLATE
self.engines = create_rollout_models(self.config)
self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm_async")


class TestTokenizer(unittest.TestCase):
@classmethod
def setUpClass(cls):
class TestModelWrapperAsyncTPV1(BaseTestModelWrapper, RayUnittestBase):
def setUp(self):
ray.init(ignore_reinit_error=True)
self.config = get_template_config()
self.config.model.model_path = get_model_path()
self.config.explorer.engine_type = "vllm_async"
self.config.explorer.engine_num = 2
self.config.explorer.tensor_parallel_size = 2
self.config.explorer.use_v1 = True
self.config.explorer.chat_template = CHAT_TEMPLATE
self.engines = create_rollout_models(self.config)
self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm_async")


class TestModelWrapperAsyncV1(BaseTestModelWrapper, RayUnittestBase):
def setUp(self):
ray.init(ignore_reinit_error=True)
self.config = get_template_config()
self.config.model.model_path = get_model_path()
self.config.explorer.engine_type = "vllm_async"
self.config.explorer.engine_num = 2
self.config.explorer.tensor_parallel_size = 1
self.config.explorer.use_v1 = True
self.config.explorer.chat_template = CHAT_TEMPLATE
self.engines = create_rollout_models(self.config)
self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm_async")

@classmethod
def tearDownClass(cls):
ray.shutdown()

class TestTokenizer(unittest.TestCase):
def test_assistant_token_mask(self):
messages = [
{"role": "system", "content": "You are a helpful assistant."},
Expand Down
2 changes: 2 additions & 0 deletions tests/explorer/explorer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class TestExplorerCountdownEval(BaseExplorerCase):
def test_explorer(self):
self.config.data = get_unittest_dataset_config("countdown")
self.config.monitor.name = f"explore-eval-{datetime.now().strftime('%Y%m%d%H%M%S')}"
self.config.explorer.use_v1 = True
self.config.check_and_update()
explore(self.config)
parser = TensorBoardParser(os.path.join(self.config.monitor.job_dir, "tensorboard"))
Expand All @@ -53,6 +54,7 @@ def test_explorer(self):
self.config.data = get_unittest_dataset_config("countdown")
self.config.monitor.name = f"explore-no-eval-{datetime.now().strftime('%Y%m%d%H%M%S')}"
self.config.data.eval_split = None
self.config.explorer.use_v1 = False
self.config.check_and_update()
explore(self.config)
parser = TensorBoardParser(os.path.join(self.config.monitor.job_dir, "tensorboard"))
Expand Down
1 change: 1 addition & 0 deletions tests/explorer/runner_pool_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def setUp(self):
self.config.buffer.pad_token_id = 0
self.config.buffer.train_dataset = DatasetConfig(
name="test",
namespace="test_runner_pool",
storage_type=StorageType.QUEUE,
algorithm_type=AlgorithmType.PPO,
)
Expand Down
1 change: 1 addition & 0 deletions tests/template/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ explorer:
logprobs: 0
backend: nccl
use_ray: false
use_v1: true
trainer:
trainer_type: verl
trainer_config_path: tests/template/verl_config.yaml
Expand Down
2 changes: 1 addition & 1 deletion tests/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,4 +102,4 @@ def setUpClass(cls):

@classmethod
def tearDownClass(cls):
ray.shutdown()
ray.shutdown(_exiting_interpreter=True)
2 changes: 2 additions & 0 deletions trinity/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,8 @@ class ExplorerConfig:
use_ray: bool = False
gpu_memory_utilization: float = 0.9
enable_chunked_prefil: bool = False
use_v1: bool = True
bundle_indices: str = "" # DO NOT SET this field

# for workflow runner
max_pending_requests: int = 5
Expand Down
59 changes: 33 additions & 26 deletions trinity/common/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections import defaultdict
from typing import List

from trinity.common.config import Config
Expand All @@ -12,20 +13,15 @@ def create_rollout_models(
Each model has `tensor_parallel_size` workers.
"""
import ray
import vllm
from ray.util.placement_group import placement_group
from ray.util.placement_group import placement_group, placement_group_table
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy

from trinity.common.models.vllm_async_model import vLLMAysncRolloutModel
from trinity.common.models.vllm_model import vLLMRolloutModel
from trinity.utils.log import get_logger

logger = get_logger(__name__)

assert vllm.__version__ >= "0.7.3", "Trinity-RFT only supports vllm >= 0.7.3"

engine_num = config.explorer.engine_num
tensor_parallel_size = config.explorer.tensor_parallel_size
is_multi_process = config.explorer.tensor_parallel_size > 1

vllm_engines = []

Expand All @@ -36,28 +32,39 @@ def create_rollout_models(
else:
raise ValueError(f"Unknown engine type: {config.explorer.engine_type}")

bundles = [{"GPU": tensor_parallel_size, "CPU": 1} for _ in range(engine_num)]
pg = placement_group(bundles)
bundles = [{"GPU": 1} for _ in range(engine_num * tensor_parallel_size)]
pg = placement_group(bundles, strategy="PACK")
ray.get(pg.ready())

for i in range(engine_num):
logger.info(f"Creating vLLM engine {i}")
scheduling_strategy = None
vllm_engines = []

scheduling_strategy = PlacementGroupSchedulingStrategy(
placement_group=pg,
placement_group_capture_child_tasks=True,
placement_group_bundle_index=i,
)
# to address https://github.com/ray-project/ray/issues/51117
# aggregate bundles belonging to the same node
bundle_node_map = placement_group_table(pg)["bundles_to_node_id"]
node_bundle_map = defaultdict(list)
for bundle_id, node_id in bundle_node_map.items():
node_bundle_map[node_id].append(bundle_id)

vllm_engines.append(
engine_cls.options( # type: ignore [attr-defined]
num_cpus=0,
num_gpus=tensor_parallel_size,
scheduling_strategy=scheduling_strategy,
).remote(
config=config,
)
for node_id, bundle_ids in node_bundle_map.items():
assert len(bundle_ids) % tensor_parallel_size == 0, (
f"Node {node_id} has {len(bundle_ids)} bundles, "
f"which is not divisible by tensor_parallel_size({tensor_parallel_size})"
)

for i in range(len(bundle_ids) // tensor_parallel_size):
bundles_for_engine = bundle_ids[
i * tensor_parallel_size : (i + 1) * tensor_parallel_size
]
config.explorer.bundle_indices = ",".join([str(bid) for bid in bundles_for_engine])
vllm_engines.append(
engine_cls.options(
num_cpus=0,
num_gpus=0 if is_multi_process else 1,
scheduling_strategy=PlacementGroupSchedulingStrategy(
placement_group=pg,
placement_group_bundle_index=bundles_for_engine[0],
),
).remote(
config=config,
)
)
return vllm_engines
19 changes: 0 additions & 19 deletions trinity/common/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,29 +45,10 @@ async def convert_messages_to_experience_async(self, messages: List[dict]) -> Ex
"""Convert a list of messages into an experience in async."""
raise NotImplementedError

@abstractmethod
def sync_model(self, update_weight_args_list: List) -> bool:
"""Sync model weights."""
# TODO: sync with high efficiency

@abstractmethod
def get_ckp_version(self) -> int:
"""Get the checkpoint version."""

@abstractmethod
def init_process_group(
self,
master_address: str,
master_port: int,
rank_offset: int,
world_size: int,
group_name: str,
backend: str = "nccl",
timeout: int = 1200,
update_with_checkpoint: bool = True,
) -> None:
"""Init the process group for model weights sync."""

def get_address(self) -> Tuple[str, int]:
"""Get the address of the actor."""
address = ray.util.get_node_ip_address()
Expand Down
Loading