Skip to content

Commit 325fdec

Browse files
committed
Support session based sticky routing in python engine
1 parent 9e0107b commit 325fdec

File tree

10 files changed

+470
-13
lines changed

10 files changed

+470
-13
lines changed

.github/workflows/integration.yml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,9 @@ jobs:
170170
- test: TestCorrectnessTrtLlm
171171
instance: g6
172172
failure-prefix: trtllm
173-
173+
- test: TestStickyRouting
174+
instance: g6
175+
failure-prefix: lmi
174176
outputs:
175177
failure_cpu: ${{ steps.test-failure.outputs.failure_cpu }}
176178
failure_gpu: ${{ steps.test-failure.outputs.failure_gpu }}
@@ -268,4 +270,4 @@ jobs:
268270
./stop_instance.sh $instance_id
269271
270272
instance_id=${{ needs.create-runners.outputs.cpu_instance_id }}
271-
./stop_instance.sh $instance_id
273+
./stop_instance.sh $instance_id

engines/python/setup/djl_python/lmi_vllm/vllm_async_service.py

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@
3030
from djl_python.inputs import Input
3131
from djl_python.outputs import Output
3232
from djl_python.encode_decode import decode
33-
from djl_python.async_utils import handle_streaming_response, create_non_stream_output
34-
from djl_python.custom_formatter_handling import CustomFormatterHandler, CustomFormatterError
33+
from djl_python.async_utils import handle_streaming_response, create_non_stream_output, ProcessedRequest
34+
from djl_python.session_manager import SessionManager
3535

3636
from .request_response_utils import (
3737
ProcessedRequest,
@@ -61,6 +61,7 @@ def __init__(self):
6161
self.vllm_engine_args = None
6262
self.vllm_properties = None
6363
self.model_name = None
64+
self.session_manager = None
6465
self.initialized = False
6566

6667
async def initialize(self, properties: dict):
@@ -119,12 +120,20 @@ async def initialize(self, properties: dict):
119120
tool_parser=self.vllm_properties.tool_call_parser,
120121
reasoning_parser=self.vllm_properties.reasoning_parser,
121122
)
123+
self.session_manager: SessionManager = SessionManager(properties)
122124
self.initialized = True
123125

124126
def preprocess_request(self, inputs: Input) -> ProcessedRequest:
125127
batch = inputs.get_batches()
126128
assert len(batch) == 1, "only one request per batch allowed"
127129
raw_request = batch[0]
130+
131+
# Get session id
132+
session_id = raw_request.get_property("X-Amzn-SageMaker-Session-Id")
133+
session = self.session_manager.get_session(session_id)
134+
if session is None:
135+
raise RuntimeError(f"Requested session {session_id} not found")
136+
128137
content_type = raw_request.get_property("Content-Type")
129138
decoded_payload = decode(raw_request, content_type)
130139

@@ -222,10 +231,53 @@ async def inference(
222231
tokenizer=self.tokenizer,
223232
)
224233

234+
async def create_session(self, inputs: Input):
235+
await self.check_health()
236+
outputs = Output()
237+
session = self.session_manager.create_session()
238+
outputs.add_property("X-Amzn-SageMaker-Session-Id", session.session_id)
239+
outputs.add_property("Content-Type", "application/json")
240+
outputs.add(Output.binary_encode(
241+
{"result": f"Session {session.session_id} created"}),
242+
key="result")
243+
logger.info(f"Session {session.session_id} created")
244+
return outputs
245+
246+
async def close_session(self, inputs: Input):
247+
await self.check_health()
248+
outputs = Output()
249+
session_id = inputs.get_property("X-Amzn-SageMaker-Session-Id")
250+
self.session_manager.close_session(session_id)
251+
outputs.add_property("X-Amzn-SageMaker-Session-Closed", "true")
252+
outputs.add_property("Content-Type", "application/json")
253+
outputs.add(Output.binary_encode(
254+
{"result": f"Session {session_id} closed"}),
255+
key="result")
256+
logger.info(f"Session {session_id} closed")
257+
return outputs
258+
225259

226260
service = VLLMHandler()
227261

228262

