Skip to content

Commit 1b561fe

Browse files
committed
update vllm test
1 parent 0478f5f commit 1b561fe

File tree

3 files changed

+11
-125
lines changed

3 files changed

+11
-125
lines changed

.github/workflows/docker/docker-compose.yaml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,13 @@ services:
99
environment:
1010
- HF_ENDPOINT=https://hf-mirror.com
1111
- RAY_ADDRESS=auto
12+
- CHECKPOINT_ROOT_DIR=/mnt/checkpoints
13+
- DATA_ROOT_DIR=/mnt/data
1214
working_dir: /workspace
1315
networks:
1416
- ray-network
1517
volumes:
16-
- trinity-volume:/data
18+
- trinity-volume:/mnt
1719
- ../../..:/workspace
1820
ports:
1921
- "6379:6379"
@@ -35,9 +37,11 @@ services:
3537
ray start --address=trinity-node-1:6379 --block
3638
environment:
3739
- HF_ENDPOINT=https://hf-mirror.com
40+
- CHECKPOINT_ROOT_DIR=/mnt/checkpoints
41+
- DATA_ROOT_DIR=/mnt/data
3842
working_dir: /workspace
3943
volumes:
40-
- trinity-volume:/data
44+
- trinity-volume:/mnt
4145
- ../../..:/workspace
4246
depends_on:
4347
- trinity-node-1

tests/common/tmp/template_config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ explorer:
2828
enable_prefix_caching: false
2929
enforce_eager: true
3030
dtype: bfloat16
31-
temperature: 0.0
32-
top_p: 1.0
31+
temperature: 0.2
32+
top_p: 0.95
3333
top_k: -1
3434
seed: 42
3535
logprobs: 0
Lines changed: 3 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
tokenize_and_mask_messages_hf,
1414
)
1515

16-
config_dir = os.path.join(os.path.dirname(__file__), "test_data", "template.yaml")
16+
config_dir = os.path.join(os.path.dirname(__file__), "tmp", "template_config.yaml")
1717

1818

1919
CHAT_TEMPLATE = r"""
@@ -76,129 +76,11 @@
7676
"""
7777

7878

