Skip to content

Commit 54fe4a8

Browse files
authored
Feature: Adds support for OpenAISpec in litgpt serve (#1943)
1 parent 552ac10 commit 54fe4a8

File tree

3 files changed

+197
-32
lines changed

3 files changed

+197
-32
lines changed

litgpt/deploy/serve.py

Lines changed: 71 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
2+
import json
23
import sys
34
from pathlib import Path
45
from pprint import pprint
@@ -11,8 +12,10 @@
1112
from litgpt.utils import auto_download_checkpoint
1213

1314
_LITSERVE_AVAILABLE = RequirementCache("litserve")
15+
_JINJA2_AVAILABLE = RequirementCache("jinja2")
1416
if _LITSERVE_AVAILABLE:
1517
from litserve import LitAPI, LitServer
18+
from litserve.specs.openai import ChatCompletionRequest, OpenAISpec
1619
else:
1720
LitAPI, LitServer = object, object
1821

@@ -129,6 +132,55 @@ def encode_response(self, output):
129132
yield {"output": out}
130133

131134

135+
class OpenAISpecLitAPI(BaseLitAPI):
136+
def __init__(
137+
self,
138+
checkpoint_dir: Path,
139+
quantize: Optional[str] = None,
140+
precision: Optional[str] = None,
141+
temperature: float = 0.8,
142+
top_k: int = 50,
143+
top_p: float = 1.0,
144+
max_new_tokens: int = 50,
145+
devices: int = 1,
146+
):
147+
super().__init__(checkpoint_dir, quantize, precision, temperature, top_k, top_p, max_new_tokens, devices)
148+
149+
def setup(self, device: str):
150+
super().setup(device)
151+
if not _JINJA2_AVAILABLE:
152+
raise ImportError(str(_JINJA2_AVAILABLE))
153+
from jinja2 import Template
154+
155+
config_path = self.checkpoint_dir / "tokenizer_config.json"
156+
if not config_path.is_file():
157+
raise FileNotFoundError(f"Tokenizer config file not found at {config_path}")
158+
159+
with open(config_path, encoding="utf-8") as fp:
160+
config = json.load(fp)
161+
chat_template = config.get("chat_template", None)
162+
if chat_template is None:
163+
raise ValueError("chat_template not found in tokenizer config file.")
164+
self.chat_template = chat_template
165+
166+
self.template = Template(self.chat_template)
167+
168+
def decode_request(self, request: "ChatCompletionRequest") -> Any:
169+
# Apply chat template to request messages
170+
return self.template.render(messages=request.messages)
171+
172+
def predict(self, inputs: str, context: dict) -> Any:
173+
# Extract parameters from context with fallback to instance attributes
174+
temperature = context.get("temperature") or self.temperature
175+
top_p = context.get("top_p", self.top_p) or self.top_p
176+
max_new_tokens = context.get("max_completion_tokens") or self.max_new_tokens
177+
178+
# Run the model on the input and return the output.
179+
yield from self.llm.generate(
180+
inputs, temperature=temperature, top_k=self.top_k, top_p=top_p, max_new_tokens=max_new_tokens, stream=True
181+
)
182+
183+
132184
def run_server(
133185
checkpoint_dir: Path,
134186
quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8"]] = None,
@@ -141,6 +193,7 @@ def run_server(
141193
accelerator: str = "auto",
142194
port: int = 8000,
143195
stream: bool = False,
196+
openai_spec: bool = False,
144197
access_token: Optional[str] = None,
145198
) -> None:
146199
"""Serve a LitGPT model using LitServe.
@@ -179,42 +232,28 @@ def run_server(
179232
The "auto" setting (default) chooses a GPU if available, and otherwise uses a CPU.
180233
port: The network port number on which the model is configured to be served.
181234
stream: Whether to stream the responses.
235+
openai_spec: Whether to use the OpenAISpec.
182236
access_token: Optional API token to access models with restrictions.
183237
"""
184238
checkpoint_dir = auto_download_checkpoint(model_name=checkpoint_dir, access_token=access_token)
185239
pprint(locals())
186240

187-
if not stream:
188-
server = LitServer(
189-
SimpleLitAPI(
190-
checkpoint_dir=checkpoint_dir,
191-
quantize=quantize,
192-
precision=precision,
193-
temperature=temperature,
194-
top_k=top_k,
195-
top_p=top_p,
196-
max_new_tokens=max_new_tokens,
197-
devices=devices,
198-
),
199-
accelerator=accelerator,
200-
devices=1, # We need to use the devives inside the `SimpleLitAPI` class
201-
)
202-
203-
else:
204-
server = LitServer(
205-
StreamLitAPI(
206-
checkpoint_dir=checkpoint_dir,
207-
quantize=quantize,
208-
precision=precision,
209-
temperature=temperature,
210-
top_k=top_k,
211-
top_p=top_p,
212-
max_new_tokens=max_new_tokens,
213-
devices=devices, # We need to use the devives inside the `StreamLitAPI` class
214-
),
215-
accelerator=accelerator,
216-
devices=1,
217-
stream=True,
218-
)
241+
api_class = OpenAISpecLitAPI if openai_spec else StreamLitAPI if stream else SimpleLitAPI
242+
server = LitServer(
243+
api_class(
244+
checkpoint_dir=checkpoint_dir,
245+
quantize=quantize,
246+
precision=precision,
247+
temperature=temperature,
248+
top_k=top_k,
249+
top_p=top_p,
250+
max_new_tokens=max_new_tokens,
251+
devices=devices,
252+
),
253+
spec=OpenAISpec() if openai_spec else None,
254+
accelerator=accelerator,
255+
devices=1,
256+
stream=stream,
257+
)
219258

220259
server.run(port=port, generate_client_file=False)

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ optional-dependencies.extra = [
7575
optional-dependencies.test = [
7676
"einops>=0.7",
7777
"protobuf>=4.23.4",
78+
"pydantic>=2.11",
7879
"pytest>=8.1.1",
7980
"pytest-benchmark>=5.1",
8081
"pytest-dependency>=0.6",

tests/test_serve.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
2+
import json
23
import platform
34
import shutil
45
import subprocess
@@ -136,3 +137,127 @@ def run_server():
136137
if process:
137138
kill_process_tree(process.pid)
138139
server_thread.join()
140+
141+
142+
@_RunIf(min_cuda_gpus=1)
143+
def test_serve_with_openai_spec_missing_chat_template(tmp_path):
144+
seed_everything(123)
145+
ours_config = Config.from_name("pythia-14m")
146+
download_from_hub(repo_id="EleutherAI/pythia-14m", tokenizer_only=True, checkpoint_dir=tmp_path)
147+
shutil.move(str(tmp_path / "EleutherAI" / "pythia-14m" / "tokenizer.json"), str(tmp_path))
148+
shutil.move(str(tmp_path / "EleutherAI" / "pythia-14m" / "tokenizer_config.json"), str(tmp_path))
149+
ours_model = GPT(ours_config)
150+
checkpoint_path = tmp_path / "lit_model.pth"
151+
torch.save(ours_model.state_dict(), checkpoint_path)
152+
config_path = tmp_path / "model_config.yaml"
153+
with open(config_path, "w", encoding="utf-8") as fp:
154+
yaml.dump(asdict(ours_config), fp)
155+
156+
run_command = ["litgpt", "serve", tmp_path, "--openai_spec", "true"]
157+
158+
process = None
159+
160+
def run_server():
161+
nonlocal process
162+
try:
163+
process = subprocess.Popen(run_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
164+
except subprocess.TimeoutExpired:
165+
print("Server start-up timeout expired")
166+
return None, None
167+
168+
server_thread = threading.Thread(target=run_server)
169+
server_thread.start()
170+
171+
time.sleep(30) # Give the server some time to start and raise the error
172+
173+
try:
174+
stdout = process.stdout.read().strip() if process.stdout else ""
175+
stderr = process.stderr.read().strip() if process.stderr else ""
176+
output = (stdout or "") + (stderr or "")
177+
assert "ValueError: chat_template not found in tokenizer config file." in output, (
178+
"Expected ValueError for missing chat_template not found."
179+
)
180+
finally:
181+
if process:
182+
kill_process_tree(process.pid)
183+
server_thread.join()
184+
185+
186+
@_RunIf(min_cuda_gpus=1)
187+
def test_serve_with_openai_spec(tmp_path):
188+
seed_everything(123)
189+
ours_config = Config.from_name("SmolLM2-135M-Instruct")
190+
download_from_hub(repo_id="HuggingFaceTB/SmolLM2-135M-Instruct", tokenizer_only=True, checkpoint_dir=tmp_path)
191+
shutil.move(str(tmp_path / "HuggingFaceTB" / "SmolLM2-135M-Instruct" / "tokenizer.json"), str(tmp_path))
192+
shutil.move(str(tmp_path / "HuggingFaceTB" / "SmolLM2-135M-Instruct" / "tokenizer_config.json"), str(tmp_path))
193+
ours_model = GPT(ours_config)
194+
checkpoint_path = tmp_path / "lit_model.pth"
195+
torch.save(ours_model.state_dict(), checkpoint_path)
196+
config_path = tmp_path / "model_config.yaml"
197+
with open(config_path, "w", encoding="utf-8") as fp:
198+
yaml.dump(asdict(ours_config), fp)
199+
200+
run_command = ["litgpt", "serve", tmp_path, "--openai_spec", "true"]
201+
202+
process = None
203+
204+
def run_server():
205+
nonlocal process
206+
try:
207+
process = subprocess.Popen(run_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
208+
except subprocess.TimeoutExpired:
209+
print("Server start-up timeout expired")
210+
211+
server_thread = threading.Thread(target=run_server)
212+
server_thread.start()
213+
214+
_wait_and_check_response()
215+
216+
try:
217+
# Test server health
218+
response = requests.get("http://127.0.0.1:8000/health")
219+
assert response.status_code == 200, f"Server health check failed with status code {response.status_code}"
220+
assert response.text == "ok", "Server did not respond as expected."
221+
222+
# Test non-streaming chat completion
223+
response = requests.post(
224+
"http://127.0.0.1:8000/v1/chat/completions",
225+
json={
226+
"model": "SmolLM2-135M-Instruct",
227+
"messages": [{"role": "user", "content": "Hello!"}],
228+
},
229+
)
230+
assert response.status_code == 200, (
231+
f"Non-streaming chat completion failed with status code {response.status_code}"
232+
)
233+
response_json = response.json()
234+
assert "choices" in response_json, "Response JSON does not contain 'choices'."
235+
assert "message" in response_json["choices"][0], "Response JSON does not contain 'message' in 'choices'."
236+
assert "content" in response_json["choices"][0]["message"], (
237+
"Response JSON does not contain 'content' in 'message'."
238+
)
239+
assert response_json["choices"][0]["message"]["content"], "Content is empty in the response."
240+
241+
# Test streaming chat completion
242+
stream_response = requests.post(
243+
"http://127.0.0.1:8000/v1/chat/completions",
244+
json={
245+
"model": "SmolLM2-135M-Instruct",
246+
"messages": [{"role": "user", "content": "Hello!"}],
247+
"stream": True,
248+
},
249+
)
250+
assert stream_response.status_code == 200, (
251+
f"Streaming chat completion failed with status code {stream_response.status_code}"
252+
)
253+
for line in stream_response.iter_lines():
254+
decoded = line.decode("utf-8").replace("data: ", "").replace("[DONE]", "").strip()
255+
if decoded:
256+
data = json.loads(decoded)
257+
assert "choices" in data, "Response JSON does not contain 'choices'."
258+
assert "delta" in data["choices"][0], "Response JSON does not contain 'delta' in 'choices'."
259+
assert "content" in data["choices"][0]["delta"], "Response JSON does not contain 'content' in 'delta'."
260+
finally:
261+
if process:
262+
kill_process_tree(process.pid)
263+
server_thread.join()

0 commit comments

Comments
 (0)