|
13 | 13 | tokenize_and_mask_messages_hf, |
14 | 14 | ) |
15 | 15 |
|
16 | | -config_dir = os.path.join(os.path.dirname(__file__), "test_data", "template.yaml") |
| 16 | +config_dir = os.path.join(os.path.dirname(__file__), "tmp", "template_config.yaml") |
17 | 17 |
|
18 | 18 |
|
19 | 19 | CHAT_TEMPLATE = r""" |
|
76 | 76 | """ |
77 | 77 |
|
78 | 78 |
|
79 | | -@unittest.skip("Skip VLLM test") |
80 | | -class TestSyncvLLMModel(unittest.TestCase): |
81 | | - def setUp(self): |
82 | | - ray.init(ignore_reinit_error=True) |
83 | | - self.config = load_config(config_dir) |
84 | | - self.engines = create_rollout_models(self.config) |
85 | | - self.assertEqual(len(self.engines), self.config.explorer.engine_num) |
86 | | - |
87 | | - def test_generate(self): |
88 | | - prompts = [ |
89 | | - "Hello, world!", |
90 | | - "Hello, my name is", |
91 | | - ] |
92 | | - cluster_results = [] |
93 | | - for engine in self.engines: |
94 | | - cluster_results.extend(ray.get(engine.generate.remote(prompts))) |
95 | | - |
96 | | - self.assertEqual( |
97 | | - len(cluster_results), |
98 | | - len(self.engines) * len(prompts) * self.config.explorer.repeat_times, |
99 | | - ) |
100 | | - for i in range(len(self.engines)): |
101 | | - for j in range(len(prompts)): |
102 | | - for k in range(self.config.explorer.repeat_times): |
103 | | - self.assertEqual( |
104 | | - cluster_results[ |
105 | | - i * len(prompts) * self.config.explorer.repeat_times |
106 | | - + j * self.config.explorer.repeat_times |
107 | | - + k |
108 | | - ].prompt_text, |
109 | | - prompts[j], |
110 | | - ) |
111 | | - |
112 | | - def test_chat_and_logprobs(self): |
113 | | - messages = [ |
114 | | - {"role": "system", "content": "You are a helpful assistant."}, |
115 | | - {"role": "user", "content": "Hello, world!"}, |
116 | | - ] |
117 | | - cluster_results = [] |
118 | | - for engine in self.engines: |
119 | | - cluster_results.extend(ray.get(engine.chat.remote(messages))) |
120 | | - self.assertEqual( |
121 | | - len(cluster_results), len(self.engines) * self.config.explorer.repeat_times |
122 | | - ) |
123 | | - for i in range(len(self.engines)): |
124 | | - for k in range(self.config.explorer.repeat_times): |
125 | | - self.assertIn( |
126 | | - "Hello, world!", |
127 | | - cluster_results[i * self.config.explorer.repeat_times + k].prompt_text, |
128 | | - ) |
129 | | - self.assertIn( |
130 | | - "You are a helpful assistant.", |
131 | | - cluster_results[i * self.config.explorer.repeat_times + k].prompt_text, |
132 | | - ) |
133 | | - logprobs = ray.get(self.engines[0].logprobs.remote(cluster_results[0].tokens)) |
134 | | - self.assertEqual(logprobs.shape[0], cluster_results[0][0].tokens.shape[0] - 1) |
135 | | - |
136 | | - |
137 | | -@unittest.skip("Skip VLLM test") |
138 | | -class TestAsyncvLLMModel(unittest.TestCase): |
139 | | - def setUp(self): |
140 | | - ray.init(ignore_reinit_error=True) |
141 | | - self.config = load_config(config_dir) |
142 | | - self.config.explorer.engine_type = "vllm_async" |
143 | | - self.engines = create_rollout_models(self.config) |
144 | | - self.assertEqual(len(self.engines), self.config.explorer.engine_num) |
145 | | - |
146 | | - def test_generate(self): |
147 | | - prompts = ["Hello, world!", "Hi, my name is", "How are you?", "What's up?"] |
148 | | - cluster_results = [] |
149 | | - refs = [] |
150 | | - for engine in self.engines: |
151 | | - for prompt in prompts: |
152 | | - refs.append(engine.generate_async.remote(prompt)) |
153 | | - cluster_results = ray.get(refs) |
154 | | - |
155 | | - self.assertEqual( |
156 | | - len(cluster_results), |
157 | | - len(self.engines) * self.config.explorer.repeat_times * len(prompts), |
158 | | - ) |
159 | | - |
160 | | - def test_chat_and_logprobs(self): |
161 | | - messages = [ |
162 | | - [ |
163 | | - {"role": "system", "content": "You are a helpful assistant."}, |
164 | | - {"role": "user", "content": "Hello, world!"}, |
165 | | - ], |
166 | | - [ |
167 | | - {"role": "system", "content": "You are a helpful assistant."}, |
168 | | - {"role": "user", "content": "Please tell me about yourself."}, |
169 | | - ], |
170 | | - [ |
171 | | - {"role": "system", "content": "You are a helpful assistant."}, |
172 | | - {"role": "user", "content": "Please tell me a joke."}, |
173 | | - ], |
174 | | - ] |
175 | | - cluster_results = [] |
176 | | - refs = [] |
177 | | - for engine in self.engines: |
178 | | - for message in messages: |
179 | | - refs.append(engine.chat_async.remote(message)) |
180 | | - cluster_results = ray.get(refs) |
181 | | - |
182 | | - self.assertEqual( |
183 | | - len(cluster_results), |
184 | | - len(self.engines) * self.config.explorer.repeat_times * len(messages), |
185 | | - ) |
186 | | - logprobs_refs = [] |
187 | | - for i, messages in enumerate(messages): |
188 | | - token_ids = cluster_results[i][0].tokens.tolist() |
189 | | - logprobs_refs.append(self.engines[0].logprobs_async.remote(token_ids)) |
190 | | - logprobs = ray.get(logprobs_refs) |
191 | | - for i, messages in enumerate(messages): |
192 | | - self.assertEqual(logprobs[i].shape[0], cluster_results[i][0].tokens.shape[0]) |
193 | | - |
194 | | - |
195 | 79 | class TestModelWrapper: |
196 | 80 | def test_generate(self): |
197 | 81 | prompts = ["Hello, world!", "Hello, my name is"] |
198 | 82 | results = self.model_wrapper.generate(prompts) |
199 | 83 | self.assertEqual(len(results), len(prompts) * self.config.explorer.repeat_times) |
200 | | - |
201 | | - def test_chat_and_logprobs(self): |
202 | 84 | messages = [ |
203 | 85 | {"role": "system", "content": "You are a helpful assistant."}, |
204 | 86 | {"role": "user", "content": "What's the weather like today?"}, |
@@ -235,23 +117,23 @@ def test_chat_and_logprobs(self): |
235 | 117 | self.assertTrue(torch.equal(result_dict["input_ids"][0], exp.tokens)) |
236 | 118 |
|
237 | 119 |
|
238 | | -# @unittest.skip("Skip VLLM test") |
239 | 120 | class TestModelWrapperSync(TestModelWrapper, unittest.TestCase): |
240 | 121 | def setUp(self): |
241 | 122 | ray.init(ignore_reinit_error=True) |
242 | 123 | self.config = load_config(config_dir) |
| 124 | + self.config.model.model_path = os.environ.get("MODEL_PATH") |
243 | 125 | self.config.explorer.engine_type = "vllm" |
244 | 126 | self.config.explorer.engine_num = 1 |
245 | 127 | self.config.explorer.chat_template = CHAT_TEMPLATE |
246 | 128 | self.engines = create_rollout_models(self.config) |
247 | 129 | self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm") |
248 | 130 |
|
249 | 131 |
|
250 | | -# @unittest.skip("Skip VLLM test") |
251 | 132 | class TestModelWrapperAsync(TestModelWrapper, unittest.TestCase): |
252 | 133 | def setUp(self): |
253 | 134 | ray.init(ignore_reinit_error=True) |
254 | 135 | self.config = load_config(config_dir) |
| 136 | + self.config.model.model_path = os.environ.get("MODEL_PATH") |
255 | 137 | self.config.explorer.engine_type = "vllm_async" |
256 | 138 | self.config.explorer.engine_num = 1 |
257 | 139 | self.config.explorer.chat_template = CHAT_TEMPLATE |
|
0 commit comments