Skip to content

Commit 59af943

Browse files
fix: fix timeout error in vllmwrapper
1 parent f2796f5 commit 59af943

File tree

1 file changed

+73
-38
lines changed

1 file changed

+73
-38
lines changed

graphgen/models/llm/local/vllm_wrapper.py

Lines changed: 73 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import math
22
import uuid
33
from typing import Any, List, Optional
4+
import asyncio
45

56
from graphgen.bases.base_llm_wrapper import BaseLLMWrapper
67
from graphgen.bases.datatypes import Token
@@ -19,6 +20,7 @@ def __init__(
1920
temperature: float = 0.6,
2021
top_p: float = 1.0,
2122
topk: int = 5,
23+
timeout: float = 300.0,
2224
**kwargs: Any,
2325
):
2426
super().__init__(temperature=temperature, top_p=top_p, **kwargs)
@@ -42,6 +44,7 @@ def __init__(
4244
self.temperature = temperature
4345
self.top_p = top_p
4446
self.topk = topk
47+
self.timeout = timeout
4548

4649
@staticmethod
4750
def _build_inputs(prompt: str, history: Optional[List[str]] = None) -> str:
@@ -57,6 +60,12 @@ def _build_inputs(prompt: str, history: Optional[List[str]] = None) -> str:
5760
lines.append(prompt)
5861
return "\n".join(lines)
5962

63+
async def _consume_generator(self, generator):
64+
final_output = None
65+
async for request_output in generator:
66+
final_output = request_output
67+
return final_output
68+
6069
async def generate_answer(
6170
self, text: str, history: Optional[List[str]] = None, **extra: Any
6271
) -> str:
@@ -71,14 +80,27 @@ async def generate_answer(
7180

7281
result_generator = self.engine.generate(full_prompt, sp, request_id=request_id)
7382

74-
final_output = None
75-
async for request_output in result_generator:
76-
final_output = request_output
77-
78-
if not final_output or not final_output.outputs:
79-
return ""
80-
81-
return final_output.outputs[0].text
83+
try:
84+
final_output = await asyncio.wait_for(
85+
self._consume_generator(result_generator),
86+
timeout=self.timeout
87+
)
88+
89+
if not final_output or not final_output.outputs:
90+
return ""
91+
92+
result_text = final_output.outputs[0].text
93+
return result_text
94+
95+
except asyncio.TimeoutError:
96+
await self.engine.abort(request_id)
97+
raise
98+
except asyncio.CancelledError:
99+
await self.engine.abort(request_id)
100+
raise
101+
except Exception as e:
102+
await self.engine.abort(request_id)
103+
raise
82104

83105
async def generate_topk_per_token(
84106
self, text: str, history: Optional[List[str]] = None, **extra: Any
@@ -95,41 +117,54 @@ async def generate_topk_per_token(
95117

96118
result_generator = self.engine.generate(full_prompt, sp, request_id=request_id)
97119

98-
final_output = None
99-
async for request_output in result_generator:
100-
final_output = request_output
101-
102-
if (
103-
not final_output
104-
or not final_output.outputs
105-
or not final_output.outputs[0].logprobs
106-
):
107-
return []
108-
109-
top_logprobs = final_output.outputs[0].logprobs[0]
110-
111-
candidate_tokens = []
112-
for _, logprob_obj in top_logprobs.items():
113-
tok_str = (
114-
logprob_obj.decoded_token.strip() if logprob_obj.decoded_token else ""
115-
)
116-
prob = float(math.exp(logprob_obj.logprob))
117-
candidate_tokens.append(Token(tok_str, prob))
118-
119-
candidate_tokens.sort(key=lambda x: -x.prob)
120-
121-
if candidate_tokens:
122-
main_token = Token(
123-
text=candidate_tokens[0].text,
124-
prob=candidate_tokens[0].prob,
125-
top_candidates=candidate_tokens,
120+
try:
121+
final_output = await asyncio.wait_for(
122+
self._consume_generator(result_generator),
123+
timeout=self.timeout
126124
)
127-
return [main_token]
128-
return []
125+
126+
if (
127+
not final_output
128+
or not final_output.outputs
129+
or not final_output.outputs[0].logprobs
130+
):
131+
return []
132+
133+
top_logprobs = final_output.outputs[0].logprobs[0]
134+
135+
candidate_tokens = []
136+
for _, logprob_obj in top_logprobs.items():
137+
tok_str = (
138+
logprob_obj.decoded_token.strip() if logprob_obj.decoded_token else ""
139+
)
140+
prob = float(math.exp(logprob_obj.logprob))
141+
candidate_tokens.append(Token(tok_str, prob))
142+
143+
candidate_tokens.sort(key=lambda x: -x.prob)
144+
145+
if candidate_tokens:
146+
main_token = Token(
147+
text=candidate_tokens[0].text,
148+
prob=candidate_tokens[0].prob,
149+
top_candidates=candidate_tokens,
150+
)
151+
return [main_token]
152+
return []
153+
154+
except asyncio.TimeoutError:
155+
await self.engine.abort(request_id)
156+
raise
157+
except asyncio.CancelledError:
158+
await self.engine.abort(request_id)
159+
raise
160+
except Exception as e:
161+
await self.engine.abort(request_id)
162+
raise
129163

130164
async def generate_inputs_prob(
131165
self, text: str, history: Optional[List[str]] = None, **extra: Any
132166
) -> List[Token]:
133167
raise NotImplementedError(
134168
"VLLMWrapper does not support per-token logprobs yet."
135169
)
170+

0 commit comments

Comments
 (0)