1+ import asyncio
12import os
23import unittest
34
4- import ray
55import torch
66from openai import BadRequestError
77from parameterized import parameterized_class
1313 get_model_path ,
1414 get_template_config ,
1515)
16- from trinity .common .constants import ROLLOUT_WEIGHT_SYNC_GROUP_NAME
1716from trinity .common .models import create_inference_models
1817from trinity .common .models .model import ModelWrapper
1918from trinity .common .models .utils import (
@@ -29,6 +28,16 @@ def print_debug(*args):
2928 print (* args )
3029
3130
31+ async def prepare_engines (engines , auxiliary_engines ):
32+ prepare_model_refs = []
33+ for engine in engines :
34+ prepare_model_refs .append (engine .prepare .remote ())
35+ for engines in auxiliary_engines :
36+ for engine in engines :
37+ prepare_model_refs .append (engine .prepare .remote ())
38+ await asyncio .gather (* prepare_model_refs )
39+
40+
3241# Qwen2.5 chat template with {% generation %} mark
3342CHAT_TEMPLATE = r"""
3443{%- if tools %}
@@ -127,6 +136,7 @@ def setUp(self):
127136 async def test_generate (
128137 self ,
129138 ):
139+ await prepare_engines (self .engines , self .auxiliary_engines )
130140 await self .model_wrapper .prepare ()
131141 self .assertEqual (self .model_wrapper .model_path , self .config .model .model_path )
132142 self .assertEqual (await self .model_wrapper .model_path_async , self .config .model .model_path )
@@ -244,6 +254,7 @@ def setUp(self):
244254 self .tokenizer = AutoTokenizer .from_pretrained (self .config .model .model_path )
245255
246256 async def test_model_len (self ):
257+ await prepare_engines (self .engines , self .auxiliary_engines )
247258 await self .model_wrapper .prepare ()
248259 messages = [
249260 {"role" : "system" , "content" : "You are a helpful assistant." },
@@ -311,6 +322,7 @@ def setUp(self):
311322 self .model_wrapper = ModelWrapper (self .engines [0 ], engine_type = "vllm" , enable_history = True )
312323
313324 async def test_model_len (self ):
325+ await prepare_engines (self .engines , self .auxiliary_engines )
314326 await self .model_wrapper .prepare ()
315327 messages = [
316328 {"role" : "user" , "content" : "How are you?" },
@@ -362,6 +374,7 @@ def setUp(self):
362374 )
363375
364376 async def test_api (self ):
377+ await prepare_engines (self .engines , self .auxiliary_engines )
365378 await self .model_wrapper .prepare ()
366379 await self .model_wrapper_no_history .prepare ()
367380 openai_client = self .model_wrapper .get_openai_client ()
@@ -435,12 +448,42 @@ async def test_api(self):
435448 self .assertEqual (len (self .model_wrapper_no_history .history ), 0 )
436449
437450
438- class DummySynchronizer :
439- def __init__ (self ):
440- pass
451+ SYSTEM_PROMPT = """
452+ You are Qwen, created by Alibaba Cloud. You are a helpful assistant. You are walking on a frozen lake.
453+
454+ FrozenLake Quick Guide
455+ Goal: Reach the goal (G). Player (P) and Goal (G) must overlap.
456+
457+ Symbols:
458+ _ Frozen | O Hole | G Goal | P Player
459+
460+ Rules:
461+ 1. Avoid falling into holes (O).
462+ 2. Frozen tiles are slippery, you may move perpendicular to your intended direction.
463+
464+ Valid Action (separated by | ):
465+ Up | Down | Left | Right
466+
467+ Rewards:
468+ Fall into hole: 0
469+ Reach goal: +1.0
470+
471+ You will be provided the current observation, please decide on the next Action.
472+ You should show your thought process and then input the final action in ``` ```.
473+ You should only output the NEXT ACTION at each interation in the ``` ```. For example, if you want to move up, you should output ```Up```.
474+ You should plan ahead and need to achieve it in minimum number of steps.
475+ You should be aware that frozen tiles can be slippery, but the chance is small and you should not overthink it.
476+
477+ Please show your thinking process and put the final action in ``` ```. In every turn, the final action MUST be one of Up, Down, Left, Right.
478+ """
441479
442- def do_nothing (self ):
443- pass
480+ USER_PROMPT = """Current Observation (0):
481+ _ G _
482+ _ _ _
483+ P O O
484+ You have not achieved the goal, P has not reached G yet. Please give the next action.
485+ The maximum number of steps remaining is 10.
486+ """
444487
445488
446489class TestLogprobs (RayUnittestBaseAysnc ):
@@ -458,31 +501,76 @@ def setUp(self):
458501 self .engines , self .auxiliary_engines = create_inference_models (self .config )
459502 self .model_wrapper = ModelWrapper (self .engines [0 ], engine_type = "vllm" , enable_history = True )
460503
461- async def test_logprobs (self ):
462- # use init process group to apply patches
463- sync = (
464- ray .remote (DummySynchronizer )
465- .options (name = "synchronizer" , namespace = self .config .ray_namespace )
466- .remote ()
467- )
468- await sync .__ray_ready__ .remote ()
504+ async def test_logprobs_api (self ):
505+ await prepare_engines (self .engines , self .auxiliary_engines )
469506 await self .model_wrapper .prepare ()
470- master_address , master_port = await self .engines [0 ].get_available_address .remote ()
471- await self .engines [0 ].init_process_group .remote (
472- master_address ,
473- master_port ,
474- world_size = 1 ,
475- rank_offset = 0 ,
476- group_name = ROLLOUT_WEIGHT_SYNC_GROUP_NAME ,
477- explorer_name = self .config .explorer .name ,
478- timeout = 20 ,
479- )
480507 messages = [
481- {"role" : "system" , "content" : "You are a helpful assistant." },
482- {"role" : "user" , "content" : "What is your name?" },
508+ {"role" : "system" , "content" : SYSTEM_PROMPT },
509+ {"role" : "user" , "content" : USER_PROMPT },
483510 ]
484- response_1 = self .model_wrapper .chat (messages , n = 1 , temperature = 1.0 , logprobs = True )[0 ]
485- response_2 = self .model_wrapper .chat (messages , n = 1 , temperature = 0.8 , logprobs = True )[0 ]
511+
512+ # Test openai api logprobs with different temperature
513+
514+ self .model_client = self .model_wrapper .get_openai_async_client ()
515+ _ = await self .model_client .chat .completions .create (
516+ model = self .model_client .model_path ,
517+ messages = messages ,
518+ n = 1 ,
519+ temperature = 1.0 ,
520+ logprobs = True ,
521+ max_tokens = 15 ,
522+ )
523+ response_1 = self .model_wrapper .extract_experience_from_history ()[0 ]
524+ _ = await self .model_client .chat .completions .create (
525+ model = self .model_client .model_path ,
526+ messages = messages ,
527+ n = 1 ,
528+ temperature = 0.8 ,
529+ logprobs = True ,
530+ max_tokens = 15 ,
531+ )
532+ response_2 = self .model_wrapper .extract_experience_from_history ()[0 ]
533+ self .assertTrue (response_1 .logprobs is not None )
534+ self .assertTrue (len (response_1 .logprobs ) > 0 )
535+ self .assertTrue (response_2 .logprobs is not None )
536+ self .assertTrue (len (response_2 .logprobs ) > 0 )
537+ logprobs_1 = self .model_wrapper .logprobs (response_1 .tokens .tolist (), temperature = 1.0 )
538+ logprobs_2 = self .model_wrapper .logprobs (response_1 .tokens .tolist (), temperature = 0.8 )
539+ logprobs_3 = self .model_wrapper .logprobs (response_2 .tokens .tolist (), temperature = 1.0 )
540+ logprobs_4 = self .model_wrapper .logprobs (response_2 .tokens .tolist (), temperature = 0.8 )
541+ self .assertEqual (logprobs_1 .shape , logprobs_2 .shape )
542+ self .assertEqual (logprobs_3 .shape , logprobs_4 .shape )
543+ self .assertFalse (torch .allclose (logprobs_1 , logprobs_2 , rtol = 0.4 ))
544+ self .assertFalse (torch .allclose (logprobs_3 , logprobs_4 , atol = 0.4 ))
545+ logprobs_1_prompt = logprobs_1 [: response_1 .prompt_length - 1 ]
546+ logprobs_2_prompt = logprobs_2 [: response_1 .prompt_length - 1 ]
547+ logprobs_3_prompt = logprobs_3 [: response_2 .prompt_length - 1 ]
548+ logprobs_4_prompt = logprobs_4 [: response_2 .prompt_length - 1 ]
549+ self .assertEqual (logprobs_1_prompt .shape , logprobs_2_prompt .shape )
550+ self .assertFalse (torch .allclose (logprobs_1_prompt , logprobs_2_prompt , rtol = 0.4 ))
551+ self .assertFalse (torch .allclose (logprobs_3_prompt , logprobs_4_prompt , rtol = 0.4 ))
552+ self .assertTrue (torch .allclose (logprobs_1_prompt , logprobs_3_prompt , rtol = 0.4 ))
553+ self .assertTrue (torch .allclose (logprobs_2_prompt , logprobs_4_prompt , rtol = 0.4 ))
554+ logprobs_1_response = logprobs_1 [response_1 .prompt_length - 1 :]
555+ logprobs_2_response = logprobs_2 [response_1 .prompt_length - 1 :]
556+ logprobs_3_response = logprobs_3 [response_2 .prompt_length - 1 :]
557+ logprobs_4_response = logprobs_4 [response_2 .prompt_length - 1 :]
558+ self .assertEqual (logprobs_1_response .shape , logprobs_2_response .shape )
559+ self .assertEqual (logprobs_3_response .shape , logprobs_4_response .shape )
560+ self .assertEqual (logprobs_1_response .shape , logprobs_2_response .shape )
561+ self .assertEqual (response_1 .logprobs .shape , logprobs_1_response .shape )
562+ self .assertTrue (torch .allclose (response_1 .logprobs , logprobs_1_response , rtol = 0.5 ))
563+ self .assertFalse (torch .allclose (response_1 .logprobs , logprobs_2_response , rtol = 0.5 ))
564+ self .assertTrue (torch .allclose (response_2 .logprobs , logprobs_4_response , rtol = 0.8 ))
565+ self .assertFalse (torch .allclose (response_2 .logprobs , logprobs_3_response , rtol = 0.8 ))
566+
567+ # test vllm engine logprobs with different temperature
568+ response_1 = self .model_wrapper .chat (
569+ messages , n = 1 , temperature = 1.0 , logprobs = True , max_tokens = 15
570+ )[0 ]
571+ response_2 = self .model_wrapper .chat (
572+ messages , n = 1 , temperature = 0.8 , logprobs = True , max_tokens = 15
573+ )[0 ]
486574 self .assertTrue (response_1 .logprobs is not None )
487575 self .assertTrue (len (response_1 .logprobs ) > 0 )
488576 self .assertTrue (response_2 .logprobs is not None )
@@ -517,6 +605,56 @@ async def test_logprobs(self):
517605 self .assertTrue (torch .allclose (response_2 .logprobs , logprobs_4_response , rtol = 0.8 ))
518606 self .assertFalse (torch .allclose (response_2 .logprobs , logprobs_3_response , rtol = 0.8 ))
519607
608+ # test openai api and vllm engine logprobs consistency
609+ await self .model_wrapper .clean_workflow_state ()
610+ _ = await self .model_client .chat .completions .create (
611+ model = self .model_client .model_path ,
612+ messages = messages ,
613+ n = 1 ,
614+ temperature = 1.0 ,
615+ logprobs = 0 ,
616+ max_tokens = 1 ,
617+ )
618+ response_openai_1 = self .model_wrapper .extract_experience_from_history ()[0 ]
619+ _ = await self .model_client .chat .completions .create (
620+ model = self .model_client .model_path ,
621+ messages = messages ,
622+ n = 1 ,
623+ temperature = 0.8 ,
624+ logprobs = 0 ,
625+ max_tokens = 1 ,
626+ )
627+ response_openai_2 = self .model_wrapper .extract_experience_from_history ()[0 ]
628+ response_vllm_1 = self .model_wrapper .chat (
629+ messages ,
630+ n = 1 ,
631+ temperature = 1.0 ,
632+ logprobs = 0 ,
633+ max_tokens = 1 ,
634+ )[0 ]
635+ response_vllm_2 = self .model_wrapper .chat (
636+ messages ,
637+ n = 1 ,
638+ temperature = 0.8 ,
639+ logprobs = 0 ,
640+ max_tokens = 1 ,
641+ )[0 ]
642+ self .assertEqual (len (response_openai_1 .tokens ), len (response_vllm_1 .tokens ))
643+ self .assertTrue (
644+ torch .allclose (
645+ response_openai_1 .logprobs ,
646+ response_vllm_1 .logprobs ,
647+ rtol = 0.1 ,
648+ )
649+ )
650+ self .assertTrue (
651+ torch .allclose (
652+ response_openai_2 .logprobs ,
653+ response_vllm_2 .logprobs ,
654+ rtol = 0.1 ,
655+ )
656+ )
657+
520658
521659class TestAsyncAPIServer (RayUnittestBaseAysnc ):
522660 def setUp (self ):
@@ -537,6 +675,7 @@ def setUp(self):
537675 )
538676
539677 async def test_api_async (self ):
678+ await prepare_engines (self .engines , self .auxiliary_engines )
540679 await self .model_wrapper .prepare ()
541680 await self .model_wrapper_no_history .prepare ()
542681 openai_client = self .model_wrapper .get_openai_async_client ()
@@ -758,6 +897,7 @@ async def test_api_tool_calls(self):
758897 import json
759898 import time
760899
900+ await prepare_engines (self .engines , self .auxiliary_engines )
761901 await self .model_wrapper .prepare ()
762902 await self .model_wrapper_no_history .prepare ()
763903 tokenizer = AutoTokenizer .from_pretrained (get_api_model_path ())
0 commit comments