Skip to content

Commit 2f504d8

Browse files
authored
feat(serve.py): add api_path parameter to cli options to allow custom API endpoint configuration (#2080)
1 parent 8f19053 commit 2f504d8

File tree

1 file changed

+55
-9
lines changed

1 file changed

+55
-9
lines changed

litgpt/deploy/serve.py

Lines changed: 55 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,13 @@ def __init__(
3131
top_p: float = 1.0,
3232
max_new_tokens: int = 50,
3333
devices: int = 1,
34+
api_path: Optional[str] = None,
3435
) -> None:
3536
if not _LITSERVE_AVAILABLE:
3637
raise ImportError(str(_LITSERVE_AVAILABLE))
3738

38-
super().__init__()
39+
super().__init__(api_path=api_path)
40+
3941
self.checkpoint_dir = checkpoint_dir
4042
self.quantize = quantize
4143
self.precision = precision
@@ -61,12 +63,11 @@ def setup(self, device: str) -> None:
6163
accelerator=accelerator,
6264
quantize=self.quantize,
6365
precision=self.precision,
64-
generate_strategy="sequential" if self.devices is not None and self.devices > 1 else None,
66+
generate_strategy=("sequential" if self.devices is not None and self.devices > 1 else None),
6567
)
6668
print("Model successfully initialized.", file=sys.stderr)
6769

6870
def decode_request(self, request: Dict[str, Any]) -> Any:
69-
# Convert the request payload to your model input.
7071
prompt = str(request["prompt"])
7172
return prompt
7273

@@ -82,15 +83,30 @@ def __init__(
8283
top_p: float = 1.0,
8384
max_new_tokens: int = 50,
8485
devices: int = 1,
86+
api_path: Optional[str] = None,
8587
):
86-
super().__init__(checkpoint_dir, quantize, precision, temperature, top_k, top_p, max_new_tokens, devices)
88+
super().__init__(
89+
checkpoint_dir,
90+
quantize,
91+
precision,
92+
temperature,
93+
top_k,
94+
top_p,
95+
max_new_tokens,
96+
devices,
97+
api_path=api_path,
98+
)
8799

88100
def setup(self, device: str):
89101
super().setup(device)
90102

91103
def predict(self, inputs: str) -> Any:
92104
output = self.llm.generate(
93-
inputs, temperature=self.temperature, top_k=self.top_k, top_p=self.top_p, max_new_tokens=self.max_new_tokens
105+
inputs,
106+
temperature=self.temperature,
107+
top_k=self.top_k,
108+
top_p=self.top_p,
109+
max_new_tokens=self.max_new_tokens,
94110
)
95111
return output
96112

@@ -110,14 +126,24 @@ def __init__(
110126
top_p: float = 1.0,
111127
max_new_tokens: int = 50,
112128
devices: int = 1,
129+
api_path: Optional[str] = None,
113130
):
114-
super().__init__(checkpoint_dir, quantize, precision, temperature, top_k, top_p, max_new_tokens, devices)
131+
super().__init__(
132+
checkpoint_dir,
133+
quantize,
134+
precision,
135+
temperature,
136+
top_k,
137+
top_p,
138+
max_new_tokens,
139+
devices,
140+
api_path=api_path,
141+
)
115142

116143
def setup(self, device: str):
117144
super().setup(device)
118145

119146
def predict(self, inputs: torch.Tensor) -> Any:
120-
# Run the model on the input and return the output.
121147
yield from self.llm.generate(
122148
inputs,
123149
temperature=self.temperature,
@@ -143,8 +169,19 @@ def __init__(
143169
top_p: float = 1.0,
144170
max_new_tokens: int = 50,
145171
devices: int = 1,
172+
api_path: Optional[str] = None,
146173
):
147-
super().__init__(checkpoint_dir, quantize, precision, temperature, top_k, top_p, max_new_tokens, devices)
174+
super().__init__(
175+
checkpoint_dir,
176+
quantize,
177+
precision,
178+
temperature,
179+
top_k,
180+
top_p,
181+
max_new_tokens,
182+
devices,
183+
api_path=api_path,
184+
)
148185

149186
def setup(self, device: str):
150187
super().setup(device)
@@ -178,7 +215,12 @@ def predict(self, inputs: str, context: dict) -> Any:
178215

179216
# Run the model on the input and return the output.
180217
yield from self.llm.generate(
181-
inputs, temperature=temperature, top_k=self.top_k, top_p=top_p, max_new_tokens=max_new_tokens, stream=True
218+
inputs,
219+
temperature=temperature,
220+
top_k=self.top_k,
221+
top_p=top_p,
222+
max_new_tokens=max_new_tokens,
223+
stream=True,
182224
)
183225

184226

@@ -196,6 +238,7 @@ def run_server(
196238
stream: bool = False,
197239
openai_spec: bool = False,
198240
access_token: Optional[str] = None,
241+
api_path: Optional[str] = "/predict",
199242
) -> None:
200243
"""Serve a LitGPT model using LitServe.
201244
@@ -237,11 +280,13 @@ def run_server(
237280
`/v1/chat/completions` endpoints that work with the OpenAI SDK and other OpenAI-compatible clients,
238281
making it easy to integrate with existing applications that use the OpenAI API.
239282
access_token: Optional API token to access models with restrictions.
283+
api_path: The custom API path for the endpoint (e.g., "/my_api/classify").
240284
"""
241285
checkpoint_dir = auto_download_checkpoint(model_name=checkpoint_dir, access_token=access_token)
242286
pprint(locals())
243287

244288
api_class = OpenAISpecLitAPI if openai_spec else StreamLitAPI if stream else SimpleLitAPI
289+
245290
server = LitServer(
246291
api_class(
247292
checkpoint_dir=checkpoint_dir,
@@ -252,6 +297,7 @@ def run_server(
252297
top_p=top_p,
253298
max_new_tokens=max_new_tokens,
254299
devices=devices,
300+
api_path=api_path,
255301
),
256302
spec=OpenAISpec() if openai_spec else None,
257303
accelerator=accelerator,

0 commit comments

Comments
 (0)