11import math
22import uuid
33from typing import Any , List , Optional
4+ import asyncio
45
56from graphgen .bases .base_llm_wrapper import BaseLLMWrapper
67from graphgen .bases .datatypes import Token
@@ -19,6 +20,7 @@ def __init__(
1920 temperature : float = 0.6 ,
2021 top_p : float = 1.0 ,
2122 topk : int = 5 ,
23+ timeout : float = 300.0 ,
2224 ** kwargs : Any ,
2325 ):
2426 super ().__init__ (temperature = temperature , top_p = top_p , ** kwargs )
@@ -42,6 +44,7 @@ def __init__(
4244 self .temperature = temperature
4345 self .top_p = top_p
4446 self .topk = topk
47+ self .timeout = timeout
4548
4649 @staticmethod
4750 def _build_inputs (prompt : str , history : Optional [List [str ]] = None ) -> str :
@@ -57,6 +60,12 @@ def _build_inputs(prompt: str, history: Optional[List[str]] = None) -> str:
5760 lines .append (prompt )
5861 return "\n " .join (lines )
5962
63+ async def _consume_generator (self , generator ):
64+ final_output = None
65+ async for request_output in generator :
66+ final_output = request_output
67+ return final_output
68+
6069 async def generate_answer (
6170 self , text : str , history : Optional [List [str ]] = None , ** extra : Any
6271 ) -> str :
@@ -71,14 +80,27 @@ async def generate_answer(
7180
7281 result_generator = self .engine .generate (full_prompt , sp , request_id = request_id )
7382
74- final_output = None
75- async for request_output in result_generator :
76- final_output = request_output
77-
78- if not final_output or not final_output .outputs :
79- return ""
80-
81- return final_output .outputs [0 ].text
83+ try :
84+ final_output = await asyncio .wait_for (
85+ self ._consume_generator (result_generator ),
86+ timeout = self .timeout
87+ )
88+
89+ if not final_output or not final_output .outputs :
90+ return ""
91+
92+ result_text = final_output .outputs [0 ].text
93+ return result_text
94+
95+ except asyncio .TimeoutError :
96+ await self .engine .abort (request_id )
97+ raise
98+ except asyncio .CancelledError :
99+ await self .engine .abort (request_id )
100+ raise
101+ except Exception as e :
102+ await self .engine .abort (request_id )
103+ raise
82104
83105 async def generate_topk_per_token (
84106 self , text : str , history : Optional [List [str ]] = None , ** extra : Any
@@ -95,41 +117,54 @@ async def generate_topk_per_token(
95117
96118 result_generator = self .engine .generate (full_prompt , sp , request_id = request_id )
97119
98- final_output = None
99- async for request_output in result_generator :
100- final_output = request_output
101-
102- if (
103- not final_output
104- or not final_output .outputs
105- or not final_output .outputs [0 ].logprobs
106- ):
107- return []
108-
109- top_logprobs = final_output .outputs [0 ].logprobs [0 ]
110-
111- candidate_tokens = []
112- for _ , logprob_obj in top_logprobs .items ():
113- tok_str = (
114- logprob_obj .decoded_token .strip () if logprob_obj .decoded_token else ""
115- )
116- prob = float (math .exp (logprob_obj .logprob ))
117- candidate_tokens .append (Token (tok_str , prob ))
118-
119- candidate_tokens .sort (key = lambda x : - x .prob )
120-
121- if candidate_tokens :
122- main_token = Token (
123- text = candidate_tokens [0 ].text ,
124- prob = candidate_tokens [0 ].prob ,
125- top_candidates = candidate_tokens ,
120+ try :
121+ final_output = await asyncio .wait_for (
122+ self ._consume_generator (result_generator ),
123+ timeout = self .timeout
126124 )
127- return [main_token ]
128- return []
125+
126+ if (
127+ not final_output
128+ or not final_output .outputs
129+ or not final_output .outputs [0 ].logprobs
130+ ):
131+ return []
132+
133+ top_logprobs = final_output .outputs [0 ].logprobs [0 ]
134+
135+ candidate_tokens = []
136+ for _ , logprob_obj in top_logprobs .items ():
137+ tok_str = (
138+ logprob_obj .decoded_token .strip () if logprob_obj .decoded_token else ""
139+ )
140+ prob = float (math .exp (logprob_obj .logprob ))
141+ candidate_tokens .append (Token (tok_str , prob ))
142+
143+ candidate_tokens .sort (key = lambda x : - x .prob )
144+
145+ if candidate_tokens :
146+ main_token = Token (
147+ text = candidate_tokens [0 ].text ,
148+ prob = candidate_tokens [0 ].prob ,
149+ top_candidates = candidate_tokens ,
150+ )
151+ return [main_token ]
152+ return []
153+
154+ except asyncio .TimeoutError :
155+ await self .engine .abort (request_id )
156+ raise
157+ except asyncio .CancelledError :
158+ await self .engine .abort (request_id )
159+ raise
160+ except Exception as e :
161+ await self .engine .abort (request_id )
162+ raise
129163
130164 async def generate_inputs_prob (
131165 self , text : str , history : Optional [List [str ]] = None , ** extra : Any
132166 ) -> List [Token ]:
133167 raise NotImplementedError (
134168 "VLLMWrapper does not support per-token logprobs yet."
135169 )
170+
0 commit comments