|
1 | 1 | import unittest |
2 | 2 |
|
| 3 | +import ray |
3 | 4 | import torch |
4 | 5 | from openai import BadRequestError |
5 | 6 | from parameterized import parameterized_class |
|
11 | 12 | get_model_path, |
12 | 13 | get_template_config, |
13 | 14 | ) |
| 15 | +from trinity.common.constants import ROLLOUT_WEIGHT_SYNC_GROUP_NAME |
14 | 16 | from trinity.common.models import create_inference_models |
15 | 17 | from trinity.common.models.model import ModelWrapper |
16 | 18 | from trinity.common.models.utils import ( |
@@ -310,8 +312,9 @@ async def test_api(self): |
310 | 312 | ) |
311 | 313 | self.assertEqual(2, len(response.choices)) |
312 | 314 | self.assertTrue(response.choices[0].logprobs is not None) |
313 | | - self.assertEqual(0, len(response.choices[0].logprobs.content[0].top_logprobs)) |
314 | | - self.assertTrue(response.choices[0].logprobs.content[0].logprob < 0) |
| 315 | + self.assertEqual(0, len(response.choices[0].logprobs.content[2].top_logprobs)) |
| 316 | + # here we check the 3rd token logprob, because the first two tokens (`<think>`,`\n` usually have zero logprob) |
| 317 | + self.assertTrue(response.choices[0].logprobs.content[2].logprob < 0) |
315 | 318 | self.assertTrue(hasattr(response, "prompt_token_ids")) |
316 | 319 | self.assertTrue(len(response.prompt_token_ids) > 0) |
317 | 320 | self.assertTrue(hasattr(response.choices[0], "token_ids")) |
@@ -361,6 +364,89 @@ async def test_api(self): |
361 | 364 | self.assertEqual(len(self.model_wrapper_no_history.history), 0) |
362 | 365 |
|
363 | 366 |
|
| 367 | +class DummySynchronizer: |
| 368 | + def __init__(self): |
| 369 | + pass |
| 370 | + |
| 371 | + def do_nothing(self): |
| 372 | + pass |
| 373 | + |
| 374 | + |
| 375 | +class TestLogprobs(RayUnittestBaseAysnc): |
| 376 | + def setUp(self): |
| 377 | + self.config = get_template_config() |
| 378 | + self.config.mode = "explore" |
| 379 | + self.config.model.model_path = get_model_path() |
| 380 | + self.config.explorer.rollout_model.engine_type = "vllm" |
| 381 | + self.config.explorer.rollout_model.engine_num = 1 |
| 382 | + self.config.explorer.rollout_model.tensor_parallel_size = 1 |
| 383 | + self.config.explorer.rollout_model.chat_template = CHAT_TEMPLATE |
| 384 | + self.config.explorer.rollout_model.enable_openai_api = True |
| 385 | + |
| 386 | + self.config.check_and_update() |
| 387 | + self.engines, self.auxiliary_engines = create_inference_models(self.config) |
| 388 | + self.model_wrapper = ModelWrapper(self.engines[0], engine_type="vllm", enable_history=True) |
| 389 | + |
| 390 | + async def test_logprobs(self): |
| 391 | + # use init process group to apply patches |
| 392 | + sync = ( |
| 393 | + ray.remote(DummySynchronizer) |
| 394 | + .options(name="synchronizer", namespace=self.config.ray_namespace) |
| 395 | + .remote() |
| 396 | + ) |
| 397 | + await sync.__ray_ready__.remote() |
| 398 | + await self.model_wrapper.prepare() |
| 399 | + master_address, master_port = await self.engines[0].get_available_address.remote() |
| 400 | + await self.engines[0].init_process_group.remote( |
| 401 | + master_address, |
| 402 | + master_port, |
| 403 | + world_size=1, |
| 404 | + rank_offset=0, |
| 405 | + group_name=ROLLOUT_WEIGHT_SYNC_GROUP_NAME, |
| 406 | + explorer_name=self.config.explorer.name, |
| 407 | + timeout=20, |
| 408 | + ) |
| 409 | + messages = [ |
| 410 | + {"role": "system", "content": "You are a helpful assistant."}, |
| 411 | + {"role": "user", "content": "What is your name?"}, |
| 412 | + ] |
| 413 | + response_1 = self.model_wrapper.chat(messages, n=1, temperature=1.0, logprobs=True)[0] |
| 414 | + response_2 = self.model_wrapper.chat(messages, n=1, temperature=0.8, logprobs=True)[0] |
| 415 | + self.assertTrue(response_1.logprobs is not None) |
| 416 | + self.assertTrue(len(response_1.logprobs) > 0) |
| 417 | + self.assertTrue(response_2.logprobs is not None) |
| 418 | + self.assertTrue(len(response_2.logprobs) > 0) |
| 419 | + logprobs_1 = self.model_wrapper.logprobs(response_1.tokens.tolist(), temperature=1.0) |
| 420 | + logprobs_2 = self.model_wrapper.logprobs(response_1.tokens.tolist(), temperature=0.8) |
| 421 | + logprobs_3 = self.model_wrapper.logprobs(response_2.tokens.tolist(), temperature=1.0) |
| 422 | + logprobs_4 = self.model_wrapper.logprobs(response_2.tokens.tolist(), temperature=0.8) |
| 423 | + self.assertEqual(logprobs_1.shape, logprobs_2.shape) |
| 424 | + self.assertEqual(logprobs_3.shape, logprobs_4.shape) |
| 425 | + self.assertFalse(torch.allclose(logprobs_1, logprobs_2, rtol=0.4)) |
| 426 | + self.assertFalse(torch.allclose(logprobs_3, logprobs_4, atol=0.4)) |
| 427 | + logprobs_1_prompt = logprobs_1[: response_1.prompt_length - 1] |
| 428 | + logprobs_2_prompt = logprobs_2[: response_1.prompt_length - 1] |
| 429 | + logprobs_3_prompt = logprobs_3[: response_2.prompt_length - 1] |
| 430 | + logprobs_4_prompt = logprobs_4[: response_2.prompt_length - 1] |
| 431 | + self.assertEqual(logprobs_1_prompt.shape, logprobs_2_prompt.shape) |
| 432 | + self.assertFalse(torch.allclose(logprobs_1_prompt, logprobs_2_prompt, rtol=0.4)) |
| 433 | + self.assertFalse(torch.allclose(logprobs_3_prompt, logprobs_4_prompt, rtol=0.4)) |
| 434 | + self.assertTrue(torch.allclose(logprobs_1_prompt, logprobs_3_prompt, rtol=0.4)) |
| 435 | + self.assertTrue(torch.allclose(logprobs_2_prompt, logprobs_4_prompt, rtol=0.4)) |
| 436 | + logprobs_1_response = logprobs_1[response_1.prompt_length - 1 :] |
| 437 | + logprobs_2_response = logprobs_2[response_1.prompt_length - 1 :] |
| 438 | + logprobs_3_response = logprobs_3[response_2.prompt_length - 1 :] |
| 439 | + logprobs_4_response = logprobs_4[response_2.prompt_length - 1 :] |
| 440 | + self.assertEqual(logprobs_1_response.shape, logprobs_2_response.shape) |
| 441 | + self.assertEqual(logprobs_3_response.shape, logprobs_4_response.shape) |
| 442 | + self.assertEqual(logprobs_1_response.shape, logprobs_2_response.shape) |
| 443 | + self.assertEqual(response_1.logprobs.shape, logprobs_1_response.shape) |
| 444 | + self.assertTrue(torch.allclose(response_1.logprobs, logprobs_1_response, rtol=0.5)) |
| 445 | + self.assertFalse(torch.allclose(response_1.logprobs, logprobs_2_response, rtol=0.5)) |
| 446 | + self.assertTrue(torch.allclose(response_2.logprobs, logprobs_4_response, rtol=0.8)) |
| 447 | + self.assertFalse(torch.allclose(response_2.logprobs, logprobs_3_response, rtol=0.8)) |
| 448 | + |
| 449 | + |
364 | 450 | class TestAsyncAPIServer(RayUnittestBaseAysnc): |
365 | 451 | def setUp(self): |
366 | 452 | self.config = get_template_config() |
@@ -403,8 +489,9 @@ async def test_api_async(self): |
403 | 489 | ) |
404 | 490 | self.assertEqual(2, len(response.choices)) |
405 | 491 | self.assertTrue(response.choices[0].logprobs is not None) |
406 | | - self.assertEqual(0, len(response.choices[0].logprobs.content[0].top_logprobs)) |
407 | | - self.assertTrue(response.choices[0].logprobs.content[0].logprob < 0) |
| 492 | + self.assertEqual(0, len(response.choices[0].logprobs.content[2].top_logprobs)) |
| 493 | + # here we check the 3rd token logprob, because the first two tokens (`<think>`,`\n` usually have zero logprob) |
| 494 | + self.assertTrue(response.choices[0].logprobs.content[2].logprob < 0) |
408 | 495 | self.assertTrue(hasattr(response, "prompt_token_ids")) |
409 | 496 | self.assertTrue(len(response.prompt_token_ids) > 0) |
410 | 497 | self.assertTrue(hasattr(response.choices[0], "token_ids")) |
|
0 commit comments