Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions engines/python/setup/djl_python/input_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,9 @@ def add_server_maintained_params(request_input: RequestInput,
request_input.server_parameters["output_formatter"] = kwargs.get(
"configs").output_formatter

if input_item.get_property("cancelled"):
request_input.is_cancelled = True

output_formatter = request_input.server_parameters["output_formatter"]
if output_formatter == "json" or output_formatter == "sse":
request_input.tgi_compat = kwargs.get("configs").tgi_compat
Expand Down
11 changes: 11 additions & 0 deletions engines/python/setup/djl_python/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.
import inspect
import json
from typing import Union, Callable, Any, List, Dict, Optional

from djl_python.output_formatter import get_output_formatter, adapt_legacy_output_formatter
Expand Down Expand Up @@ -108,6 +109,8 @@ def get_next_token(self) -> str:

:return: next_token
"""
if self.is_cancelled():
return ""
if self.next_token_str:
return self.next_token_str
if self.legacy_formatter:
Expand Down Expand Up @@ -181,3 +184,11 @@ def get_client_request_id(self) -> str:
:return: the requestId specified in the HTTP request
"""
return self.request_input.client_request_id

def is_cancelled(self) -> bool:
"""
Returns whether the request has been cancelled by the client

:return: true if the request is cancelled
"""
return self.request_input.is_cancelled
1 change: 1 addition & 0 deletions engines/python/setup/djl_python/request_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ class RequestInput:
parameters: Dict = field(default_factory=lambda: {})
server_parameters: Dict = field(default_factory=lambda: {})
tgi_compat: bool = False
is_cancelled: bool = False


@dataclass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,14 +171,17 @@ def postprocess_results(self) -> List[dict]:
req = self.active_requests[i]
res = {
"data": req.get_next_token(),
"last": req.is_last_token(),
"last": req.is_last_token() or req.is_cancelled(),
"content_type": req.get_content_type(),
"request_id": req.get_client_request_id(),
}
if req.get_error_message():
res["error"] = req.get_error_message()
if req.get_error_code():
res["code"] = req.get_error_code()
if req.is_cancelled():
res["error"] = res.get("error", "request has been cancelled")
res["code"] = res.get("code", 499)
req.reset_next_token()
results.append(res)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.
import logging
from collections import OrderedDict

from vllm import LLMEngine, SamplingParams
from vllm.sampling_params import RequestOutputKind
from vllm.utils import random_uuid, AtomicCounter
from vllm.utils import AtomicCounter

from djl_python.request import Request
from djl_python.rolling_batch.rolling_batch import RollingBatch, stop_on_any_exception, filter_unused_generation_params
Expand Down Expand Up @@ -154,6 +155,15 @@ def translate_vllm_params(self, parameters: dict) -> dict:
remove_unused_params=True)
return parameters

def cancel_requests(self):
for req in self.active_requests:
if req.is_cancelled():
self.engine.abort_request(req.get_client_request_id())
self.request_cache.pop(req.get_client_request_id(), None)
logging.info(
f"RequestId[{req.get_client_request_id()}] has been cancelled"
)

@stop_on_any_exception
def inference(self, new_requests: List[Request]) -> List:
"""
Expand All @@ -164,9 +174,10 @@ def inference(self, new_requests: List[Request]) -> List:
:return results: List of dictionaries, one for each request, that contain output tokens and other data.
"""
self.add_new_requests(new_requests)
self.cancel_requests()
# step 0: register new requests to engine
for request in new_requests:
request_id = random_uuid()
request_id = request.get_client_request_id()
# Chat completions request route
if request.parameters.get("sampling_params") is not None:
prompt_inputs = request.parameters.get("engine_prompt")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,10 @@ public void run() {
String key = prefix + entry.getKey();
batch.addProperty(key, entry.getValue());
}
if (req.isCancelled()) {
String key = prefix + "cancelled";
batch.addProperty(key, "true");
}

batch.add(prefix + "data", req.getRequest());
String seed = req.getSeed();
Expand Down Expand Up @@ -223,12 +227,16 @@ public void run() {
}

public Output addInput(Input input, int timeout) throws TranslateException {
String requestId = input.getProperty("requestId", "");
String requestIdLogPrefix = "RequestId=[" + requestId + "]";
if (input.isCancelled()) {
logger.warn("{} has been cancelled, not processing request", requestIdLogPrefix);
return new Output(499, "request has been cancelled due to client disconnect");
}
try {
lock.lock();
if (list.size() >= maxRollingBatchSize) {
// Input always reflects a single request here
String requestId = input.getProperty("requestId", "");
String requestIdLogPrefix = "RequestId=[" + requestId + "]";
logger.debug(
"{} exceed max_rolling_batch_size: {}",
requestIdLogPrefix,
Expand Down Expand Up @@ -370,5 +378,9 @@ void addResponse(byte[] json, Map<String, String> properties) {
data.appendContent(nextToken.getBytes(StandardCharsets.UTF_8), last);
}
}

boolean isCancelled() {
return input.isCancelled();
}
}
}
Loading