@@ -33,6 +33,7 @@ def __init__(
3333 max_new_tokens : int = 50 ,
3434 devices : int = 1 ,
3535 api_path : Optional [str ] = None ,
36+ generate_strategy : Optional [Literal ["sequential" , "tensor_parallel" ]] = None ,
3637 ) -> None :
3738 if not _LITSERVE_AVAILABLE :
3839 raise ImportError (str (_LITSERVE_AVAILABLE ))
@@ -47,6 +48,7 @@ def __init__(
4748 self .max_new_tokens = max_new_tokens
4849 self .top_p = top_p
4950 self .devices = devices
51+ self .generate_strategy = generate_strategy
5052
5153 def setup (self , device : str ) -> None :
5254 if ":" in device :
@@ -64,7 +66,8 @@ def setup(self, device: str) -> None:
6466 accelerator = accelerator ,
6567 quantize = self .quantize ,
6668 precision = self .precision ,
67- generate_strategy = ("sequential" if self .devices is not None and self .devices > 1 else None ),
69+ generate_strategy = self .generate_strategy
70+ or ("sequential" if self .devices is not None and self .devices > 1 else None ),
6871 )
6972 print ("Model successfully initialized." , file = sys .stderr )
7073
@@ -85,6 +88,7 @@ def __init__(
8588 max_new_tokens : int = 50 ,
8689 devices : int = 1 ,
8790 api_path : Optional [str ] = None ,
91+ generate_strategy : Optional [str ] = None ,
8892 ):
8993 super ().__init__ (
9094 checkpoint_dir ,
@@ -96,6 +100,7 @@ def __init__(
96100 max_new_tokens ,
97101 devices ,
98102 api_path = api_path ,
103+ generate_strategy = generate_strategy ,
99104 )
100105
101106 def setup (self , device : str ):
@@ -128,6 +133,7 @@ def __init__(
128133 max_new_tokens : int = 50 ,
129134 devices : int = 1 ,
130135 api_path : Optional [str ] = None ,
136+ generate_strategy : Optional [str ] = None ,
131137 ):
132138 super ().__init__ (
133139 checkpoint_dir ,
@@ -139,6 +145,7 @@ def __init__(
139145 max_new_tokens ,
140146 devices ,
141147 api_path = api_path ,
148+ generate_strategy = generate_strategy ,
142149 )
143150
144151 def setup (self , device : str ):
@@ -171,6 +178,7 @@ def __init__(
171178 max_new_tokens : int = 50 ,
172179 devices : int = 1 ,
173180 api_path : Optional [str ] = None ,
181+ generate_strategy : Optional [str ] = None ,
174182 ):
175183 super ().__init__ (
176184 checkpoint_dir ,
@@ -182,6 +190,7 @@ def __init__(
182190 max_new_tokens ,
183191 devices ,
184192 api_path = api_path ,
193+ generate_strategy = generate_strategy ,
185194 )
186195
187196 def setup (self , device : str ):
@@ -241,6 +250,7 @@ def run_server(
241250 access_token : Optional [str ] = None ,
242251 api_path : Optional [str ] = "/predict" ,
243252 timeout : int = 30 ,
253+ generate_strategy : Optional [Literal ["sequential" , "tensor_parallel" ]] = None ,
244254) -> None :
245255 """Serve a LitGPT model using LitServe.
246256
@@ -284,6 +294,10 @@ def run_server(
284294 access_token: Optional API token to access models with restrictions.
285295 api_path: The custom API path for the endpoint (e.g., "/my_api/classify").
286296 timeout: Request timeout in seconds. Defaults to 30.
297+ generate_strategy: The generation strategy to use. The "sequential" strategy (default for devices > 1)
298+ allows running models that wouldn't fit in a single card by partitioning the transformer blocks across
299+ all devices and running them sequentially. "tensor_parallel" shards the model using tensor parallelism.
300+ If None (default for devices = 1), the model is not distributed.
287301 """
288302 checkpoint_dir = auto_download_checkpoint (model_name = checkpoint_dir , access_token = access_token )
289303 pprint (locals ())
@@ -301,6 +315,7 @@ def run_server(
301315 max_new_tokens = max_new_tokens ,
302316 devices = devices ,
303317 api_path = api_path ,
318+ generate_strategy = generate_strategy ,
304319 ),
305320 spec = OpenAISpec () if openai_spec else None ,
306321 accelerator = accelerator ,
0 commit comments