11import math
22from dataclasses import dataclass
33from typing import List , Dict , Optional
4- from openai import AsyncOpenAI , RateLimitError , APIConnectionError , APITimeoutError , ChatCompletion
5- from models import TopkTokenModel , Token
4+ import openai
5+ from openai import AsyncOpenAI , RateLimitError , APIConnectionError , APITimeoutError
66from tenacity import (
77 retry ,
88 stop_after_attempt ,
99 wait_exponential ,
1010 retry_if_exception_type ,
1111)
1212
13- def get_top_response_tokens (response : ChatCompletion ) -> List [Token ]:
13+ from models import TopkTokenModel , Token
14+
15+
16+ def get_top_response_tokens (response : openai .ChatCompletion ) -> List [Token ]:
1417 token_logprobs = response .choices [0 ].logprobs .content
1518 tokens = []
1619 for token_prob in token_logprobs :
@@ -76,6 +79,7 @@ async def generate_topk_per_token(self, text: str, history: Optional[List[str]]
7679
7780 completion = await self .client .chat .completions .create (
7881 model = self .model_name ,
82+ messages = kwargs ["messages" ],
7983 ** kwargs
8084 )
8185
@@ -94,7 +98,11 @@ async def generate_answer(self, text: str, history: Optional[List[str]] = None,
9498
9599 completion = await self .client .chat .completions .create (
96100 model = self .model_name ,
101+ messages = kwargs ["messages" ],
97102 ** kwargs
98103 )
99104
100105 return completion .choices [0 ].message .content
106+
107+ async def generate_inputs_prob (self , text : str , history : Optional [List [str ]] = None ) -> List [Token ]:
108+ raise NotImplementedError
0 commit comments