Skip to content

Commit 47aaa0e

Browse files
committed
feat: added response_format support for llm_inference_api + reduced base_inference_api spam for requests
1 parent e96e332 commit 47aaa0e

File tree

5 files changed

+215
-7
lines changed

5 files changed

+215
-7
lines changed

extensions/business/edge_inference_api/base_inference_api.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,8 @@
8585
"REQUEST_TIMEOUT": 600, # 10 minutes
8686
"SAVE_PERIOD": 300, # 5 minutes
8787

88+
"LOG_REQUESTS_STATUS_EVERY_SECONDS": 5, # log pending request status every 5 seconds
89+
8890
"REQUEST_TTL_SECONDS": 60 * 60 * 2, # keep historical results for 2 hours
8991
"RATE_LIMIT_PER_MINUTE": 5,
9092
"AUTH_TOKEN_ENV": "INFERENCE_API_TOKEN",
@@ -129,6 +131,7 @@ def on_init(self):
129131
self.P(err_msg)
130132
raise ValueError(err_msg)
131133
# endif AI_ENGINE not specified
134+
self._request_last_log_time: Dict[str, float] = {}
132135
self._requests: Dict[str, Dict[str, Any]] = {}
133136
self._api_errors: Dict[str, Dict[str, Any]] = {}
134137
# TODO: add inference metrics tracking (latency, tokens, etc)
@@ -569,7 +572,11 @@ def solve_postponed_request(self, request_id: str):
569572
Request result when completed or failed, or a PostponedRequest for pending work.
570573
"""
571574
if request_id in self._requests:
572-
self.Pd(f"Checking status of request ID {request_id}...")
575+
last_logged_status = self._request_last_log_time.get(request_id, 0)
576+
if (self.time() - last_logged_status) > self.cfg_log_requests_status_every_seconds:
577+
self.Pd(f"Checking status of request ID {request_id}...")
578+
self._request_last_log_time[request_id] = self.time()
579+
# endif logging status
573580
request_data = self._requests[request_id]
574581

575582
self.maybe_mark_request_timeout(request_id=request_id, request_data=request_data)

extensions/business/edge_inference_api/llm_inference_api.py

Lines changed: 138 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@
9090
from extensions.business.edge_inference_api.base_inference_api import BaseInferenceApiPlugin as BasePlugin
9191
from extensions.serving.mixins_llm.llm_utils import LlmCT
9292

93-
from typing import Any, Dict, List, Optional
93+
from typing import Any, Dict, List, Optional, Tuple
9494

9595

9696
_CONFIG = {
@@ -142,11 +142,117 @@ def check_messages(self, messages: list[dict]):
142142
return f"Message {idx} content must be a non-empty string."
143143
return None
144144

145+
def check_and_normalize_response_format(self, response_format) -> Tuple[Optional[Dict[str, Any]], str]:
146+
"""
147+
Validate and normalize the value received for response_format in an inference request.
148+
149+
Supported inputs:
150+
- None -> returns None
151+
- dict -> validates and normalizes
152+
- str -> if it looks like JSON, parses to dict then validates
153+
154+
Accepted forms:
155+
A) llama-cpp-python documented forms:
156+
- {"type": "json_object"}
157+
- {"type": "json_object", "schema": {<json-schema>}}
158+
159+
B) llama.cpp server alternative form:
160+
- {"type": "json_schema", "json_schema": {"schema": {<json-schema>}, ...}}
161+
162+
Parameters
163+
----------
164+
response_format : dict or str, optional
165+
Controls structured output constraints for the model response.
166+
167+
Returns
168+
-------
169+
(result, err_msg), where
170+
result: dict or None
171+
The normalized value of response_format
172+
err_msg: str or None
173+
Error message when validation fails, otherwise None.
174+
"""
175+
result = None
176+
err_msg = ""
177+
if not response_format:
178+
return result, err_msg
179+
# endif response_format not provided
180+
181+
# Accept JSON string input (common in REST APIs).
182+
if isinstance(response_format, str):
183+
s = response_format.strip()
184+
if not s:
185+
# Treat empty string as "no response_format"
186+
return result, err_msg
187+
try:
188+
response_format = self.json.loads(s)
189+
except Exception as e:
190+
err_msg = f"response_format is a string but not valid JSON: {e}"
191+
return result, err_msg
192+
# endif response_format received as string
193+
194+
if not isinstance(response_format, dict):
195+
err_msg = f"response_format expected as a dict, but received {type(response_format)} instead."
196+
return result, err_msg
197+
# endif invalid type
198+
199+
if "type" not in response_format:
200+
err_msg = f"response_format missing required key 'type'."
201+
return result, err_msg
202+
fmt_type = response_format['type']
203+
if not isinstance(fmt_type, str):
204+
err_msg = f"Key 'type' from response_format should be a string, but received {type(fmt_type)} instead."
205+
return result, err_msg
206+
# endif type checking
207+
208+
def _check_schema(_schema: Any, where: str):
209+
if _schema is None:
210+
return None, ""
211+
if not isinstance(_schema, dict):
212+
return None, f"{where} from response_format must be an object/dict (JSON Schema) if provided, but received{type(_schema)} instead."
213+
try:
214+
# Check if schema is JSON-serializable
215+
self.json.dumps(_schema)
216+
except Exception as e:
217+
return None, f"{where} from response_format must be JSON-serializable if provided: {e}"
218+
return _schema, ""
219+
# enddef _check_schema
220+
221+
fmt_type = fmt_type.strip()
222+
if fmt_type == 'json_object':
223+
schema = response_format.get('schema')
224+
schema, err_msg = _check_schema(_schema=schema, where="'schema'")
225+
result = {"type": "json_object"}
226+
if schema is not None:
227+
result["schema"] = schema
228+
elif fmt_type == 'json_schema':
229+
# Check for both 'json_schema' and 'schema'
230+
schema = response_format.get('json_schema')
231+
if schema is not None:
232+
if not isinstance(schema, dict):
233+
err_msg = f"'json_schema' from response_format must be an object/dict."
234+
return result, err_msg
235+
schema = schema.get('schema')
236+
schema, err_msg = _check_schema(_schema=schema, where="'json_schema.schema'")
237+
else:
238+
schema = response_format.get('schema')
239+
schema, err_msg = _check_schema(_schema=schema, where="'schema'")
240+
# endif json_schema specified
241+
if schema is None:
242+
err_msg = "json_schema response_format requires a schema (missing 'schema'/'json_schema.schema')"
243+
# Here, "json_object" is put as type, since it is the more reliable one according to llama.cpp documentation.
244+
result = {"type": "json_object", "schema": schema}
245+
# endif response_format type
246+
if result is None and not err_msg:
247+
err_msg = f"Unsupported response_format type '{fmt_type}'. Supported: 'json_object', 'json_schema'."
248+
return result, err_msg
249+
145250
def check_generation_params(
146251
self,
147252
temperature: float,
148253
max_tokens: int,
149254
top_p: float = 1.0,
255+
response_format: Any = None,
150256
**kwargs
151257
):
152258
"""
@@ -160,6 +266,8 @@ def check_generation_params(
160266
Maximum number of tokens to generate.
161267
top_p : float, optional
162268
Nucleus sampling cutoff between 0 and 1.
269+
response_format : dict or None, optional
270+
Controls structured output constraints for the model response.
163271
**kwargs
164272
Additional unused parameters.
165273
@@ -180,7 +288,8 @@ def check_generation_params(
180288
)
181289
if not 0 < top_p <= 1:
182290
return "top_p must be between 0 and 1."
183-
return None
291+
_, err_msg = self.check_and_normalize_response_format(response_format=response_format)
292+
return err_msg
184293

185294
def normalize_messages(self, messages: List[Dict[str, Any]]):
186295
"""
@@ -215,6 +324,7 @@ def predict(
215324
max_tokens: int = 512,
216325
top_p: float = 1.0,
217326
repeat_penalty: Optional[float] = 1.0,
327+
response_format: Optional[Dict[str, Any]] = None,
218328
metadata: Optional[Dict[str, Any]] = None,
219329
authorization: Optional[str] = None,
220330
**kwargs
@@ -234,6 +344,8 @@ def predict(
234344
Nucleus sampling probability threshold.
235345
repeat_penalty : float or None, optional
236346
Penalty for repeated tokens if supported by the backend.
347+
response_format : dict or None, optional
348+
Controls structured output constraints for the model response.
237349
metadata : dict or None, optional
238350
Additional metadata to store with the request.
239351
authorization : str or None, optional
@@ -252,6 +364,7 @@ def predict(
252364
max_tokens=max_tokens,
253365
top_p=top_p,
254366
repeat_penalty=repeat_penalty,
367+
response_format=response_format,
255368
metadata=metadata,
256369
authorization=authorization,
257370
**kwargs
@@ -265,6 +378,7 @@ def predict_async(
265378
max_tokens: int = 512,
266379
top_p: float = 1.0,
267380
repeat_penalty: float = 1.0,
381+
response_format: Optional[Dict[str, Any]] = None,
268382
metadata: Optional[Dict[str, Any]] = None,
269383
authorization: Optional[str] = None,
270384
**kwargs
@@ -284,6 +398,8 @@ def predict_async(
284398
Nucleus sampling probability threshold.
285399
repeat_penalty : float, optional
286400
Penalty for repeated tokens if supported by the backend.
401+
response_format : dict or None, optional
402+
Controls structured output constraints for the model response.
287403
metadata : dict or None, optional
288404
Additional metadata to store with the request.
289405
authorization : str or None, optional
@@ -302,6 +418,7 @@ def predict_async(
302418
max_tokens=max_tokens,
303419
top_p=top_p,
304420
repeat_penalty=repeat_penalty,
421+
response_format=response_format,
305422
metadata=metadata,
306423
authorization=authorization,
307424
**kwargs
@@ -315,6 +432,7 @@ def create_chat_completion(
315432
max_tokens: int = 512,
316433
top_p: float = 1.0,
317434
repeat_penalty: Optional[float] = 1.0,
435+
response_format: Optional[Dict[str, Any]] = None,
318436
metadata: Optional[Dict[str, Any]] = None,
319437
authorization: Optional[str] = None,
320438
**kwargs
@@ -334,6 +452,8 @@ def create_chat_completion(
334452
Nucleus sampling probability threshold.
335453
repeat_penalty : float or None, optional
336454
Penalty for repeated tokens if supported by the backend.
455+
response_format : dict or None, optional
456+
Controls structured output constraints for the model response.
337457
metadata : dict or None, optional
338458
Additional metadata to store with the request.
339459
authorization : str or None, optional
@@ -352,6 +472,7 @@ def create_chat_completion(
352472
max_tokens=max_tokens,
353473
top_p=top_p,
354474
repeat_penalty=repeat_penalty,
475+
response_format=response_format,
355476
metadata=metadata,
356477
authorization=authorization,
357478
**kwargs
@@ -365,6 +486,7 @@ def create_chat_completion_async(
365486
max_tokens: int = 512,
366487
top_p: float = 1.0,
367488
repeat_penalty: float = 1.0,
489+
response_format: Optional[Dict[str, Any]] = None,
368490
metadata: Optional[Dict[str, Any]] = None,
369491
authorization: Optional[str] = None,
370492
**kwargs
@@ -384,6 +506,8 @@ def create_chat_completion_async(
384506
Nucleus sampling probability threshold.
385507
repeat_penalty : float, optional
386508
Penalty for repeated tokens if supported by the backend.
509+
response_format : dict or None, optional
510+
Controls structured output constraints for the model response.
387511
metadata : dict or None, optional
388512
Additional metadata to store with the request.
389513
authorization : str or None, optional
@@ -402,6 +526,7 @@ def create_chat_completion_async(
402526
max_tokens=max_tokens,
403527
top_p=top_p,
404528
repeat_penalty=repeat_penalty,
529+
response_format=response_format,
405530
metadata=metadata,
406531
authorization=authorization,
407532
**kwargs
@@ -417,6 +542,7 @@ def check_predict_params(
417542
max_tokens: int,
418543
top_p: float = 1.0,
419544
repeat_penalty: float = 1.0,
545+
response_format: Optional[Dict[str, Any]] = None,
420546
**kwargs
421547
):
422548
"""
@@ -434,6 +560,8 @@ def check_predict_params(
434560
Nucleus sampling probability threshold.
435561
repeat_penalty : float, optional
436562
Penalty for repeated tokens if supported by the backend.
563+
response_format : dict or None, optional
564+
Controls structured output constraints for the model response.
437565
**kwargs
438566
Additional parameters not validated here.
439567
@@ -449,9 +577,10 @@ def check_predict_params(
449577
temperature=temperature,
450578
max_tokens=max_tokens,
451579
top_p=top_p,
580+
response_format=response_format,
452581
**kwargs
453582
)
454-
if err is not None:
583+
if err:
455584
return err
456585
return None
457586

@@ -462,6 +591,7 @@ def process_predict_params(
462591
max_tokens: int,
463592
top_p: float = 1.0,
464593
repeat_penalty: float = 1.0,
594+
response_format: Optional[Dict[str, Any]] = None,
465595
**kwargs
466596
):
467597
"""
@@ -479,6 +609,8 @@ def process_predict_params(
479609
Nucleus sampling probability threshold.
480610
repeat_penalty : float, optional
481611
Penalty for repeated tokens if supported by the backend.
612+
response_format : dict or None, optional
613+
Controls structured output constraints for the model response.
482614
**kwargs
483615
Additional parameters to include as-is.
484616
@@ -488,12 +620,15 @@ def process_predict_params(
488620
Processed parameters ready for dispatch.
489621
"""
490622
normalized_messages = self.normalize_messages(messages)
623+
# No need to capture err_msg here, already validated in check_predict_params
624+
response_format, _ = self.check_and_normalize_response_format(response_format=response_format)
491625
return {
492626
'messages': normalized_messages,
493627
'temperature': temperature,
494628
'max_tokens': max_tokens,
495629
'top_p': top_p,
496630
'repeat_penalty': repeat_penalty,
631+
'response_format': response_format,
497632
**kwargs
498633
}
499634

0 commit comments

Comments
 (0)