@@ -442,7 +442,7 @@ async def test_api(self):
442442 )
443443 self .assertEqual (2 , len (response .choices ))
444444 self .assertTrue (hasattr (response .choices [0 ], "token_ids" ))
445- self .assertTrue (len ( response .choices [0 ].token_ids ) > 0 )
445+ self .assertTrue (response .choices [0 ].token_ids is None )
446446 with self .assertRaises (ValueError ):
447447 self .model_wrapper_no_history .extract_experience_from_history ()
448448 self .assertEqual (len (self .model_wrapper_no_history .history ), 0 )
@@ -496,6 +496,7 @@ def setUp(self):
496496 self .config .explorer .rollout_model .tensor_parallel_size = 1
497497 self .config .explorer .rollout_model .chat_template = CHAT_TEMPLATE
498498 self .config .explorer .rollout_model .enable_openai_api = True
499+ self .config .explorer .rollout_model .enable_log_requests = True
499500
500501 self .config .check_and_update ()
501502 self .engines , self .auxiliary_engines = create_inference_models (self .config )
@@ -540,17 +541,17 @@ async def test_logprobs_api(self):
540541 logprobs_4 = self .model_wrapper .logprobs (response_2 .tokens .tolist (), temperature = 0.8 )
541542 self .assertEqual (logprobs_1 .shape , logprobs_2 .shape )
542543 self .assertEqual (logprobs_3 .shape , logprobs_4 .shape )
543- self .assertFalse (torch .allclose (logprobs_1 , logprobs_2 , rtol = 0.4 ))
544- self .assertFalse (torch .allclose (logprobs_3 , logprobs_4 , atol = 0.4 ))
544+ self .assertFalse (torch .allclose (logprobs_1 , logprobs_2 , rtol = 0.3 , atol = 1e-3 ))
545+ self .assertFalse (torch .allclose (logprobs_3 , logprobs_4 , rtol = 0.3 , atol = 1e-3 ))
545546 logprobs_1_prompt = logprobs_1 [: response_1 .prompt_length - 1 ]
546547 logprobs_2_prompt = logprobs_2 [: response_1 .prompt_length - 1 ]
547548 logprobs_3_prompt = logprobs_3 [: response_2 .prompt_length - 1 ]
548549 logprobs_4_prompt = logprobs_4 [: response_2 .prompt_length - 1 ]
549550 self .assertEqual (logprobs_1_prompt .shape , logprobs_2_prompt .shape )
550- self .assertFalse (torch .allclose (logprobs_1_prompt , logprobs_2_prompt , rtol = 0.4 ))
551- self .assertFalse (torch .allclose (logprobs_3_prompt , logprobs_4_prompt , rtol = 0.4 ))
552- self .assertTrue (torch .allclose (logprobs_1_prompt , logprobs_3_prompt , rtol = 0.4 ))
553- self .assertTrue (torch .allclose (logprobs_2_prompt , logprobs_4_prompt , rtol = 0.4 ))
551+ self .assertFalse (torch .allclose (logprobs_1_prompt , logprobs_2_prompt , rtol = 0.3 , atol = 1e-3 ))
552+ self .assertFalse (torch .allclose (logprobs_3_prompt , logprobs_4_prompt , rtol = 0.3 , atol = 1e-3 ))
553+ self .assertTrue (torch .allclose (logprobs_1_prompt , logprobs_3_prompt , rtol = 0.3 , atol = 1e-3 ))
554+ self .assertTrue (torch .allclose (logprobs_2_prompt , logprobs_4_prompt , rtol = 0.3 , atol = 1e-3 ))
554555 logprobs_1_response = logprobs_1 [response_1 .prompt_length - 1 :]
555556 logprobs_2_response = logprobs_2 [response_1 .prompt_length - 1 :]
556557 logprobs_3_response = logprobs_3 [response_2 .prompt_length - 1 :]
@@ -559,10 +560,18 @@ async def test_logprobs_api(self):
559560 self .assertEqual (logprobs_3_response .shape , logprobs_4_response .shape )
560561 self .assertEqual (logprobs_1_response .shape , logprobs_2_response .shape )
561562 self .assertEqual (response_1 .logprobs .shape , logprobs_1_response .shape )
562- self .assertTrue (torch .allclose (response_1 .logprobs , logprobs_1_response , rtol = 0.5 ))
563- self .assertFalse (torch .allclose (response_1 .logprobs , logprobs_2_response , rtol = 0.5 ))
564- self .assertTrue (torch .allclose (response_2 .logprobs , logprobs_4_response , rtol = 0.8 ))
565- self .assertFalse (torch .allclose (response_2 .logprobs , logprobs_3_response , rtol = 0.8 ))
563+ self .assertTrue (
564+ torch .allclose (response_1 .logprobs , logprobs_1_response , rtol = 0.3 , atol = 1e-3 )
565+ )
566+ self .assertFalse (
567+ torch .allclose (response_1 .logprobs , logprobs_2_response , rtol = 0.3 , atol = 1e-3 )
568+ )
569+ self .assertTrue (
570+ torch .allclose (response_2 .logprobs , logprobs_4_response , rtol = 0.3 , atol = 1e-3 )
571+ )
572+ self .assertFalse (
573+ torch .allclose (response_2 .logprobs , logprobs_3_response , rtol = 0.3 , atol = 1e-3 )
574+ )
566575
567576 # test vllm engine logprobs with different temperature
568577 response_1 = self .model_wrapper .chat (
@@ -581,17 +590,17 @@ async def test_logprobs_api(self):
581590 logprobs_4 = self .model_wrapper .logprobs (response_2 .tokens .tolist (), temperature = 0.8 )
582591 self .assertEqual (logprobs_1 .shape , logprobs_2 .shape )
583592 self .assertEqual (logprobs_3 .shape , logprobs_4 .shape )
584- self .assertFalse (torch .allclose (logprobs_1 , logprobs_2 , rtol = 0.4 ))
585- self .assertFalse (torch .allclose (logprobs_3 , logprobs_4 , atol = 0.4 ))
593+ self .assertFalse (torch .allclose (logprobs_1 , logprobs_2 , rtol = 0.3 , atol = 1e-3 ))
594+ self .assertFalse (torch .allclose (logprobs_3 , logprobs_4 , rtol = 0.3 , atol = 1e-3 ))
586595 logprobs_1_prompt = logprobs_1 [: response_1 .prompt_length - 1 ]
587596 logprobs_2_prompt = logprobs_2 [: response_1 .prompt_length - 1 ]
588597 logprobs_3_prompt = logprobs_3 [: response_2 .prompt_length - 1 ]
589598 logprobs_4_prompt = logprobs_4 [: response_2 .prompt_length - 1 ]
590599 self .assertEqual (logprobs_1_prompt .shape , logprobs_2_prompt .shape )
591- self .assertFalse (torch .allclose (logprobs_1_prompt , logprobs_2_prompt , rtol = 0.4 ))
592- self .assertFalse (torch .allclose (logprobs_3_prompt , logprobs_4_prompt , rtol = 0.4 ))
593- self .assertTrue (torch .allclose (logprobs_1_prompt , logprobs_3_prompt , rtol = 0.4 ))
594- self .assertTrue (torch .allclose (logprobs_2_prompt , logprobs_4_prompt , rtol = 0.4 ))
600+ self .assertFalse (torch .allclose (logprobs_1_prompt , logprobs_2_prompt , rtol = 0.3 , atol = 1e-3 ))
601+ self .assertFalse (torch .allclose (logprobs_3_prompt , logprobs_4_prompt , rtol = 0.3 , atol = 1e-3 ))
602+ self .assertTrue (torch .allclose (logprobs_1_prompt , logprobs_3_prompt , rtol = 0.3 , atol = 1e-3 ))
603+ self .assertTrue (torch .allclose (logprobs_2_prompt , logprobs_4_prompt , rtol = 0.3 , atol = 1e-3 ))
595604 logprobs_1_response = logprobs_1 [response_1 .prompt_length - 1 :]
596605 logprobs_2_response = logprobs_2 [response_1 .prompt_length - 1 :]
597606 logprobs_3_response = logprobs_3 [response_2 .prompt_length - 1 :]
@@ -600,10 +609,18 @@ async def test_logprobs_api(self):
600609 self .assertEqual (logprobs_3_response .shape , logprobs_4_response .shape )
601610 self .assertEqual (logprobs_1_response .shape , logprobs_2_response .shape )
602611 self .assertEqual (response_1 .logprobs .shape , logprobs_1_response .shape )
603- self .assertTrue (torch .allclose (response_1 .logprobs , logprobs_1_response , rtol = 0.5 ))
604- self .assertFalse (torch .allclose (response_1 .logprobs , logprobs_2_response , rtol = 0.5 ))
605- self .assertTrue (torch .allclose (response_2 .logprobs , logprobs_4_response , rtol = 0.8 ))
606- self .assertFalse (torch .allclose (response_2 .logprobs , logprobs_3_response , rtol = 0.8 ))
612+ self .assertTrue (
613+ torch .allclose (response_1 .logprobs , logprobs_1_response , rtol = 0.3 , atol = 1e-3 )
614+ )
615+ self .assertFalse (
616+ torch .allclose (response_1 .logprobs , logprobs_2_response , rtol = 0.3 , atol = 1e-3 )
617+ )
618+ self .assertTrue (
619+ torch .allclose (response_2 .logprobs , logprobs_4_response , rtol = 0.3 , atol = 1e-3 )
620+ )
621+ self .assertFalse (
622+ torch .allclose (response_2 .logprobs , logprobs_3_response , rtol = 0.3 , atol = 1e-3 )
623+ )
607624
608625 # test openai api and vllm engine logprobs consistency
609626 await self .model_wrapper .clean_workflow_state ()
@@ -747,7 +764,7 @@ async def test_api_async(self):
747764 )
748765 self .assertEqual (2 , len (response .choices ))
749766 self .assertTrue (hasattr (response .choices [0 ], "token_ids" ))
750- self .assertTrue (len ( response .choices [0 ].token_ids ) > 0 )
767+ self .assertTrue (response .choices [0 ].token_ids is None )
751768 with self .assertRaises (ValueError ):
752769 self .model_wrapper_no_history .extract_experience_from_history ()
753770 self .assertEqual (len (self .model_wrapper_no_history .history ), 0 )
0 commit comments