|
1 | 1 | # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
|
| 2 | +import json |
2 | 3 | import platform
|
3 | 4 | import shutil
|
4 | 5 | import subprocess
|
@@ -136,3 +137,127 @@ def run_server():
|
136 | 137 | if process:
|
137 | 138 | kill_process_tree(process.pid)
|
138 | 139 | server_thread.join()
|
| 140 | + |
| 141 | + |
| 142 | +@_RunIf(min_cuda_gpus=1) |
| 143 | +def test_serve_with_openai_spec_missing_chat_template(tmp_path): |
| 144 | + seed_everything(123) |
| 145 | + ours_config = Config.from_name("pythia-14m") |
| 146 | + download_from_hub(repo_id="EleutherAI/pythia-14m", tokenizer_only=True, checkpoint_dir=tmp_path) |
| 147 | + shutil.move(str(tmp_path / "EleutherAI" / "pythia-14m" / "tokenizer.json"), str(tmp_path)) |
| 148 | + shutil.move(str(tmp_path / "EleutherAI" / "pythia-14m" / "tokenizer_config.json"), str(tmp_path)) |
| 149 | + ours_model = GPT(ours_config) |
| 150 | + checkpoint_path = tmp_path / "lit_model.pth" |
| 151 | + torch.save(ours_model.state_dict(), checkpoint_path) |
| 152 | + config_path = tmp_path / "model_config.yaml" |
| 153 | + with open(config_path, "w", encoding="utf-8") as fp: |
| 154 | + yaml.dump(asdict(ours_config), fp) |
| 155 | + |
| 156 | + run_command = ["litgpt", "serve", tmp_path, "--openai_spec", "true"] |
| 157 | + |
| 158 | + process = None |
| 159 | + |
| 160 | + def run_server(): |
| 161 | + nonlocal process |
| 162 | + try: |
| 163 | + process = subprocess.Popen(run_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) |
| 164 | + except subprocess.TimeoutExpired: |
| 165 | + print("Server start-up timeout expired") |
| 166 | + return None, None |
| 167 | + |
| 168 | + server_thread = threading.Thread(target=run_server) |
| 169 | + server_thread.start() |
| 170 | + |
| 171 | + time.sleep(30) # Give the server some time to start and raise the error |
| 172 | + |
| 173 | + try: |
| 174 | + stdout = process.stdout.read().strip() if process.stdout else "" |
| 175 | + stderr = process.stderr.read().strip() if process.stderr else "" |
| 176 | + output = (stdout or "") + (stderr or "") |
| 177 | + assert "ValueError: chat_template not found in tokenizer config file." in output, ( |
| 178 | + "Expected ValueError for missing chat_template not found." |
| 179 | + ) |
| 180 | + finally: |
| 181 | + if process: |
| 182 | + kill_process_tree(process.pid) |
| 183 | + server_thread.join() |
| 184 | + |
| 185 | + |
| 186 | +@_RunIf(min_cuda_gpus=1) |
| 187 | +def test_serve_with_openai_spec(tmp_path): |
| 188 | + seed_everything(123) |
| 189 | + ours_config = Config.from_name("SmolLM2-135M-Instruct") |
| 190 | + download_from_hub(repo_id="HuggingFaceTB/SmolLM2-135M-Instruct", tokenizer_only=True, checkpoint_dir=tmp_path) |
| 191 | + shutil.move(str(tmp_path / "HuggingFaceTB" / "SmolLM2-135M-Instruct" / "tokenizer.json"), str(tmp_path)) |
| 192 | + shutil.move(str(tmp_path / "HuggingFaceTB" / "SmolLM2-135M-Instruct" / "tokenizer_config.json"), str(tmp_path)) |
| 193 | + ours_model = GPT(ours_config) |
| 194 | + checkpoint_path = tmp_path / "lit_model.pth" |
| 195 | + torch.save(ours_model.state_dict(), checkpoint_path) |
| 196 | + config_path = tmp_path / "model_config.yaml" |
| 197 | + with open(config_path, "w", encoding="utf-8") as fp: |
| 198 | + yaml.dump(asdict(ours_config), fp) |
| 199 | + |
| 200 | + run_command = ["litgpt", "serve", tmp_path, "--openai_spec", "true"] |
| 201 | + |
| 202 | + process = None |
| 203 | + |
| 204 | + def run_server(): |
| 205 | + nonlocal process |
| 206 | + try: |
| 207 | + process = subprocess.Popen(run_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) |
| 208 | + except subprocess.TimeoutExpired: |
| 209 | + print("Server start-up timeout expired") |
| 210 | + |
| 211 | + server_thread = threading.Thread(target=run_server) |
| 212 | + server_thread.start() |
| 213 | + |
| 214 | + _wait_and_check_response() |
| 215 | + |
| 216 | + try: |
| 217 | + # Test server health |
| 218 | + response = requests.get("http://127.0.0.1:8000/health") |
| 219 | + assert response.status_code == 200, f"Server health check failed with status code {response.status_code}" |
| 220 | + assert response.text == "ok", "Server did not respond as expected." |
| 221 | + |
| 222 | + # Test non-streaming chat completion |
| 223 | + response = requests.post( |
| 224 | + "http://127.0.0.1:8000/v1/chat/completions", |
| 225 | + json={ |
| 226 | + "model": "SmolLM2-135M-Instruct", |
| 227 | + "messages": [{"role": "user", "content": "Hello!"}], |
| 228 | + }, |
| 229 | + ) |
| 230 | + assert response.status_code == 200, ( |
| 231 | + f"Non-streaming chat completion failed with status code {response.status_code}" |
| 232 | + ) |
| 233 | + response_json = response.json() |
| 234 | + assert "choices" in response_json, "Response JSON does not contain 'choices'." |
| 235 | + assert "message" in response_json["choices"][0], "Response JSON does not contain 'message' in 'choices'." |
| 236 | + assert "content" in response_json["choices"][0]["message"], ( |
| 237 | + "Response JSON does not contain 'content' in 'message'." |
| 238 | + ) |
| 239 | + assert response_json["choices"][0]["message"]["content"], "Content is empty in the response." |
| 240 | + |
| 241 | + # Test streaming chat completion |
| 242 | + stream_response = requests.post( |
| 243 | + "http://127.0.0.1:8000/v1/chat/completions", |
| 244 | + json={ |
| 245 | + "model": "SmolLM2-135M-Instruct", |
| 246 | + "messages": [{"role": "user", "content": "Hello!"}], |
| 247 | + "stream": True, |
| 248 | + }, |
| 249 | + ) |
| 250 | + assert stream_response.status_code == 200, ( |
| 251 | + f"Streaming chat completion failed with status code {stream_response.status_code}" |
| 252 | + ) |
| 253 | + for line in stream_response.iter_lines(): |
| 254 | + decoded = line.decode("utf-8").replace("data: ", "").replace("[DONE]", "").strip() |
| 255 | + if decoded: |
| 256 | + data = json.loads(decoded) |
| 257 | + assert "choices" in data, "Response JSON does not contain 'choices'." |
| 258 | + assert "delta" in data["choices"][0], "Response JSON does not contain 'delta' in 'choices'." |
| 259 | + assert "content" in data["choices"][0]["delta"], "Response JSON does not contain 'content' in 'delta'." |
| 260 | + finally: |
| 261 | + if process: |
| 262 | + kill_process_tree(process.pid) |
| 263 | + server_thread.join() |
0 commit comments