79-
@unittest.skip("Skip VLLM test")
80-
class TestSyncvLLMModel(unittest.TestCase):
81-
def setUp(self):
82-
ray.init(ignore_reinit_error=True)
83-
self.config = load_config(config_dir)
84-
self.engines = create_rollout_models(self.config)
85-
self.assertEqual(len(self.engines), self.config.explorer.engine_num)
86-
87-
def test_generate(self):
88-
prompts = [
89-
"Hello, world!",
90-
"Hello, my name is",
91-
]
92-
cluster_results = []
93-
for engine in self.engines:
94-
cluster_results.extend(ray.get(engine.generate.remote(prompts)))
95-
96-
self.assertEqual(
97-
len(cluster_results),
98-
len(self.engines) * len(prompts) * self.config.explorer.repeat_times,
99-
)
100-
for i in range(len(self.engines)):
101-
for j in range(len(prompts)):
102-
for k in range(self.config.explorer.repeat_times):
103-
self.assertEqual(
104-
cluster_results[
105-
i * len(prompts) * self.config.explorer.repeat_times
106-
+ j * self.config.explorer.repeat_times
107-
+ k
108-
].prompt_text,
109-
prompts[j],
110-
)
111-
112-
def test_chat_and_logprobs(self):
113-
messages = [
114-
{"role": "system", "content": "You are a helpful assistant."},
115-
{"role": "user", "content": "Hello, world!"},
116-
]
117-
cluster_results = []
118-
for engine in self.engines:
119-
cluster_results.extend(ray.get(engine.chat.remote(messages)))
120-
self.assertEqual(
121-
len(cluster_results), len(self.engines) * self.config.explorer.repeat_times
122-
)
123-
for i in range(len(self.engines)):
124-
for k in range(self.config.explorer.repeat_times):
125-
self.assertIn(
126-
"Hello, world!",
127-
cluster_results[i * self.config.explorer.repeat_times + k].prompt_text,
128-
)
129-
self.assertIn(
130-
"You are a helpful assistant.",
131-
cluster_results[i * self.config.explorer.repeat_times + k].prompt_text,
132-
)
133-
logprobs = ray.get(self.engines[0].logprobs.remote(cluster_results[0].tokens))
134-
self.assertEqual(logprobs.shape[0], cluster_results[0][0].tokens.shape[0] - 1)
135-
136-
137-
@unittest.skip("Skip VLLM test")
138-
class TestAsyncvLLMModel(unittest.TestCase):
139-
def setUp(self):
140-
ray.init(ignore_reinit_error=True)
141-
self.config = load_config(config_dir)
142-
self.config.explorer.engine_type = "vllm_async"
143-
self.engines = create_rollout_models(self.config)
144-
self.assertEqual(len(self.engines), self.config.explorer.engine_num)
145-
146-
def test_generate(self):
147-
prompts = ["Hello, world!", "Hi, my name is", "How are you?", "What's up?"]
148-
cluster_results = []
149-
refs = []
150-
for engine in self.engines:
151-
for prompt in prompts:
152-
refs.append(engine.generate_async.remote(prompt))
153-
cluster_results = ray.get(refs)
154-
155-
self.assertEqual(
156-
len(cluster_results),
157-
len(self.engines) * self.config.explorer.repeat_times * len(prompts),
158-
)
159-
160-
def test_chat_and_logprobs(self):
161-
messages = [
162-
[
163-
{"role": "system", "content": "You are a helpful assistant."},
164-
{"role": "user", "content": "Hello, world!"},
165-
],
166-
[
167-
{"role": "system", "content": "You are a helpful assistant."},
168-
{"role": "user", "content": "Please tell me about yourself."},
169-
],
170-
[
171-
{"role": "system", "content": "You are a helpful assistant."},
172-
{"role": "user", "content": "Please tell me a joke."},
173-
],
174-
]
175-
cluster_results = []
176-
refs = []
177-
for engine in self.engines:
178-
for message in messages:
179-
refs.append(engine.chat_async.remote(message))
180-
cluster_results = ray.get(refs)
181-
182-
self.assertEqual(
183-
len(cluster_results),
184-
len(self.engines) * self.config.explorer.repeat_times * len(messages),
185-
)
186-
logprobs_refs = []
187-
for i, messages in enumerate(messages):
188-
token_ids = cluster_results[i][0].tokens.tolist()
189-
logprobs_refs.append(self.engines[0].logprobs_async.remote(token_ids))
190-
logprobs = ray.get(logprobs_refs)
191-
for i, messages in enumerate(messages):
192-
self.assertEqual(logprobs[i].shape[0], cluster_results[i][0].tokens.shape[0])
193-
194-
19579
class TestModelWrapper:
19680
def test_generate(self):
19781
prompts = ["Hello, world!", "Hello, my name is"]
19882
results = self.model_wrapper.generate(prompts)
19983
self.assertEqual(len(results), len(prompts) * self.config.explorer.repeat_times)
200-
201-
def test_chat_and_logprobs(self):
20284
messages = [
20385
{"role": "system", "content": "You are a helpful assistant."},
20486
{"role": "user", "content": "What's the weather like today?"},
@@ -235,23 +117,23 @@ def test_chat_and_logprobs(self):
235117
self.assertTrue(torch.equal(result_dict["input_ids"][0], exp.tokens))
236118

237119

238-
# @unittest.skip("Skip VLLM test")
239120
class TestModelWrapperSync(TestModelWrapper, unittest.TestCase):
240121
def setUp(self):
241122
ray.init(ignore_reinit_error=True)
242123
self.config = load_config(config_dir)
124+
self.config.model.model_path = os.environ.get("MODEL_PATH")
243125
self.config.explorer.engine_type = "vllm"
244126
self.config.explorer.engine_num = 1
245127
self.config.explorer.chat_template = CHAT_TEMPLATE
246128
self.engines = create_rollout_models(self.config)
247129
self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm")
248130

249131

250-
# @unittest.skip("Skip VLLM test")
251132
class TestModelWrapperAsync(TestModelWrapper, unittest.TestCase):
252133
def setUp(self):
253134
ray.init(ignore_reinit_error=True)
254135
self.config = load_config(config_dir)
136+
self.config.model.model_path = os.environ.get("MODEL_PATH")
255137
self.config.explorer.engine_type = "vllm_async"
256138
self.config.explorer.engine_num = 1
257139
self.config.explorer.chat_template = CHAT_TEMPLATE

0 commit comments

Comments
 (0)