@@ -441,6 +441,9 @@ def run(
441441
442442 """
443443
444+ def pre_setup (self , lit_api : LitAPI , spec : Optional [LitSpec ]):
445+ pass
446+
444447 def __call__ (
445448 self ,
446449 lit_api : LitAPI ,
@@ -487,7 +490,109 @@ def run(
487490 raise NotImplementedError
488491
489492
490- class SingleLoop (_BaseLoop ):
493+ class LitLoop (_BaseLoop ):
494+ def __init__ (self ):
495+ self ._context = {}
496+
497+ def get_batch_requests (self , lit_api : LitAPI , request_queue : Queue , max_batch_size : int , batch_timeout : float ):
498+ if max_batch_size <= 1 :
499+ raise ValueError ("max_batch_size must be greater than 1" )
500+
501+ batches , timed_out_uids = collate_requests (
502+ lit_api ,
503+ request_queue ,
504+ max_batch_size ,
505+ batch_timeout ,
506+ )
507+ return batches , timed_out_uids
508+
509+ def get_request (self , request_queue : Queue , timeout : float = 1.0 ):
510+ response_queue_id , uid , timestamp , x_enc = request_queue .get (timeout = timeout )
511+ return response_queue_id , uid , timestamp , x_enc
512+
513+ def populate_context (self , lit_spec : LitSpec , request : Any ):
514+ if lit_spec and hasattr (lit_spec , "populate_context" ):
515+ lit_spec .populate_context (self ._context , request )
516+
517+ def put_response (
518+ self , response_queues : List [Queue ], response_queue_id : int , uid : str , response_data : Any , status : LitAPIStatus
519+ ) -> None :
520+ response_queues [response_queue_id ].put ((uid , (response_data , status )))
521+
522+ def put_error_response (
523+ self , response_queues : List [Queue ], response_queue_id : int , uid : str , error : Exception
524+ ) -> None :
525+ response_queues [response_queue_id ].put ((uid , (error , LitAPIStatus .ERROR )))
526+
527+
528+ class DefaultLoop (LitLoop ):
529+ def pre_setup (self , lit_api : LitAPI , spec : Optional [LitSpec ]):
530+ # we will sanitize regularly if no spec
531+ # in case, we have spec then:
532+ # case 1: spec implements a streaming API
533+ # Case 2: spec implements a non-streaming API
534+ if spec :
535+ # TODO: Implement sanitization
536+ lit_api ._spec = spec
537+ return
538+
539+ original = lit_api .unbatch .__code__ is LitAPI .unbatch .__code__
540+ if (
541+ lit_api .stream
542+ and lit_api .max_batch_size > 1
543+ and not all ([
544+ inspect .isgeneratorfunction (lit_api .predict ),
545+ inspect .isgeneratorfunction (lit_api .encode_response ),
546+ (original or inspect .isgeneratorfunction (lit_api .unbatch )),
547+ ])
548+ ):
549+ raise ValueError (
550+ """When `stream=True` with max_batch_size > 1, `lit_api.predict`, `lit_api.encode_response` and
551+ `lit_api.unbatch` must generate values using `yield`.
552+
553+ Example:
554+
555+ def predict(self, inputs):
556+ ...
557+ for i in range(max_token_length):
558+ yield prediction
559+
560+ def encode_response(self, outputs):
561+ for output in outputs:
562+ encoded_output = ...
563+ yield encoded_output
564+
565+ def unbatch(self, outputs):
566+ for output in outputs:
567+ unbatched_output = ...
568+ yield unbatched_output
569+ """
570+ )
571+
572+ if lit_api .stream and not all ([
573+ inspect .isgeneratorfunction (lit_api .predict ),
574+ inspect .isgeneratorfunction (lit_api .encode_response ),
575+ ]):
576+ raise ValueError (
577+ """When `stream=True` both `lit_api.predict` and
578+ `lit_api.encode_response` must generate values using `yield`.
579+
580+ Example:
581+
582+ def predict(self, inputs):
583+ ...
584+ for i in range(max_token_length):
585+ yield prediction
586+
587+ def encode_response(self, outputs):
588+ for output in outputs:
589+ encoded_output = ...
590+ yield encoded_output
591+ """
592+ )
593+
594+
595+ class SingleLoop (DefaultLoop ):
491596 def __call__ (
492597 self ,
493598 lit_api : LitAPI ,
@@ -505,7 +610,7 @@ def __call__(
505610 run_single_loop (lit_api , lit_spec , request_queue , response_queues , callback_runner )
506611
507612
508- class BatchedLoop (_BaseLoop ):
613+ class BatchedLoop (DefaultLoop ):
509614 def __call__ (
510615 self ,
511616 lit_api : LitAPI ,
@@ -531,7 +636,7 @@ def __call__(
531636 )
532637
533638
534- class StreamingLoop (_BaseLoop ):
639+ class StreamingLoop (DefaultLoop ):
535640 def __call__ (
536641 self ,
537642 lit_api : LitAPI ,
@@ -549,7 +654,7 @@ def __call__(
549654 run_streaming_loop (lit_api , lit_spec , request_queue , response_queues , callback_runner )
550655
551656
552- class BatchedStreamingLoop (_BaseLoop ):
657+ class BatchedStreamingLoop (DefaultLoop ):
553658 def __call__ (
554659 self ,
555660 lit_api : LitAPI ,
@@ -593,41 +698,6 @@ class Output:
593698 status : LitAPIStatus
594699
595700
596- class LitLoop (_BaseLoop ):
597- def __init__ (self ):
598- self ._context = {}
599-
600- def get_batch_requests (self , lit_api : LitAPI , request_queue : Queue , max_batch_size : int , batch_timeout : float ):
601- if max_batch_size <= 1 :
602- raise ValueError ("max_batch_size must be greater than 1" )
603-
604- batches , timed_out_uids = collate_requests (
605- lit_api ,
606- request_queue ,
607- max_batch_size ,
608- batch_timeout ,
609- )
610- return batches , timed_out_uids
611-
612- def get_request (self , request_queue : Queue , timeout : float = 1.0 ):
613- response_queue_id , uid , timestamp , x_enc = request_queue .get (timeout = timeout )
614- return response_queue_id , uid , timestamp , x_enc
615-
616- def populate_context (self , lit_spec : LitSpec , request : Any ):
617- if lit_spec and hasattr (lit_spec , "populate_context" ):
618- lit_spec .populate_context (self ._context , request )
619-
620- def put_response (
621- self , response_queues : List [Queue ], response_queue_id : int , uid : str , response_data : Any , status : LitAPIStatus
622- ) -> None :
623- response_queues [response_queue_id ].put ((uid , (response_data , status )))
624-
625- def put_error_response (
626- self , response_queues : List [Queue ], response_queue_id : int , uid : str , error : Exception
627- ) -> None :
628- response_queues [response_queue_id ].put ((uid , (error , LitAPIStatus .ERROR )))
629-
630-
631701class ContinuousBatchingLoop (LitLoop ):
632702 def __init__ (self , max_sequence_length : int = 2048 ):
633703 super ().__init__ ()
@@ -840,15 +910,7 @@ def inference_worker(
840910 logging .info (f"LitServe will use { lit_spec .__class__ .__name__ } spec" )
841911
842912 if loop == "auto" :
843- loop = (
844- BatchedStreamingLoop ()
845- if stream and max_batch_size > 1
846- else StreamingLoop ()
847- if stream
848- else BatchedLoop ()
849- if max_batch_size > 1
850- else SingleLoop ()
851- )
913+ loop = get_default_loop (stream , max_batch_size )
852914
853915 loop (
854916 lit_api ,
@@ -863,3 +925,15 @@ def inference_worker(
863925 workers_setup_status ,
864926 callback_runner ,
865927 )
928+
929+
930+ def get_default_loop (stream : bool , max_batch_size : int ) -> _BaseLoop :
931+ return (
932+ BatchedStreamingLoop ()
933+ if stream and max_batch_size > 1
934+ else StreamingLoop ()
935+ if stream
936+ else BatchedLoop ()
937+ if max_batch_size > 1
938+ else SingleLoop ()
939+ )
0 commit comments