Skip to content

Commit d46da46

Browse files
committed
Implement JIT and Auto-Unload
1 parent a459c01 commit d46da46

File tree

14 files changed

+1254
-238
lines changed

14 files changed

+1254
-238
lines changed

README.md

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ This repository hosts a high-performance API server that provides OpenAI-compati
4343
- 🎛️ **LoRA adapter support** for fine-tuned image generation
4444
-**Configurable quantization** (4-bit, 8-bit, 16-bit) for optimal performance
4545
- 🧠 **Customizable context length** for memory optimization and performance tuning
46+
- ♻️ **JIT loading with idle auto-unload** to reclaim VRAM when the server is idle
4647

4748
---
4849

@@ -401,6 +402,23 @@ mlx-openai-server launch \
401402
402403
```
403404
405+
#### Enabling JIT Loading & Auto-Unload
406+
407+
Use the `--jit` flag to defer model initialization until the first request arrives. Pair it with
408+
`--auto-unload-minutes <minutes>` to automatically unload the model after a period of inactivity and
409+
reclaim VRAM. Example:
410+
411+
```bash
412+
mlx-openai-server launch \
413+
--model-path <path-to-mlx-model> \
414+
--model-type lm \
415+
--jit \
416+
--auto-unload-minutes 30
417+
```
418+
419+
When JIT mode is active, the `/health` endpoint reports `status="ok"` with
420+
`model_status="unloaded"` while the model is idle and loads it back on demand for the next request.
421+
404422
#### Server Parameters
405423
- `--model-path`: Path to the MLX model directory (local path or Hugging Face model repository). Required for `lm`, `multimodal`, `embeddings`, `image-generation`, `image-edit`, and `whisper` model types.
406424
- `--model-type`: Type of model to run:

app/api/endpoints.py

Lines changed: 194 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,18 @@
66
import json
77
import random
88
import time
9-
from typing import Annotated, Any, Literal
9+
from typing import Annotated, Any, Literal, TypeAlias
1010

1111
from fastapi import APIRouter, Form, HTTPException, Request
1212
from fastapi.responses import JSONResponse, StreamingResponse
1313
from loguru import logger
1414
import numpy as np
1515

1616
from ..handler import MFLUX_AVAILABLE, MLXFluxHandler
17+
from ..handler.mlx_embeddings import MLXEmbeddingsHandler
1718
from ..handler.mlx_lm import MLXLMHandler
1819
from ..handler.mlx_vlm import MLXVLMHandler
20+
from ..handler.mlx_whisper import MLXWhisperHandler
1921
from ..schemas.openai import (
2022
ChatCompletionChunk,
2123
ChatCompletionMessageToolCall,
@@ -48,30 +50,114 @@
4850
router = APIRouter()
4951

5052

53+
MLXHandlerType: TypeAlias = (
54+
MLXVLMHandler | MLXLMHandler | MLXFluxHandler | MLXEmbeddingsHandler | MLXWhisperHandler
55+
)
56+
57+
58+
async def _get_handler_or_error(
59+
raw_request: Request, reason: str
60+
) -> tuple[MLXHandlerType | None, JSONResponse | None]:
61+
"""Return a loaded handler or an error response if unavailable."""
62+
63+
handler_manager = getattr(raw_request.app.state, "handler_manager", None)
64+
if handler_manager is not None:
65+
try:
66+
handler = await handler_manager.ensure_loaded(reason)
67+
except HTTPException:
68+
raise
69+
except Exception as e: # pragma: no cover - defensive logging
70+
logger.exception(
71+
f"Unable to load handler via JIT for {reason}. {type(e).__name__}: {e}"
72+
)
73+
return (
74+
None,
75+
JSONResponse(
76+
content=create_error_response(
77+
"Failed to load model handler",
78+
"server_error",
79+
HTTPStatus.INTERNAL_SERVER_ERROR,
80+
),
81+
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
82+
),
83+
)
84+
85+
if handler is not None:
86+
return handler, None
87+
88+
handler = getattr(raw_request.app.state, "handler", None)
89+
if handler is None:
90+
return (
91+
None,
92+
JSONResponse(
93+
content=create_error_response(
94+
"Model handler not initialized",
95+
"service_unavailable",
96+
HTTPStatus.SERVICE_UNAVAILABLE,
97+
),
98+
status_code=HTTPStatus.SERVICE_UNAVAILABLE,
99+
),
100+
)
101+
return handler, None
102+
103+
104+
def _get_cached_model_metadata(raw_request: Request) -> dict[str, Any] | None:
105+
"""Fetch cached model metadata from application state, if available."""
106+
107+
metadata_cache = getattr(raw_request.app.state, "model_metadata", None)
108+
if isinstance(metadata_cache, list) and metadata_cache:
109+
entry = metadata_cache[0]
110+
if isinstance(entry, dict):
111+
return entry
112+
return None
113+
114+
115+
def _get_configured_model_id(raw_request: Request) -> str | None:
116+
"""Return the configured model identifier from config or cache."""
117+
118+
config = getattr(raw_request.app.state, "server_config", None)
119+
if config is not None:
120+
return getattr(config, "model_identifier", getattr(config, "model_path", None))
121+
122+
cached = _get_cached_model_metadata(raw_request)
123+
if cached is not None:
124+
return cached.get("id")
125+
return None
126+
127+
51128
# =============================================================================
52129
# Critical/Monitoring Endpoints - Defined first to ensure priority matching
53130
# =============================================================================
54131

55132

56133
@router.get("/health", response_model=None)
57134
async def health(raw_request: Request) -> HealthCheckResponse | JSONResponse:
58-
"""
59-
Health check endpoint - verifies handler initialization status.
135+
"""Health check endpoint aware of JIT/auto-unload state."""
136+
handler_manager = getattr(raw_request.app.state, "handler_manager", None)
137+
configured_model_id = _get_configured_model_id(raw_request)
138+
139+
if handler_manager is not None:
140+
handler = getattr(handler_manager, "current_handler", None)
141+
if handler is not None:
142+
model_id = getattr(handler, "model_path", configured_model_id or "unknown")
143+
return HealthCheckResponse(
144+
status=HealthCheckStatus.OK, model_id=model_id, model_status="initialized"
145+
)
146+
return HealthCheckResponse(
147+
status=HealthCheckStatus.OK,
148+
model_id=configured_model_id,
149+
model_status="unloaded",
150+
)
60151

61-
Returns 503 if handler is not initialized, 200 otherwise.
62-
"""
63152
handler = getattr(raw_request.app.state, "handler", None)
64-
65153
if handler is None:
66154
# Handler not initialized - return 503 with degraded status
67155
return JSONResponse(
68156
status_code=HTTPStatus.SERVICE_UNAVAILABLE,
69157
content={"status": "unhealthy", "model_id": None, "model_status": "uninitialized"},
70158
)
71159

72-
# Handler initialized - extract model_id
73-
model_id = getattr(handler, "model_path", "unknown")
74-
160+
model_id = getattr(handler, "model_path", configured_model_id or "unknown")
75161
return HealthCheckResponse(
76162
status=HealthCheckStatus.OK, model_id=model_id, model_status="initialized"
77163
)
@@ -101,8 +187,14 @@ async def models(raw_request: Request) -> ModelsResponse | JSONResponse:
101187
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
102188
)
103189

190+
cached_metadata = _get_cached_model_metadata(raw_request)
191+
if cached_metadata is not None:
192+
return ModelsResponse(object="list", data=[Model(**cached_metadata)])
193+
104194
# Fallback to handler (Phase 0 compatibility)
105-
handler = getattr(raw_request.app.state, "handler", None)
195+
handler, error = await _get_handler_or_error(raw_request, "models")
196+
if error is not None:
197+
return error
106198
if handler is None:
107199
return JSONResponse(
108200
content=create_error_response(
@@ -138,7 +230,9 @@ async def queue_stats(raw_request: Request) -> dict[str, Any] | JSONResponse:
138230
Note: queue_stats shape is handler-dependent (Flux vs LM/VLM/Whisper)
139231
so callers know keys may vary.
140232
"""
141-
handler = getattr(raw_request.app.state, "handler", None)
233+
handler, error = await _get_handler_or_error(raw_request, "queue_stats")
234+
if error is not None:
235+
return error
142236
if handler is None:
143237
return JSONResponse(
144238
content=create_error_response(
@@ -173,7 +267,9 @@ async def chat_completions(
173267
request: ChatCompletionRequest, raw_request: Request
174268
) -> ChatCompletionResponse | StreamingResponse | JSONResponse:
175269
"""Handle chat completion requests."""
176-
handler = getattr(raw_request.app.state, "handler", None)
270+
handler, error = await _get_handler_or_error(raw_request, "chat_completions")
271+
if error is not None:
272+
return error
177273
if handler is None:
178274
return JSONResponse(
179275
content=create_error_response(
@@ -213,7 +309,9 @@ async def embeddings(
213309
request: EmbeddingRequest, raw_request: Request
214310
) -> EmbeddingResponse | JSONResponse:
215311
"""Handle embedding requests."""
216-
handler = getattr(raw_request.app.state, "handler", None)
312+
handler, error = await _get_handler_or_error(raw_request, "embeddings")
313+
if error is not None:
314+
return error
217315
if handler is None:
218316
return JSONResponse(
219317
content=create_error_response(
@@ -224,6 +322,16 @@ async def embeddings(
224322
status_code=HTTPStatus.SERVICE_UNAVAILABLE,
225323
)
226324

325+
if not isinstance(handler, MLXEmbeddingsHandler):
326+
return JSONResponse(
327+
content=create_error_response(
328+
"Embedding requests require an embeddings model. Use --model-type embeddings.",
329+
"unsupported_request",
330+
HTTPStatus.BAD_REQUEST,
331+
),
332+
status_code=HTTPStatus.BAD_REQUEST,
333+
)
334+
227335
try:
228336
embeddings = await handler.generate_embeddings_response(request)
229337
return create_response_embeddings(embeddings, request.model, request.encoding_format)
@@ -241,7 +349,9 @@ async def image_generations(
241349
request: ImageGenerationRequest, raw_request: Request
242350
) -> ImageGenerationResponse | JSONResponse:
243351
"""Handle image generation requests."""
244-
handler = getattr(raw_request.app.state, "handler", None)
352+
handler, error = await _get_handler_or_error(raw_request, "image_generation")
353+
if error is not None:
354+
return error
245355
if handler is None:
246356
return JSONResponse(
247357
content=create_error_response(
@@ -281,7 +391,9 @@ async def create_image_edit(
281391
request: Annotated[ImageEditRequest, Form()], raw_request: Request
282392
) -> ImageEditResponse | JSONResponse:
283393
"""Handle image editing requests with dynamic provider routing."""
284-
handler = getattr(raw_request.app.state, "handler", None)
394+
handler, error = await _get_handler_or_error(raw_request, "image_edit")
395+
if error is not None:
396+
return error
285397
if handler is None:
286398
return JSONResponse(
287399
content=create_error_response(
@@ -320,20 +432,32 @@ async def create_audio_transcriptions(
320432
request: Annotated[TranscriptionRequest, Form()], raw_request: Request
321433
) -> StreamingResponse | TranscriptionResponse | JSONResponse | str:
322434
"""Handle audio transcription requests."""
323-
try:
324-
handler = getattr(raw_request.app.state, "handler", None)
325-
if handler is None:
326-
return JSONResponse(
327-
content=create_error_response(
328-
"Model handler not initialized",
329-
"service_unavailable",
330-
HTTPStatus.SERVICE_UNAVAILABLE,
331-
),
332-
status_code=HTTPStatus.SERVICE_UNAVAILABLE,
333-
)
435+
handler, error = await _get_handler_or_error(raw_request, "audio_transcriptions")
436+
if error is not None:
437+
return error
438+
if handler is None:
439+
return JSONResponse(
440+
content=create_error_response(
441+
"Model handler not initialized",
442+
"service_unavailable",
443+
HTTPStatus.SERVICE_UNAVAILABLE,
444+
),
445+
status_code=HTTPStatus.SERVICE_UNAVAILABLE,
446+
)
447+
448+
if not isinstance(handler, MLXWhisperHandler):
449+
return JSONResponse(
450+
content=create_error_response(
451+
"Audio transcription requests require a whisper model. Use --model-type whisper.",
452+
"unsupported_request",
453+
HTTPStatus.BAD_REQUEST,
454+
),
455+
status_code=HTTPStatus.BAD_REQUEST,
456+
)
334457

458+
try:
335459
if request.stream:
336-
# procoess the request before sending to the handler
460+
# process the request before sending to the handler
337461
request_data = await handler.prepare_transcription_request(request)
338462
return StreamingResponse(
339463
handler.generate_transcription_stream_from_data(request_data),
@@ -401,6 +525,7 @@ def create_response_chunk(
401525
chat_id: str | None = None,
402526
created_time: int | None = None,
403527
request_id: str | None = None,
528+
tool_call_id: str | None = None,
404529
) -> ChatCompletionChunk:
405530
"""Create a formatted response chunk for streaming."""
406531
chat_id = chat_id or get_id()
@@ -474,8 +599,12 @@ def create_response_chunk(
474599
if function_call:
475600
# Validate index exists before accessing
476601
tool_index = chunk.get("index", 0)
602+
tool_identifier = tool_call_id or get_tool_call_id()
477603
tool_chunk = ChoiceDeltaToolCall(
478-
index=tool_index, type="function", id=get_tool_call_id(), function=function_call
604+
index=tool_index,
605+
type="function",
606+
id=tool_identifier,
607+
function=function_call,
479608
)
480609

481610
delta = Delta(content=None, role="assistant", tool_calls=[tool_chunk]) # type: ignore[call-arg]
@@ -509,7 +638,9 @@ async def handle_stream_response(
509638
chat_index = get_id()
510639
created_time = int(time.time())
511640
finish_reason = "stop"
512-
tool_call_index = -1
641+
next_tool_call_index = 0
642+
current_implicit_tool_index: int | None = None
643+
tool_call_ids: dict[int, str] = {}
513644
usage_info = None
514645

515646
try:
@@ -546,19 +677,48 @@ async def handle_stream_response(
546677

547678
# Handle tool call chunks
548679
payload = dict(chunk) # Create a copy to avoid mutating the original
549-
if payload.get("name"):
680+
current_tool_id = None
681+
682+
has_name = bool(payload.get("name"))
683+
has_arguments = "arguments" in payload
684+
payload_index = payload.get("index")
685+
686+
if has_name:
550687
finish_reason = "tool_calls"
551-
tool_call_index += 1
552-
payload["index"] = tool_call_index
553-
elif payload.get("arguments") and "index" not in payload:
554-
payload["index"] = tool_call_index
688+
if payload_index is None:
689+
if current_implicit_tool_index is not None:
690+
payload_index = current_implicit_tool_index
691+
else:
692+
payload_index = next_tool_call_index
693+
next_tool_call_index += 1
694+
payload["index"] = payload_index
695+
current_implicit_tool_index = payload_index
696+
# Keep the implicit index available for additional argument chunks
697+
elif has_arguments:
698+
if payload_index is None:
699+
if current_implicit_tool_index is not None:
700+
payload_index = current_implicit_tool_index
701+
else:
702+
payload_index = next_tool_call_index
703+
next_tool_call_index += 1
704+
payload["index"] = payload_index
705+
current_implicit_tool_index = payload_index
706+
elif payload_index is not None:
707+
current_implicit_tool_index = payload_index
708+
709+
payload_index = payload.get("index")
710+
if payload_index is not None:
711+
if payload_index not in tool_call_ids:
712+
tool_call_ids[payload_index] = get_tool_call_id()
713+
current_tool_id = tool_call_ids[payload_index]
555714

556715
response_chunk = create_response_chunk(
557716
payload,
558717
model,
559718
chat_id=chat_index,
560719
created_time=created_time,
561720
request_id=request_id,
721+
tool_call_id=current_tool_id,
562722
)
563723
yield _yield_sse_chunk(response_chunk)
564724

0 commit comments

Comments
 (0)