Skip to content

Commit 492520d

Browse files
authored
Merge pull request #5588 from hpcaitech/feat/online-serving
[Feature]Online Serving
2 parents d482922 + 5d9a494 commit 492520d

21 files changed

+1172
-34
lines changed

colossalai/inference/batch_bucket.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,9 @@ def is_empty(self):
6262
def current_batch_size(self):
6363
return self._current_batch_size
6464

65+
def __len__(self):
66+
return self._current_batch_size
67+
6568
@property
6669
def available_batch_size(self):
6770
return self.max_batch_size - self._current_batch_size

colossalai/inference/config.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
"""
22
Our config contains various options for inference optimization, it is a unified API that wraps all the configurations for inference.
33
"""
4-
54
import logging
6-
from dataclasses import dataclass
7-
from typing import Optional, Union
5+
from dataclasses import dataclass, fields
6+
from typing import Any, Dict, Optional, Union
87

98
import torch
109
import torch.distributed as dist
@@ -214,3 +213,18 @@ def to_generation_config(self, model_config) -> GenerationConfig:
214213
meta_config[type] = getattr(model_config, type)
215214

216215
return GenerationConfig.from_dict(meta_config)
216+
217+
@classmethod
218+
def from_dict(cls, config_dict: Dict[str, Any]) -> "InferenceConfig":
219+
# Get the list of attributes of this dataclass.
220+
attrs = [attr.name for attr in fields(cls)]
221+
inference_config_args = {}
222+
for attr in attrs:
223+
if attr in config_dict:
224+
inference_config_args[attr] = config_dict[attr]
225+
else:
226+
inference_config_args[attr] = getattr(cls, attr)
227+
228+
# Set the attributes from the parsed arguments.
229+
inference_config = cls(**inference_config_args)
230+
return inference_config
Lines changed: 309 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,309 @@
1+
import asyncio
2+
import logging
3+
from functools import partial
4+
from typing import AsyncIterator, Dict, Iterable, List, Optional, Set, Tuple, Type
5+
6+
from colossalai.inference.core.engine import InferenceEngine
7+
8+
# CLI logger
9+
logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
10+
logger = logging.getLogger("colossalai-inference")
11+
12+
13+
def _raise_exception_on_finish(task: asyncio.Task, request_tracker: "Tracer") -> None:
14+
msg = "Task finished unexpectedly. This should never happen! "
15+
try:
16+
try:
17+
task.result()
18+
except asyncio.CancelledError:
19+
return
20+
except Exception as exc:
21+
raise RuntimeError(msg + " See stack trace above for the actual cause.") from exc
22+
raise RuntimeError(msg)
23+
except Exception as exc:
24+
request_tracker.propagate_exception(exc)
25+
raise exc
26+
27+
28+
class RequstStream:
29+
"""
30+
A stream of Output for a request that can be iterated over asynchronously.
31+
Attributes: 1.request_id: The id of the request.
32+
2._future: A future that will be set when the request is finished.
33+
Methods: set_result and get_result, results will be set when finished, for once, and
34+
the `self.future` will be set to done.
35+
36+
"""
37+
38+
def __init__(self, request_id: int) -> None:
39+
self.request_id = request_id
40+
self._future = asyncio.Future()
41+
42+
def set_result(self, result) -> None:
43+
"""Set final result and signal taht it's ready"""
44+
if not self._future.done():
45+
self._future.set_result(result)
46+
47+
async def get_result(self):
48+
"""Wait for the result to be set and return it."""
49+
return await self._future
50+
51+
@property
52+
def finished(self) -> bool:
53+
"""Check if the stream has finished by checking if the future is done."""
54+
return self._future.done()
55+
56+
57+
class Tracer:
58+
"""
59+
Recording new requests and finished requests.
60+
Attributes: 1._request_streams: We create one stream for each request to trace the output.
61+
2._finished_requests: A queue to store the finished requests.
62+
3._new_requests: New requests will be stored in this queue first, before sending them to the engine.
63+
4.new_requests_event: An event to notify the engine that there are new requests.
64+
"""
65+
66+
def __init__(self) -> None:
67+
self._request_streams: Dict[int, RequstStream] = {}
68+
self._finished_requests: asyncio.Queue[int] = asyncio.Queue()
69+
self._new_requests: asyncio.Queue[Tuple[RequstStream, dict]] = asyncio.Queue()
70+
self.new_requests_event = None
71+
72+
def __contains__(self, item):
73+
return item in self._request_streams
74+
75+
def init_event(self):
76+
self.new_requests_event = asyncio.Event()
77+
78+
def propagate_exception(self, exc: Exception, request_id: Optional[int] = None) -> None:
79+
"""
80+
Propagate an exception to request streams (all if request_id is None).
81+
"""
82+
if request_id is not None:
83+
self._request_streams[request_id].set_result(exc)
84+
else:
85+
for stream in self._request_streams.values():
86+
stream.set_result(exc)
87+
88+
def process_finished_request(self, finished_request) -> None:
89+
"""Process a finished request from the engine."""
90+
request_id = finished_request.request_id
91+
try:
92+
self._request_streams[request_id].set_result(finished_request)
93+
except:
94+
raise RuntimeError(f"The request_id {request_id} is not found in our stream, please check")
95+
self.abort_request(request_id)
96+
97+
def add_request(self, request_id: int, **engine_add_request_kwargs) -> RequstStream:
98+
"""
99+
Add a request to be sent to the engine on the next background
100+
loop iteration.
101+
"""
102+
if request_id in self._request_streams:
103+
raise KeyError(f"Request {request_id} already exists.")
104+
105+
stream = RequstStream(request_id)
106+
logger.info(f"Added request {request_id}.")
107+
self._new_requests.put_nowait((stream, {"request_id": request_id, **engine_add_request_kwargs}))
108+
self.new_requests_event.set()
109+
110+
return stream
111+
112+
def abort_request(self, request_id: int, *, verbose: bool = False) -> None:
113+
"""Abort a request during next background loop iteration."""
114+
if verbose:
115+
logger.info(f"Aborted request {request_id}.")
116+
117+
self._finished_requests.put_nowait(request_id)
118+
119+
if request_id not in self._request_streams or self._request_streams[request_id].finished:
120+
# The request has already finished or been aborted.
121+
# The requests in new_requests will be aborted when try to get them(if marked aborted)
122+
return
123+
124+
self._request_streams[request_id].set_result(None)
125+
126+
def get_new_requests(self):
127+
"""
128+
Get new requests from http server.
129+
"""
130+
new_requests: List[Dict] = []
131+
finished_requests: Set[int] = set()
132+
133+
while not self._finished_requests.empty():
134+
request_id = self._finished_requests.get_nowait()
135+
finished_requests.add(request_id)
136+
137+
while not self._new_requests.empty():
138+
stream, new_request = self._new_requests.get_nowait()
139+
if new_request["request_id"] in finished_requests:
140+
# The request has been aborted.
141+
stream.set_result(None)
142+
continue
143+
self._request_streams[stream.request_id] = stream
144+
new_requests.append(new_request)
145+
146+
self.new_requests_event.clear()
147+
148+
return new_requests
149+
150+
async def wait_for_new_requests(self):
151+
await self.new_requests_event.wait()
152+
153+
154+
class _AsyncInferenceEngine(InferenceEngine):
155+
"""
156+
Async methods for Inference Engine. This engine is an extension for InferenceEngine, and the additional methods will only be used for
157+
Methods: 1. async_step: The async version of Engine.step()
158+
"""
159+
160+
async def async_step(self) -> List[str]:
161+
"""
162+
The async version of Engine.step()
163+
Performs one decoding iteration and returns newly generated results.
164+
165+
It first schedules the sequences to be executed in the next iteration.
166+
Then, it executes the model and updates the scheduler with the model
167+
outputs. Finally, it decodes the sequences and returns the newly
168+
generated results.
169+
"""
170+
batch = self.request_handler.schedule()
171+
loop = asyncio.get_running_loop()
172+
173+
# Use run_in_executor to asyncally run the sync method model.forward().
174+
logits = await loop.run_in_executor(
175+
None,
176+
self.model,
177+
batch,
178+
self.k_cache,
179+
self.v_cache,
180+
)
181+
182+
if self.inference_config.pad_input:
183+
logits = logits[:, -1, :]
184+
self.request_handler.search_tokens(self.generation_config, logits)
185+
186+
finished_sequences = self.request_handler.update()
187+
for sequence in finished_sequences:
188+
sequence.output = self.tokenizer.decode(sequence.output_token_id)
189+
190+
return finished_sequences, self.request_handler.total_requests_in_batch_bucket() > 0
191+
192+
193+
class AsyncInferenceEngine:
194+
"""An asynchronous wrapper for the InferenceEngine class.
195+
196+
This class is used to wrap the InferenceEngine class to make it asynchronous.
197+
It uses asyncio to create a background loop that keeps processing incoming
198+
requests. Note that this class does not hold model directly, when incoming a new
199+
request, it first called `add_request` and the Tracer will record the request, putting
200+
it to the background `InferenceEngine`(done in background loop) to process. You can
201+
consider this engine as an interface.
202+
"""
203+
204+
_engine_class: Type[_AsyncInferenceEngine] = _AsyncInferenceEngine
205+
206+
def __init__(self, start_engine_loop: bool = True, **kwargs):
207+
self.engine = self._init_engine(**kwargs)
208+
self.background_loop = None
209+
# reference to the unshielded loop
210+
self._background_loop_unshielded = None
211+
self.start_engine_loop = start_engine_loop
212+
self._request_tracer = Tracer()
213+
214+
@property
215+
def background_loop_status(self):
216+
return self.background_loop is not None and not self.background_loop.done()
217+
218+
def start_background_loop(self):
219+
if self.background_loop_status:
220+
raise RuntimeError("Existing loop is running")
221+
222+
self._request_tracer.init_event()
223+
224+
self._background_loop_unshielded = asyncio.get_event_loop().create_task(self.run_engine_loop())
225+
self._background_loop_unshielded.add_done_callback(
226+
partial(_raise_exception_on_finish, request_tracker=self._request_tracer)
227+
)
228+
self.background_loop = asyncio.shield(self._background_loop_unshielded)
229+
230+
def _init_engine(self, **kwargs):
231+
return self._engine_class(**kwargs)
232+
233+
async def step(self):
234+
"""
235+
Run engine to process requests
236+
237+
Returns True if there are in-progress requests.
238+
"""
239+
new_requests = self._request_tracer.get_new_requests()
240+
for new_request in new_requests:
241+
self.engine.add_single_request(**new_request)
242+
newly_finished_seqs, has_running_requests = await self.engine.async_step()
243+
244+
for seq in newly_finished_seqs:
245+
self._request_tracer.process_finished_request(seq)
246+
247+
return has_running_requests
248+
249+
async def _engine_abort(self, request_ids: Iterable[int]):
250+
self.engine.abort_request(request_ids)
251+
252+
async def abort(self, request_id: int):
253+
"""
254+
Abort a single request
255+
"""
256+
if not self.background_loop_status:
257+
raise RuntimeError("Background loop is not running or launched correctly.")
258+
return self._abort(request_id)
259+
260+
def _abort(self, request_id: int):
261+
self._request_tracer.abort_request(request_id)
262+
263+
async def run_engine_loop(self):
264+
processing_requests = False
265+
while True:
266+
if not processing_requests:
267+
await self._request_tracer.wait_for_new_requests()
268+
processing_requests = await self.step()
269+
await asyncio.sleep(0)
270+
271+
async def add_request(
272+
self,
273+
request_id: int,
274+
prompt: Optional[str],
275+
prompt_token_ids: Optional[List[int]] = None,
276+
) -> RequstStream:
277+
"""
278+
Add a request to the background tracker(waiting queue), start the background loop if needed.
279+
"""
280+
if not self.background_loop_status:
281+
if self.start_engine_loop:
282+
self.start_background_loop()
283+
else:
284+
raise RuntimeError("Background loop is not running.")
285+
stream = self._request_tracer.add_request(
286+
request_id,
287+
prompt=prompt,
288+
prompt_token_ids=prompt_token_ids,
289+
)
290+
return stream
291+
292+
async def generate(
293+
self,
294+
request_id: int,
295+
prompt: Optional[str],
296+
prompt_token_ids: Optional[List[int]] = None,
297+
) -> AsyncIterator[str]:
298+
"""
299+
Generate output from a request. It receives the request from http server, adds it into the
300+
waitting queue of Async Engine and streams the output sequence.
301+
"""
302+
try:
303+
stream = await self.add_request(request_id, prompt, prompt_token_ids=prompt_token_ids)
304+
return await stream.get_result()
305+
306+
except (Exception, asyncio.CancelledError) as e:
307+
# If there is an exception or coroutine is cancelled, abort the request.
308+
self._abort(request_id)
309+
raise e

0 commit comments

Comments
 (0)