Skip to content

Commit dbddd23

Browse files
adi776boratepre-commit-ci[bot]bhimrazy
authored
feat: add generate_strategy option to litgpt serve (#2188)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Bhimraj Yadav <bhimrajyadav977@gmail.com>
1 parent ac816e7 commit dbddd23

File tree

2 files changed

+58
-1
lines changed

2 files changed

+58
-1
lines changed

litgpt/deploy/serve.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def __init__(
3333
max_new_tokens: int = 50,
3434
devices: int = 1,
3535
api_path: Optional[str] = None,
36+
generate_strategy: Optional[Literal["sequential", "tensor_parallel"]] = None,
3637
) -> None:
3738
if not _LITSERVE_AVAILABLE:
3839
raise ImportError(str(_LITSERVE_AVAILABLE))
@@ -47,6 +48,7 @@ def __init__(
4748
self.max_new_tokens = max_new_tokens
4849
self.top_p = top_p
4950
self.devices = devices
51+
self.generate_strategy = generate_strategy
5052

5153
def setup(self, device: str) -> None:
5254
if ":" in device:
@@ -64,7 +66,8 @@ def setup(self, device: str) -> None:
6466
accelerator=accelerator,
6567
quantize=self.quantize,
6668
precision=self.precision,
67-
generate_strategy=("sequential" if self.devices is not None and self.devices > 1 else None),
69+
generate_strategy=self.generate_strategy
70+
or ("sequential" if self.devices is not None and self.devices > 1 else None),
6871
)
6972
print("Model successfully initialized.", file=sys.stderr)
7073

@@ -85,6 +88,7 @@ def __init__(
8588
max_new_tokens: int = 50,
8689
devices: int = 1,
8790
api_path: Optional[str] = None,
91+
generate_strategy: Optional[str] = None,
8892
):
8993
super().__init__(
9094
checkpoint_dir,
@@ -96,6 +100,7 @@ def __init__(
96100
max_new_tokens,
97101
devices,
98102
api_path=api_path,
103+
generate_strategy=generate_strategy,
99104
)
100105

101106
def setup(self, device: str):
@@ -128,6 +133,7 @@ def __init__(
128133
max_new_tokens: int = 50,
129134
devices: int = 1,
130135
api_path: Optional[str] = None,
136+
generate_strategy: Optional[str] = None,
131137
):
132138
super().__init__(
133139
checkpoint_dir,
@@ -139,6 +145,7 @@ def __init__(
139145
max_new_tokens,
140146
devices,
141147
api_path=api_path,
148+
generate_strategy=generate_strategy,
142149
)
143150

144151
def setup(self, device: str):
@@ -171,6 +178,7 @@ def __init__(
171178
max_new_tokens: int = 50,
172179
devices: int = 1,
173180
api_path: Optional[str] = None,
181+
generate_strategy: Optional[str] = None,
174182
):
175183
super().__init__(
176184
checkpoint_dir,
@@ -182,6 +190,7 @@ def __init__(
182190
max_new_tokens,
183191
devices,
184192
api_path=api_path,
193+
generate_strategy=generate_strategy,
185194
)
186195

187196
def setup(self, device: str):
@@ -241,6 +250,7 @@ def run_server(
241250
access_token: Optional[str] = None,
242251
api_path: Optional[str] = "/predict",
243252
timeout: int = 30,
253+
generate_strategy: Optional[Literal["sequential", "tensor_parallel"]] = None,
244254
) -> None:
245255
"""Serve a LitGPT model using LitServe.
246256
@@ -284,6 +294,10 @@ def run_server(
284294
access_token: Optional API token to access models with restrictions.
285295
api_path: The custom API path for the endpoint (e.g., "/my_api/classify").
286296
timeout: Request timeout in seconds. Defaults to 30.
297+
generate_strategy: The generation strategy to use. The "sequential" strategy (default for devices > 1)
298+
allows running models that wouldn't fit in a single card by partitioning the transformer blocks across
299+
all devices and running them sequentially. "tensor_parallel" shards the model using tensor parallelism.
300+
If None (default for devices = 1), the model is not distributed.
287301
"""
288302
checkpoint_dir = auto_download_checkpoint(model_name=checkpoint_dir, access_token=access_token)
289303
pprint(locals())
@@ -301,6 +315,7 @@ def run_server(
301315
max_new_tokens=max_new_tokens,
302316
devices=devices,
303317
api_path=api_path,
318+
generate_strategy=generate_strategy,
304319
),
305320
spec=OpenAISpec() if openai_spec else None,
306321
accelerator=accelerator,

tests/test_serve.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,3 +254,45 @@ def run_server():
254254
if process:
255255
kill_process_tree(process.pid)
256256
server_thread.join()
257+
258+
259+
@pytest.mark.parametrize(
260+
"generate_strategy",
261+
[
262+
pytest.param("sequential", marks=_RunIf(min_cuda_gpus=1)),
263+
pytest.param("tensor_parallel", marks=_RunIf(min_cuda_gpus=2)),
264+
],
265+
)
266+
def test_serve_with_generate_strategy(tmp_path, generate_strategy):
267+
seed_everything(123)
268+
ours_config = Config.from_name("pythia-14m")
269+
download_from_hub(repo_id="EleutherAI/pythia-14m", tokenizer_only=True, checkpoint_dir=tmp_path)
270+
shutil.move(str(tmp_path / "EleutherAI" / "pythia-14m" / "tokenizer.json"), str(tmp_path))
271+
shutil.move(str(tmp_path / "EleutherAI" / "pythia-14m" / "tokenizer_config.json"), str(tmp_path))
272+
ours_model = GPT(ours_config)
273+
checkpoint_path = tmp_path / "lit_model.pth"
274+
torch.save(ours_model.state_dict(), checkpoint_path)
275+
config_path = tmp_path / "model_config.yaml"
276+
with open(config_path, "w", encoding="utf-8") as fp:
277+
yaml.dump(asdict(ours_config), fp)
278+
279+
# Test with generate strategy
280+
run_command = ["litgpt", "serve", tmp_path, "--generate_strategy", generate_strategy]
281+
282+
process = None
283+
284+
def run_server():
285+
nonlocal process
286+
try:
287+
process = subprocess.Popen(run_command, stdout=None, stderr=None, text=True)
288+
except subprocess.TimeoutExpired:
289+
print("Server start-up timeout expired")
290+
291+
server_thread = threading.Thread(target=run_server)
292+
server_thread.start()
293+
294+
_wait_and_check_response()
295+
296+
if process:
297+
kill_process_tree(process.pid)
298+
server_thread.join()

0 commit comments

Comments
 (0)