263+
async def create_session(inputs: Input) -> Output:
264+
if not service.initialized:
265+
await service.initialize(inputs.get_properties())
266+
logger.info("vllm service initialized")
267+
268+
outputs = await service.create_session(inputs)
269+
return outputs
270+
271+
272+
async def close_session(inputs: Input) -> Output:
273+
if not service.initialized:
274+
await service.initialize(inputs.get_properties())
275+
logger.info("vllm service initialized")
276+
277+
outputs = await service.create_session(inputs)
278+
return outputs
279+
280+
229281
async def handle(
230282
inputs: Input
231283
) -> Optional[Union[Output, AsyncGenerator[Output, None]]]:

engines/python/src/main/java/ai/djl/python/engine/PyProcess.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,8 @@ Output predict(Input inputs, int timeout, boolean initialLoad) throws TranslateE
158158
// In RollingBatch, we queue adapter loading jobs to occur after the initial load.
159159
// Executing those in RollingBatch context doesn't work, so we need to handle them in the
160160
// 'standard' way.
161-
if (initialLoad || inputs.getProperty("handler", null) != null) {
161+
if (initialLoad
162+
|| (inputs.getProperty("handler", null) != null && asyncRequestManager == null)) {
162163
return predictStandard(inputs, timeout, initialLoad);
163164
}
164165
if (rollingBatch != null) {

serving/src/main/java/ai/djl/serving/http/ManagementRequestHandler.java

Lines changed: 148 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,20 @@
2121
import ai.djl.serving.http.list.ListWorkflowsResponse;
2222
import ai.djl.serving.models.Endpoint;
2323
import ai.djl.serving.models.ModelManager;
24+
import ai.djl.serving.sessions.SessionManager;
25+
import ai.djl.serving.util.ConfigManager;
2426
import ai.djl.serving.util.NettyUtils;
2527
import ai.djl.serving.wlm.ModelInfo;
28+
import ai.djl.serving.wlm.WorkLoadManager;
29+
import ai.djl.serving.wlm.WorkerPool;
2630
import ai.djl.serving.wlm.WorkerPoolConfig;
31+
import ai.djl.serving.wlm.util.WlmCapacityException;
32+
import ai.djl.serving.wlm.util.WlmException;
2733
import ai.djl.serving.workflow.BadWorkflowException;
2834
import ai.djl.serving.workflow.Workflow;
2935
import ai.djl.serving.workflow.WorkflowDefinition;
3036
import ai.djl.serving.workflow.WorkflowTemplates;
37+
import ai.djl.translate.TranslateException;
3138
import ai.djl.util.JsonUtils;
3239
import ai.djl.util.Pair;
3340

@@ -40,32 +47,41 @@
4047
import io.netty.handler.codec.http.QueryStringDecoder;
4148
import io.netty.util.CharsetUtil;
4249

50+
import org.slf4j.Logger;
51+
import org.slf4j.LoggerFactory;
52+
4353
import java.io.IOException;
4454
import java.lang.reflect.Method;
4555
import java.net.URI;
4656
import java.util.ArrayList;
4757
import java.util.Collections;
4858
import java.util.List;
4959
import java.util.Map;
60+
import java.util.NoSuchElementException;
5061
import java.util.concurrent.CompletableFuture;
5162
import java.util.regex.Pattern;
5263
import java.util.stream.Collectors;
5364

5465
/** A class handling inbound HTTP requests to the management API. */
5566
public class ManagementRequestHandler extends HttpRequestHandler {
5667

68+
private static final Logger logger = LoggerFactory.getLogger(ManagementRequestHandler.class);
69+
5770
private static final Pattern WORKFLOWS_PATTERN = Pattern.compile("^/workflows([/?].*)?");
5871
private static final Pattern MODELS_PATTERN = Pattern.compile("^/models([/?].*)?");
5972
private static final Pattern INVOKE_PATTERN = Pattern.compile("^/models/.+/invoke$");
6073
private static final Pattern SERVER_PATTERN = Pattern.compile("^/server/.+");
74+
private static final Pattern SESSION_PATTERN = Pattern.compile("^/(create|close)_session");
6175

6276
/** {@inheritDoc} */
6377
@Override
6478
public boolean acceptInboundMessage(Object msg) throws Exception {
6579
if (super.acceptInboundMessage(msg)) {
6680
FullHttpRequest req = (FullHttpRequest) msg;
6781
String uri = req.uri();
68-
if (WORKFLOWS_PATTERN.matcher(uri).matches() || SERVER_PATTERN.matcher(uri).matches()) {
82+
if (WORKFLOWS_PATTERN.matcher(uri).matches()
83+
|| SERVER_PATTERN.matcher(uri).matches()
84+
|| SESSION_PATTERN.matcher(uri).matches()) {
6985
return true;
7086
} else if (AdapterManagementRequestHandler.ADAPTERS_PATTERN.matcher(uri).matches()) {
7187
return false;
@@ -107,7 +123,11 @@ protected void handleRequest(
107123
}
108124
return;
109125
} else if (HttpMethod.POST.equals(method)) {
110-
if ("models".equals(segments[1])) {
126+
if ("create_session".equals(segments[1])) {
127+
handleCreateSession(ctx);
128+
} else if ("close_session".equals(segments[1])) {
129+
handleCloseSession(ctx, req);
130+
} else if ("models".equals(segments[1])) {
111131
handleRegisterModel(ctx, req, decoder);
112132
} else {
113133
handleRegisterWorkflow(ctx, decoder);
@@ -384,6 +404,95 @@ private void handleScaleWorkflow(
384404
}
385405
}
386406

407+
private void handleCreateSession(final ChannelHandlerContext ctx) {
408+
WorkLoadManager wlm = ModelManager.getInstance().getWorkLoadManager();
409+
String modelName =
410+
ModelManager.getInstance()
411+
.getSingleStartupWorkflow()
412+
.orElseThrow(
413+
() ->
414+
new BadRequestException(
415+
"there should be only a single startup"
416+
+ " model used."));
417+
WorkerPool<Input, Output> wp = wlm.getWorkerPoolById(modelName);
418+
if (wp == null) {
419+
throw new BadRequestException(
420+
HttpResponseStatus.NOT_FOUND.code(),
421+
"The model " + modelName + " was not found");
422+
}
423+
ModelInfo<Input, Output> modelInfo = getModelInfo(wp);
424+
425+
SessionManager<Input, Output> sessionManager = SessionManager.newInstance(modelInfo);
426+
sessionManager
427+
.createSession(wlm)
428+
.whenCompleteAsync(
429+
(o, t) -> {
430+
if (o != null) {
431+
if (o.getCode() >= 300) {
432+
throw new BadRequestException(o.getCode(), o.getMessage());
433+
}
434+
NettyUtils.sendJsonResponse(
435+
ctx,
436+
new StatusResponse(o.getMessage()),
437+
HttpResponseStatus.valueOf(o.getCode()));
438+
}
439+
})
440+
.exceptionally(
441+
t -> {
442+
onException(t.getCause(), ctx);
443+
return null;
444+
});
445+
}
446+
447+
private void handleCloseSession(final ChannelHandlerContext ctx, FullHttpRequest req) {
448+
WorkLoadManager wlm = ModelManager.getInstance().getWorkLoadManager();
449+
String modelName =
450+
ModelManager.getInstance()
451+
.getSingleStartupWorkflow()
452+
.orElseThrow(
453+
() ->
454+
new BadRequestException(
455+
"there should be only a single startup"
456+
+ " model used."));
457+
WorkerPool<Input, Output> wp = wlm.getWorkerPoolById(modelName);
458+
if (wp == null) {
459+
throw new BadRequestException(
460+
HttpResponseStatus.NOT_FOUND.code(),
461+
"The model " + modelName + " was not found");
462+
}
463+
ModelInfo<Input, Output> modelInfo = getModelInfo(wp);
464+
String sessionId = req.headers().get("X-Amzn-SageMaker-Session-Id");
465+
466+
SessionManager<Input, Output> sessionManager = SessionManager.newInstance(modelInfo);
467+
sessionManager
468+
.closeSession(wlm, sessionId)
469+
.whenCompleteAsync(
470+
(o, t) -> {
471+
if (o != null) {
472+
if (o.getCode() >= 300) {
473+
throw new BadRequestException(o.getCode(), o.getMessage());
474+
}
475+
NettyUtils.sendJsonResponse(
476+
ctx,
477+
new StatusResponse(o.getMessage()),
478+
HttpResponseStatus.valueOf(o.getCode()));
479+
}
480+
})
481+
.exceptionally(
482+
t -> {
483+
onException(t.getCause(), ctx);
484+
return null;
485+
});
486+
}
487+
488+
private ModelInfo<Input, Output> getModelInfo(WorkerPool<Input, Output> wp) {
489+
if (!(wp.getWpc() instanceof ModelInfo)) {
490+
String modelName = wp.getWpc().getId();
491+
throw new BadRequestException("The worker " + modelName + " is not a model");
492+
}
493+
return (ModelInfo<Input, Output>) wp.getWpc();
494+
}
495+
387496
@SuppressWarnings("unchecked")
388497
private void handleConfigLogs(ChannelHandlerContext ctx, QueryStringDecoder decoder) {
389498
String logLevel = NettyUtils.getParameter(decoder, "level", null);
@@ -408,4 +517,41 @@ private void handleConfigLogs(ChannelHandlerContext ctx, QueryStringDecoder deco
408517
StatusResponse resp = new StatusResponse("OK");
409518
NettyUtils.sendJsonResponse(ctx, resp);
410519
}
520+
521+
private void onException(Throwable t, ChannelHandlerContext ctx) {
522+
ConfigManager config = ConfigManager.getInstance();
523+
int code;
524+
String requestIdLogPrefix = "";
525+
if (ctx != null) {
526+
String requestId = NettyUtils.getRequestId(ctx.channel());
527+
requestIdLogPrefix = "RequestId=[" + requestId + "]: ";
528+
}
529+
if (t instanceof TranslateException) {
530+
logger.debug("{}{}", requestIdLogPrefix, t.getMessage(), t);
531+
code = config.getBadRequestErrorHttpCode();
532+
} else if (t instanceof BadRequestException) {
533+
code = ((BadRequestException) t).getCode();
534+
} else if (t instanceof WlmException) {
535+
logger.warn("{}{}", requestIdLogPrefix, t.getMessage(), t);
536+
if (t instanceof WlmCapacityException) {
537+
code = config.getThrottleErrorHttpCode();
538+
} else {
539+
code = config.getWlmErrorHttpCode();
540+
}
541+
} else if (t instanceof NoSuchElementException) {
542+
logger.warn(requestIdLogPrefix, t);
543+
code = HttpResponseStatus.NOT_FOUND.code();
544+
} else if (t instanceof IllegalArgumentException) {
545+
logger.warn(requestIdLogPrefix, t);
546+
code = HttpResponseStatus.CONFLICT.code();
547+
} else {
548+
logger.warn("{} Unexpected error", requestIdLogPrefix, t);
549+
code = config.getServerErrorHttpCode();
550+
}
551+
HttpResponseStatus status = HttpResponseStatus.valueOf(code);
552+
553+
if (ctx != null) {
554+
NettyUtils.sendError(ctx, status, t);
555+
}
556+
}
411557
}
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/*
2+
* Copyright 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
5+
* with the License. A copy of the License is located at
6+
*
7+
* http://aws.amazon.com/apache2.0/
8+
*
9+
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
10+
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
11+
* and limitations under the License.
12+
*/
13+
package ai.djl.serving.sessions;
14+
15+
import ai.djl.modality.Input;
16+
import ai.djl.modality.Output;
17+
import ai.djl.serving.wlm.ModelInfo;
18+
19+
/** An overload of {@link SessionManager} for the python engine. */
20+
public class PySessionManager extends SessionManager<Input, Output> {
21+
22+
protected PySessionManager(ModelInfo<Input, Output> modelInfo) {
23+
super(modelInfo);
24+
}
25+
26+
@Override
27+
protected Input getCreateSessionInput() {
28+
Input input = new Input();
29+
input.addProperty("handler", "create_session");
30+
return input;
31+
}
32+
33+
@Override
34+
protected Input getCloseSessionInput(String sessionId) {
35+
Input input = new Input();
36+
input.addProperty("handler", "close_session");
37+
input.addProperty("X-Amzn-SageMaker-Session-Id", sessionId);
38+
return input;
39+
}
40+
}

0 commit comments

Comments
 (0)