@@ -40,7 +40,7 @@ def setup(self, device: str) -> None:
4040
4141 fabric = L .Fabric (
4242 accelerator = device .type ,
43- devices = 1 if device .type == "cpu" else [device .index ], # TODO: Update once LitServe supports "auto"
43+ devices = 1 if device .type == "cpu" else [device .index ],
4444 precision = precision ,
4545 )
4646 checkpoint_path = self .checkpoint_dir / "lit_model.pth"
@@ -99,7 +99,7 @@ def run_server(
9999 top_k : int = 200 ,
100100 max_new_tokens : int = 50 ,
101101 devices : int = 1 ,
102- accelerator : str = "cuda " ,
102+ accelerator : str = "auto " ,
103103 port : int = 8000
104104) -> None :
105105 """Serve a LitGPT model using LitServe
@@ -114,7 +114,8 @@ def run_server(
114114 generated text but can also lead to more incoherent texts.
115115 max_new_tokens: The number of generation steps to take.
116116 devices: How many devices/GPUs to use.
117- accelerator: The type of accelerator to use. For example, "cuda" or "cpu".
117+ accelerator: The type of accelerator to use. For example, "auto", "cuda", "cpu", or "mps".
118+ The "auto" setting (default) chooses a GPU if available, and otherwise uses a CPU.
118119 port: The network port number on which the model is configured to be served.
119120 """
120121 check_valid_checkpoint_dir (checkpoint_dir , model_filename = "lit_model.pth" )
0 commit comments