2929 InferenceClient ,
3030 InferenceEndpoint ,
3131 InferenceEndpointTimeoutError ,
32+ TextGenerationOutput ,
3233 create_inference_endpoint ,
3334 get_inference_endpoint ,
3435)
35- from huggingface_hub .inference ._text_generation import TextGenerationResponse
3636from torch .utils .data import DataLoader
3737from tqdm import tqdm
3838from transformers import AutoTokenizer
@@ -148,7 +148,7 @@ def max_length(self):
148148
149149 def __async_process_request (
150150 self , context : str , stop_tokens : list [str ], max_tokens : int
151- ) -> Coroutine [None , list [TextGenerationResponse ], str ]:
151+ ) -> Coroutine [None , list [TextGenerationOutput ], str ]:
152152 # Todo: add an option to launch with conversational instead for chat prompts
153153 # https://huggingface.co/docs/huggingface_hub/v0.20.3/en/package_reference/inference_client#huggingface_hub.AsyncInferenceClient.conversational
154154 generated_text = self .async_client .text_generation (
@@ -162,7 +162,7 @@ def __async_process_request(
162162
163163 return generated_text
164164
165- def __process_request (self , context : str , stop_tokens : list [str ], max_tokens : int ) -> TextGenerationResponse :
165+ def __process_request (self , context : str , stop_tokens : list [str ], max_tokens : int ) -> TextGenerationOutput :
166166 # Todo: add an option to launch with conversational instead for chat prompts
167167 # https://huggingface.co/docs/huggingface_hub/v0.20.3/en/package_reference/inference_client#huggingface_hub.AsyncInferenceClient.conversational
168168 generated_text = self .client .text_generation (
@@ -179,7 +179,7 @@ def __process_request(self, context: str, stop_tokens: list[str], max_tokens: in
179179 async def __async_process_batch_generate (
180180 self ,
181181 requests : list [GreedyUntilRequest | GreedyUntilWithLogitsRequest ],
182- ) -> list [TextGenerationResponse ]:
182+ ) -> list [TextGenerationOutput ]:
183183 return await asyncio .gather (
184184 * [
185185 self .__async_process_request (
@@ -194,7 +194,7 @@ async def __async_process_batch_generate(
194194 def __process_batch_generate (
195195 self ,
196196 requests : list [GreedyUntilRequest | GreedyUntilWithLogitsRequest ],
197- ) -> list [TextGenerationResponse ]:
197+ ) -> list [TextGenerationOutput ]:
198198 return [
199199 self .__process_request (
200200 context = request .context ,
@@ -206,7 +206,7 @@ def __process_batch_generate(
206206
207207 async def __async_process_batch_logprob (
208208 self , requests : list [LoglikelihoodRequest ], rolling : bool = False
209- ) -> list [TextGenerationResponse ]:
209+ ) -> list [TextGenerationOutput ]:
210210 return await asyncio .gather (
211211 * [
212212 self .__async_process_request (
@@ -220,7 +220,7 @@ async def __async_process_batch_logprob(
220220
221221 def __process_batch_logprob (
222222 self , requests : list [LoglikelihoodRequest ], rolling : bool = False
223- ) -> list [TextGenerationResponse ]:
223+ ) -> list [TextGenerationOutput ]:
224224 return [
225225 self .__process_request (
226226 context = request .context if rolling else request .context + request .choice ,
0 commit comments