@@ -327,7 +327,28 @@ async def test_api(self):
327327 )
328328 exps = self .model_wrapper .extract_experience_from_history ()
329329 self .assertEqual (len (exps ), 4 )
330+ for exp in exps :
331+ self .assertTrue (len (exp .tokens ) > 0 )
332+ self .assertTrue (len (exp .logprobs ) > 0 )
333+ self .assertTrue (exp .prompt_length + len (exp .logprobs ) == len (exp .tokens ))
330334 self .assertEqual (len (self .model_wrapper .extract_experience_from_history ()), 0 )
335+ response = openai_client .chat .completions .create (
336+ model = model_id ,
337+ messages = messages ,
338+ )
339+ exps = self .model_wrapper .extract_experience_from_history ()
340+ self .assertEqual (len (exps ), 1 )
341+ self .assertTrue (len (exps [0 ].tokens ) > 0 )
342+ self .assertTrue (len (exps [0 ].logprobs ) > 0 )
343+ self .assertTrue (exps [0 ].prompt_length + len (exps [0 ].logprobs ) == len (exps [0 ].tokens ))
344+ response = openai_client .chat .completions .create (
345+ model = model_id ,
346+ messages = messages ,
347+ logprobs = False ,
348+ )
349+ exps = self .model_wrapper .extract_experience_from_history ()
350+ self .assertEqual (len (exps ), 1 )
351+ self .assertTrue (len (exps [0 ].logprobs ) == 0 )
331352 response = self .model_wrapper_no_history .get_openai_client ().chat .completions .create (
332353 model = model_id , messages = messages , n = 2
333354 )
@@ -400,7 +421,28 @@ async def test_api_async(self):
400421 )
401422 exps = self .model_wrapper .extract_experience_from_history ()
402423 self .assertEqual (len (exps ), 4 )
424+ for exp in exps :
425+ self .assertTrue (len (exp .tokens ) > 0 )
426+ self .assertTrue (len (exp .logprobs ) > 0 )
427+ self .assertTrue (exp .prompt_length + len (exp .logprobs ) == len (exp .tokens ))
403428 self .assertEqual (len (self .model_wrapper .extract_experience_from_history ()), 0 )
429+ response = await openai_client .chat .completions .create (
430+ model = model_id ,
431+ messages = messages ,
432+ )
433+ exps = self .model_wrapper .extract_experience_from_history ()
434+ self .assertEqual (len (exps ), 1 )
435+ self .assertTrue (len (exps [0 ].tokens ) > 0 )
436+ self .assertTrue (len (exps [0 ].logprobs ) > 0 )
437+ self .assertTrue (exps [0 ].prompt_length + len (exps [0 ].logprobs ) == len (exps [0 ].tokens ))
438+ response = await openai_client .chat .completions .create (
439+ model = model_id ,
440+ messages = messages ,
441+ logprobs = False ,
442+ )
443+ exps = self .model_wrapper .extract_experience_from_history ()
444+ self .assertEqual (len (exps ), 1 )
445+ self .assertTrue (len (exps [0 ].logprobs ) == 0 )
404446 response = (
405447 await self .model_wrapper_no_history .get_openai_async_client ().chat .completions .create (
406448 model = model_id , messages = messages , n = 2
0 commit comments