66import json
77import random
88import time
9- from typing import Annotated , Any , Literal
9+ from typing import Annotated , Any , Literal , TypeAlias
1010
1111from fastapi import APIRouter , Form , HTTPException , Request
1212from fastapi .responses import JSONResponse , StreamingResponse
1313from loguru import logger
1414import numpy as np
1515
1616from ..handler import MFLUX_AVAILABLE , MLXFluxHandler
17+ from ..handler .mlx_embeddings import MLXEmbeddingsHandler
1718from ..handler .mlx_lm import MLXLMHandler
1819from ..handler .mlx_vlm import MLXVLMHandler
20+ from ..handler .mlx_whisper import MLXWhisperHandler
1921from ..schemas .openai import (
2022 ChatCompletionChunk ,
2123 ChatCompletionMessageToolCall ,
4850router = APIRouter ()
4951
5052
53+ MLXHandlerType : TypeAlias = (
54+ MLXVLMHandler | MLXLMHandler | MLXFluxHandler | MLXEmbeddingsHandler | MLXWhisperHandler
55+ )
56+
57+
58+ async def _get_handler_or_error (
59+ raw_request : Request , reason : str
60+ ) -> tuple [MLXHandlerType | None , JSONResponse | None ]:
61+ """Return a loaded handler or an error response if unavailable."""
62+
63+ handler_manager = getattr (raw_request .app .state , "handler_manager" , None )
64+ if handler_manager is not None :
65+ try :
66+ handler = await handler_manager .ensure_loaded (reason )
67+ except HTTPException :
68+ raise
69+ except Exception as e : # pragma: no cover - defensive logging
70+ logger .exception (
71+ f"Unable to load handler via JIT for { reason } . { type (e ).__name__ } : { e } "
72+ )
73+ return (
74+ None ,
75+ JSONResponse (
76+ content = create_error_response (
77+ "Failed to load model handler" ,
78+ "server_error" ,
79+ HTTPStatus .INTERNAL_SERVER_ERROR ,
80+ ),
81+ status_code = HTTPStatus .INTERNAL_SERVER_ERROR ,
82+ ),
83+ )
84+
85+ if handler is not None :
86+ return handler , None
87+
88+ handler = getattr (raw_request .app .state , "handler" , None )
89+ if handler is None :
90+ return (
91+ None ,
92+ JSONResponse (
93+ content = create_error_response (
94+ "Model handler not initialized" ,
95+ "service_unavailable" ,
96+ HTTPStatus .SERVICE_UNAVAILABLE ,
97+ ),
98+ status_code = HTTPStatus .SERVICE_UNAVAILABLE ,
99+ ),
100+ )
101+ return handler , None
102+
103+
104+ def _get_cached_model_metadata (raw_request : Request ) -> dict [str , Any ] | None :
105+ """Fetch cached model metadata from application state, if available."""
106+
107+ metadata_cache = getattr (raw_request .app .state , "model_metadata" , None )
108+ if isinstance (metadata_cache , list ) and metadata_cache :
109+ entry = metadata_cache [0 ]
110+ if isinstance (entry , dict ):
111+ return entry
112+ return None
113+
114+
115+ def _get_configured_model_id (raw_request : Request ) -> str | None :
116+ """Return the configured model identifier from config or cache."""
117+
118+ config = getattr (raw_request .app .state , "server_config" , None )
119+ if config is not None :
120+ return getattr (config , "model_identifier" , getattr (config , "model_path" , None ))
121+
122+ cached = _get_cached_model_metadata (raw_request )
123+ if cached is not None :
124+ return cached .get ("id" )
125+ return None
126+
127+
51128# =============================================================================
52129# Critical/Monitoring Endpoints - Defined first to ensure priority matching
53130# =============================================================================
54131
55132
56133@router .get ("/health" , response_model = None )
57134async def health (raw_request : Request ) -> HealthCheckResponse | JSONResponse :
58- """
59- Health check endpoint - verifies handler initialization status.
135+ """Health check endpoint aware of JIT/auto-unload state."""
136+ handler_manager = getattr (raw_request .app .state , "handler_manager" , None )
137+ configured_model_id = _get_configured_model_id (raw_request )
138+
139+ if handler_manager is not None :
140+ handler = getattr (handler_manager , "current_handler" , None )
141+ if handler is not None :
142+ model_id = getattr (handler , "model_path" , configured_model_id or "unknown" )
143+ return HealthCheckResponse (
144+ status = HealthCheckStatus .OK , model_id = model_id , model_status = "initialized"
145+ )
146+ return HealthCheckResponse (
147+ status = HealthCheckStatus .OK ,
148+ model_id = configured_model_id ,
149+ model_status = "unloaded" ,
150+ )
60151
61- Returns 503 if handler is not initialized, 200 otherwise.
62- """
63152 handler = getattr (raw_request .app .state , "handler" , None )
64-
65153 if handler is None :
66154 # Handler not initialized - return 503 with degraded status
67155 return JSONResponse (
68156 status_code = HTTPStatus .SERVICE_UNAVAILABLE ,
69157 content = {"status" : "unhealthy" , "model_id" : None , "model_status" : "uninitialized" },
70158 )
71159
72- # Handler initialized - extract model_id
73- model_id = getattr (handler , "model_path" , "unknown" )
74-
160+ model_id = getattr (handler , "model_path" , configured_model_id or "unknown" )
75161 return HealthCheckResponse (
76162 status = HealthCheckStatus .OK , model_id = model_id , model_status = "initialized"
77163 )
@@ -101,8 +187,14 @@ async def models(raw_request: Request) -> ModelsResponse | JSONResponse:
101187 status_code = HTTPStatus .INTERNAL_SERVER_ERROR ,
102188 )
103189
190+ cached_metadata = _get_cached_model_metadata (raw_request )
191+ if cached_metadata is not None :
192+ return ModelsResponse (object = "list" , data = [Model (** cached_metadata )])
193+
104194 # Fallback to handler (Phase 0 compatibility)
105- handler = getattr (raw_request .app .state , "handler" , None )
195+ handler , error = await _get_handler_or_error (raw_request , "models" )
196+ if error is not None :
197+ return error
106198 if handler is None :
107199 return JSONResponse (
108200 content = create_error_response (
@@ -138,7 +230,9 @@ async def queue_stats(raw_request: Request) -> dict[str, Any] | JSONResponse:
138230 Note: queue_stats shape is handler-dependent (Flux vs LM/VLM/Whisper)
139231 so callers know keys may vary.
140232 """
141- handler = getattr (raw_request .app .state , "handler" , None )
233+ handler , error = await _get_handler_or_error (raw_request , "queue_stats" )
234+ if error is not None :
235+ return error
142236 if handler is None :
143237 return JSONResponse (
144238 content = create_error_response (
@@ -173,7 +267,9 @@ async def chat_completions(
173267 request : ChatCompletionRequest , raw_request : Request
174268) -> ChatCompletionResponse | StreamingResponse | JSONResponse :
175269 """Handle chat completion requests."""
176- handler = getattr (raw_request .app .state , "handler" , None )
270+ handler , error = await _get_handler_or_error (raw_request , "chat_completions" )
271+ if error is not None :
272+ return error
177273 if handler is None :
178274 return JSONResponse (
179275 content = create_error_response (
@@ -213,7 +309,9 @@ async def embeddings(
213309 request : EmbeddingRequest , raw_request : Request
214310) -> EmbeddingResponse | JSONResponse :
215311 """Handle embedding requests."""
216- handler = getattr (raw_request .app .state , "handler" , None )
312+ handler , error = await _get_handler_or_error (raw_request , "embeddings" )
313+ if error is not None :
314+ return error
217315 if handler is None :
218316 return JSONResponse (
219317 content = create_error_response (
@@ -224,6 +322,16 @@ async def embeddings(
224322 status_code = HTTPStatus .SERVICE_UNAVAILABLE ,
225323 )
226324
325+ if not isinstance (handler , MLXEmbeddingsHandler ):
326+ return JSONResponse (
327+ content = create_error_response (
328+ "Embedding requests require an embeddings model. Use --model-type embeddings." ,
329+ "unsupported_request" ,
330+ HTTPStatus .BAD_REQUEST ,
331+ ),
332+ status_code = HTTPStatus .BAD_REQUEST ,
333+ )
334+
227335 try :
228336 embeddings = await handler .generate_embeddings_response (request )
229337 return create_response_embeddings (embeddings , request .model , request .encoding_format )
@@ -241,7 +349,9 @@ async def image_generations(
241349 request : ImageGenerationRequest , raw_request : Request
242350) -> ImageGenerationResponse | JSONResponse :
243351 """Handle image generation requests."""
244- handler = getattr (raw_request .app .state , "handler" , None )
352+ handler , error = await _get_handler_or_error (raw_request , "image_generation" )
353+ if error is not None :
354+ return error
245355 if handler is None :
246356 return JSONResponse (
247357 content = create_error_response (
@@ -281,7 +391,9 @@ async def create_image_edit(
281391 request : Annotated [ImageEditRequest , Form ()], raw_request : Request
282392) -> ImageEditResponse | JSONResponse :
283393 """Handle image editing requests with dynamic provider routing."""
284- handler = getattr (raw_request .app .state , "handler" , None )
394+ handler , error = await _get_handler_or_error (raw_request , "image_edit" )
395+ if error is not None :
396+ return error
285397 if handler is None :
286398 return JSONResponse (
287399 content = create_error_response (
@@ -320,20 +432,32 @@ async def create_audio_transcriptions(
320432 request : Annotated [TranscriptionRequest , Form ()], raw_request : Request
321433) -> StreamingResponse | TranscriptionResponse | JSONResponse | str :
322434 """Handle audio transcription requests."""
323- try :
324- handler = getattr (raw_request .app .state , "handler" , None )
325- if handler is None :
326- return JSONResponse (
327- content = create_error_response (
328- "Model handler not initialized" ,
329- "service_unavailable" ,
330- HTTPStatus .SERVICE_UNAVAILABLE ,
331- ),
332- status_code = HTTPStatus .SERVICE_UNAVAILABLE ,
333- )
435+ handler , error = await _get_handler_or_error (raw_request , "audio_transcriptions" )
436+ if error is not None :
437+ return error
438+ if handler is None :
439+ return JSONResponse (
440+ content = create_error_response (
441+ "Model handler not initialized" ,
442+ "service_unavailable" ,
443+ HTTPStatus .SERVICE_UNAVAILABLE ,
444+ ),
445+ status_code = HTTPStatus .SERVICE_UNAVAILABLE ,
446+ )
447+
448+ if not isinstance (handler , MLXWhisperHandler ):
449+ return JSONResponse (
450+ content = create_error_response (
451+ "Audio transcription requests require a whisper model. Use --model-type whisper." ,
452+ "unsupported_request" ,
453+ HTTPStatus .BAD_REQUEST ,
454+ ),
455+ status_code = HTTPStatus .BAD_REQUEST ,
456+ )
334457
458+ try :
335459 if request .stream :
336- # procoess the request before sending to the handler
460+ # process the request before sending to the handler
337461 request_data = await handler .prepare_transcription_request (request )
338462 return StreamingResponse (
339463 handler .generate_transcription_stream_from_data (request_data ),
@@ -401,6 +525,7 @@ def create_response_chunk(
401525 chat_id : str | None = None ,
402526 created_time : int | None = None ,
403527 request_id : str | None = None ,
528+ tool_call_id : str | None = None ,
404529) -> ChatCompletionChunk :
405530 """Create a formatted response chunk for streaming."""
406531 chat_id = chat_id or get_id ()
@@ -474,8 +599,12 @@ def create_response_chunk(
474599 if function_call :
475600 # Validate index exists before accessing
476601 tool_index = chunk .get ("index" , 0 )
602+ tool_identifier = tool_call_id or get_tool_call_id ()
477603 tool_chunk = ChoiceDeltaToolCall (
478- index = tool_index , type = "function" , id = get_tool_call_id (), function = function_call
604+ index = tool_index ,
605+ type = "function" ,
606+ id = tool_identifier ,
607+ function = function_call ,
479608 )
480609
481610 delta = Delta (content = None , role = "assistant" , tool_calls = [tool_chunk ]) # type: ignore[call-arg]
@@ -509,7 +638,9 @@ async def handle_stream_response(
509638 chat_index = get_id ()
510639 created_time = int (time .time ())
511640 finish_reason = "stop"
512- tool_call_index = - 1
641+ next_tool_call_index = 0
642+ current_implicit_tool_index : int | None = None
643+ tool_call_ids : dict [int , str ] = {}
513644 usage_info = None
514645
515646 try :
@@ -546,19 +677,48 @@ async def handle_stream_response(
546677
547678 # Handle tool call chunks
548679 payload = dict (chunk ) # Create a copy to avoid mutating the original
549- if payload .get ("name" ):
680+ current_tool_id = None
681+
682+ has_name = bool (payload .get ("name" ))
683+ has_arguments = "arguments" in payload
684+ payload_index = payload .get ("index" )
685+
686+ if has_name :
550687 finish_reason = "tool_calls"
551- tool_call_index += 1
552- payload ["index" ] = tool_call_index
553- elif payload .get ("arguments" ) and "index" not in payload :
554- payload ["index" ] = tool_call_index
688+ if payload_index is None :
689+ if current_implicit_tool_index is not None :
690+ payload_index = current_implicit_tool_index
691+ else :
692+ payload_index = next_tool_call_index
693+ next_tool_call_index += 1
694+ payload ["index" ] = payload_index
695+ current_implicit_tool_index = payload_index
696+ # Keep the implicit index available for additional argument chunks
697+ elif has_arguments :
698+ if payload_index is None :
699+ if current_implicit_tool_index is not None :
700+ payload_index = current_implicit_tool_index
701+ else :
702+ payload_index = next_tool_call_index
703+ next_tool_call_index += 1
704+ payload ["index" ] = payload_index
705+ current_implicit_tool_index = payload_index
706+ elif payload_index is not None :
707+ current_implicit_tool_index = payload_index
708+
709+ payload_index = payload .get ("index" )
710+ if payload_index is not None :
711+ if payload_index not in tool_call_ids :
712+ tool_call_ids [payload_index ] = get_tool_call_id ()
713+ current_tool_id = tool_call_ids [payload_index ]
555714
556715 response_chunk = create_response_chunk (
557716 payload ,
558717 model ,
559718 chat_id = chat_index ,
560719 created_time = created_time ,
561720 request_id = request_id ,
721+ tool_call_id = current_tool_id ,
562722 )
563723 yield _yield_sse_chunk (response_chunk )
564724
0 commit comments