2121import ai .djl .serving .http .list .ListWorkflowsResponse ;
2222import ai .djl .serving .models .Endpoint ;
2323import ai .djl .serving .models .ModelManager ;
24+ import ai .djl .serving .sessions .SessionManager ;
25+ import ai .djl .serving .util .ConfigManager ;
2426import ai .djl .serving .util .NettyUtils ;
2527import ai .djl .serving .wlm .ModelInfo ;
28+ import ai .djl .serving .wlm .WorkLoadManager ;
29+ import ai .djl .serving .wlm .WorkerPool ;
2630import ai .djl .serving .wlm .WorkerPoolConfig ;
31+ import ai .djl .serving .wlm .util .WlmCapacityException ;
32+ import ai .djl .serving .wlm .util .WlmException ;
2733import ai .djl .serving .workflow .BadWorkflowException ;
2834import ai .djl .serving .workflow .Workflow ;
2935import ai .djl .serving .workflow .WorkflowDefinition ;
3036import ai .djl .serving .workflow .WorkflowTemplates ;
37+ import ai .djl .translate .TranslateException ;
3138import ai .djl .util .JsonUtils ;
3239import ai .djl .util .Pair ;
3340
4047import io .netty .handler .codec .http .QueryStringDecoder ;
4148import io .netty .util .CharsetUtil ;
4249
50+ import org .slf4j .Logger ;
51+ import org .slf4j .LoggerFactory ;
52+
4353import java .io .IOException ;
4454import java .lang .reflect .Method ;
4555import java .net .URI ;
4656import java .util .ArrayList ;
4757import java .util .Collections ;
4858import java .util .List ;
4959import java .util .Map ;
60+ import java .util .NoSuchElementException ;
5061import java .util .concurrent .CompletableFuture ;
5162import java .util .regex .Pattern ;
5263import java .util .stream .Collectors ;
5364
5465/** A class handling inbound HTTP requests to the management API. */
5566public class ManagementRequestHandler extends HttpRequestHandler {
5667
68+ private static final Logger logger = LoggerFactory .getLogger (ManagementRequestHandler .class );
69+
5770 private static final Pattern WORKFLOWS_PATTERN = Pattern .compile ("^/workflows([/?].*)?" );
5871 private static final Pattern MODELS_PATTERN = Pattern .compile ("^/models([/?].*)?" );
5972 private static final Pattern INVOKE_PATTERN = Pattern .compile ("^/models/.+/invoke$" );
6073 private static final Pattern SERVER_PATTERN = Pattern .compile ("^/server/.+" );
74+ private static final Pattern SESSION_PATTERN = Pattern .compile ("^/(create|close)_session" );
6175
6276 /** {@inheritDoc} */
6377 @ Override
6478 public boolean acceptInboundMessage (Object msg ) throws Exception {
6579 if (super .acceptInboundMessage (msg )) {
6680 FullHttpRequest req = (FullHttpRequest ) msg ;
6781 String uri = req .uri ();
68- if (WORKFLOWS_PATTERN .matcher (uri ).matches () || SERVER_PATTERN .matcher (uri ).matches ()) {
82+ if (WORKFLOWS_PATTERN .matcher (uri ).matches ()
83+ || SERVER_PATTERN .matcher (uri ).matches ()
84+ || SESSION_PATTERN .matcher (uri ).matches ()) {
6985 return true ;
7086 } else if (AdapterManagementRequestHandler .ADAPTERS_PATTERN .matcher (uri ).matches ()) {
7187 return false ;
@@ -107,7 +123,11 @@ protected void handleRequest(
107123 }
108124 return ;
109125 } else if (HttpMethod .POST .equals (method )) {
110- if ("models" .equals (segments [1 ])) {
126+ if ("create_session" .equals (segments [1 ])) {
127+ handleCreateSession (ctx );
128+ } else if ("close_session" .equals (segments [1 ])) {
129+ handleCloseSession (ctx , req );
130+ } else if ("models" .equals (segments [1 ])) {
111131 handleRegisterModel (ctx , req , decoder );
112132 } else {
113133 handleRegisterWorkflow (ctx , decoder );
@@ -384,6 +404,95 @@ private void handleScaleWorkflow(
384404 }
385405 }
386406
407+ private void handleCreateSession (final ChannelHandlerContext ctx ) {
408+ WorkLoadManager wlm = ModelManager .getInstance ().getWorkLoadManager ();
409+ String modelName =
410+ ModelManager .getInstance ()
411+ .getSingleStartupWorkflow ()
412+ .orElseThrow (
413+ () ->
414+ new BadRequestException (
415+ "there should be only a single startup"
416+ + " model used." ));
417+ WorkerPool <Input , Output > wp = wlm .getWorkerPoolById (modelName );
418+ if (wp == null ) {
419+ throw new BadRequestException (
420+ HttpResponseStatus .NOT_FOUND .code (),
421+ "The model " + modelName + " was not found" );
422+ }
423+ ModelInfo <Input , Output > modelInfo = getModelInfo (wp );
424+
425+ SessionManager <Input , Output > sessionManager = SessionManager .newInstance (modelInfo );
426+ sessionManager
427+ .createSession (wlm )
428+ .whenCompleteAsync (
429+ (o , t ) -> {
430+ if (o != null ) {
431+ if (o .getCode () >= 300 ) {
432+ throw new BadRequestException (o .getCode (), o .getMessage ());
433+ }
434+ NettyUtils .sendJsonResponse (
435+ ctx ,
436+ new StatusResponse (o .getMessage ()),
437+ HttpResponseStatus .valueOf (o .getCode ()));
438+ }
439+ })
440+ .exceptionally (
441+ t -> {
442+ onException (t .getCause (), ctx );
443+ return null ;
444+ });
445+ }
446+
447+ private void handleCloseSession (final ChannelHandlerContext ctx , FullHttpRequest req ) {
448+ WorkLoadManager wlm = ModelManager .getInstance ().getWorkLoadManager ();
449+ String modelName =
450+ ModelManager .getInstance ()
451+ .getSingleStartupWorkflow ()
452+ .orElseThrow (
453+ () ->
454+ new BadRequestException (
455+ "there should be only a single startup"
456+ + " model used." ));
457+ WorkerPool <Input , Output > wp = wlm .getWorkerPoolById (modelName );
458+ if (wp == null ) {
459+ throw new BadRequestException (
460+ HttpResponseStatus .NOT_FOUND .code (),
461+ "The model " + modelName + " was not found" );
462+ }
463+ ModelInfo <Input , Output > modelInfo = getModelInfo (wp );
464+ String sessionId = req .headers ().get ("X-Amzn-SageMaker-Session-Id" );
465+
466+ SessionManager <Input , Output > sessionManager = SessionManager .newInstance (modelInfo );
467+ sessionManager
468+ .closeSession (wlm , sessionId )
469+ .whenCompleteAsync (
470+ (o , t ) -> {
471+ if (o != null ) {
472+ if (o .getCode () >= 300 ) {
473+ throw new BadRequestException (o .getCode (), o .getMessage ());
474+ }
475+ NettyUtils .sendJsonResponse (
476+ ctx ,
477+ new StatusResponse (o .getMessage ()),
478+ HttpResponseStatus .valueOf (o .getCode ()));
479+ }
480+ })
481+ .exceptionally (
482+ t -> {
483+ onException (t .getCause (), ctx );
484+ return null ;
485+ });
486+ }
487+
488+ private ModelInfo <Input , Output > getModelInfo (WorkerPool <Input , Output > wp ) {
489+ if (!(wp .getWpc () instanceof ModelInfo )) {
490+ String modelName = wp .getWpc ().getId ();
491+ throw new BadRequestException ("The worker " + modelName + " is not a model" );
492+ }
493+ return (ModelInfo <Input , Output >) wp .getWpc ();
494+ }
495+
387496 @ SuppressWarnings ("unchecked" )
388497 private void handleConfigLogs (ChannelHandlerContext ctx , QueryStringDecoder decoder ) {
389498 String logLevel = NettyUtils .getParameter (decoder , "level" , null );
@@ -408,4 +517,41 @@ private void handleConfigLogs(ChannelHandlerContext ctx, QueryStringDecoder deco
408517 StatusResponse resp = new StatusResponse ("OK" );
409518 NettyUtils .sendJsonResponse (ctx , resp );
410519 }
520+
521+ private void onException (Throwable t , ChannelHandlerContext ctx ) {
522+ ConfigManager config = ConfigManager .getInstance ();
523+ int code ;
524+ String requestIdLogPrefix = "" ;
525+ if (ctx != null ) {
526+ String requestId = NettyUtils .getRequestId (ctx .channel ());
527+ requestIdLogPrefix = "RequestId=[" + requestId + "]: " ;
528+ }
529+ if (t instanceof TranslateException ) {
530+ logger .debug ("{}{}" , requestIdLogPrefix , t .getMessage (), t );
531+ code = config .getBadRequestErrorHttpCode ();
532+ } else if (t instanceof BadRequestException ) {
533+ code = ((BadRequestException ) t ).getCode ();
534+ } else if (t instanceof WlmException ) {
535+ logger .warn ("{}{}" , requestIdLogPrefix , t .getMessage (), t );
536+ if (t instanceof WlmCapacityException ) {
537+ code = config .getThrottleErrorHttpCode ();
538+ } else {
539+ code = config .getWlmErrorHttpCode ();
540+ }
541+ } else if (t instanceof NoSuchElementException ) {
542+ logger .warn (requestIdLogPrefix , t );
543+ code = HttpResponseStatus .NOT_FOUND .code ();
544+ } else if (t instanceof IllegalArgumentException ) {
545+ logger .warn (requestIdLogPrefix , t );
546+ code = HttpResponseStatus .CONFLICT .code ();
547+ } else {
548+ logger .warn ("{} Unexpected error" , requestIdLogPrefix , t );
549+ code = config .getServerErrorHttpCode ();
550+ }
551+ HttpResponseStatus status = HttpResponseStatus .valueOf (code );
552+
553+ if (ctx != null ) {
554+ NettyUtils .sendError (ctx , status , t );
555+ }
556+ }
411557}
0 commit comments