Skip to content

Commit 251b034

Browse files
authored
Add support for json_schema parameter (#7)
1 parent d8e1ab9 commit 251b034

File tree

6 files changed

+127
-15
lines changed

6 files changed

+127
-15
lines changed

ai_server/server.py

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import base64
44
import glob
5+
import json
56
import os
67
import subprocess
78
from typing import Optional
@@ -32,6 +33,7 @@
3233
if _llama_server_url and not _llama_server_url.startswith(('http://', 'https://'))
3334
else _llama_server_url
3435
)
36+
SCHEMA_KEY = "schema"
3537

3638

3739
def _build_messages(content: str, system_prompt: Optional[str] = None, image_files: Optional[list] = None) -> list:
@@ -52,17 +54,21 @@ def chat_with_llama_server_http(
5254
system_prompt: Optional[str] = None,
5355
timeout: int = 300,
5456
image_files: Optional[list] = None,
57+
json_schema: Optional[dict] = None,
5558
) -> str:
5659
"""Handle chat using llama-server HTTP API."""
5760
if not LLAMA_SERVER_URL:
5861
raise Exception("LLAMA_SERVER_URL environment variable not set")
5962

6063
try:
6164
messages = _build_messages(content, system_prompt, image_files=[]) # TODO: Pass image files
65+
payload = {'model': model, 'messages': messages, 'stream': False, 'max_tokens': 512}
66+
if json_schema:
67+
payload['json_schema'] = json_schema[SCHEMA_KEY]
6268

6369
response = requests.post(
6470
f'{LLAMA_SERVER_URL}/v1/chat/completions',
65-
json={'model': model, 'messages': messages, 'stream': False, 'max_tokens': 512},
71+
json=payload,
6672
headers={'Content-Type': 'application/json'},
6773
timeout=timeout,
6874
)
@@ -95,12 +101,21 @@ def is_llamacpp_available(model: str) -> bool:
95101

96102

97103
def chat_with_ollama(
98-
model: str, content: str, system_prompt: Optional[str] = None, image_files: Optional[list] = None
104+
model: str,
105+
content: str,
106+
system_prompt: Optional[str] = None,
107+
image_files: Optional[list] = None,
108+
json_schema: Optional[dict] = None,
99109
) -> str:
100110
"""Handle chat using ollama."""
101111
messages = _build_messages(content, system_prompt, image_files)
102112

103-
response = ollama.chat(model=model, messages=messages, stream=False)
113+
response = ollama.chat(
114+
model=model,
115+
messages=messages,
116+
stream=False,
117+
format=json_schema[SCHEMA_KEY] if json_schema else None,
118+
)
104119
return response.message.content
105120

106121

@@ -110,6 +125,7 @@ def chat_with_llamacpp(
110125
system_prompt: Optional[str] = None,
111126
timeout: int = 300,
112127
image_files: Optional[list] = None,
128+
json_schema: Optional[dict] = None,
113129
) -> str:
114130
"""Handle chat using llama.cpp CLI."""
115131
model_path = resolve_model_path(model)
@@ -118,6 +134,9 @@ def chat_with_llamacpp(
118134
raise ValueError(f"Model not found: {model}")
119135

120136
cmd = [LLAMA_CPP_CLI, '-m', model_path, '--n-gpu-layers', '40', '-p', content, '-n', '512', '--single-turn']
137+
if json_schema:
138+
raw_schema = json_schema[SCHEMA_KEY] if SCHEMA_KEY in json_schema else json_schema
139+
cmd += ["--json-schema", json.dumps(raw_schema)]
121140

122141
# Add system prompt if provided
123142
if system_prompt:
@@ -152,20 +171,27 @@ def chat_with_model(
152171
llama_mode: str = "cli",
153172
system_prompt: Optional[str] = None,
154173
image_files: Optional[list] = None,
174+
json_schema: Optional[dict] = None,
155175
) -> str:
156176
"""Route chat request based on llama_mode: server (external), cli, or ollama fallback; and with optional system prompt."""
157177
if is_llamacpp_available(model):
158178
if llama_mode == "server":
159179
if not LLAMA_SERVER_URL:
160180
raise Exception("LLAMA_SERVER_URL environment variable not set for server mode")
161-
return chat_with_llama_server_http(model, content, system_prompt=system_prompt, image_files=image_files)
181+
return chat_with_llama_server_http(
182+
model, content, system_prompt=system_prompt, image_files=image_files, json_schema=json_schema
183+
)
162184
elif llama_mode == "cli":
163-
return chat_with_llamacpp(model, content, system_prompt=system_prompt, image_files=image_files)
185+
return chat_with_llamacpp(
186+
model, content, system_prompt=system_prompt, image_files=image_files, json_schema=json_schema
187+
)
164188
else:
165189
raise ValueError(f"Invalid llama_mode: '{llama_mode}'. Valid options are 'server' or 'cli'.")
166190
else:
167191
# Model not available in llama.cpp, use ollama
168-
return chat_with_ollama(model, content, system_prompt=system_prompt, image_files=image_files)
192+
return chat_with_ollama(
193+
model, content, system_prompt=system_prompt, image_files=image_files, json_schema=json_schema
194+
)
169195

170196

171197
def authenticate() -> str:
@@ -190,11 +216,14 @@ def chat():
190216
llama_mode = request.form.get('llama_mode', 'cli')
191217
system_prompt = request.form.get('system_prompt')
192218
image_files = list(request.files.values())
219+
json_schema = request.form.get('json_schema')
220+
if json_schema:
221+
json_schema = json.loads(json_schema)
193222

194223
if not content.strip():
195224
abort(400, description='Missing prompt content')
196225

197-
response_content = chat_with_model(model, content, llama_mode, system_prompt, image_files)
226+
response_content = chat_with_model(model, content, llama_mode, system_prompt, image_files, json_schema)
198227
return jsonify(response_content)
199228

200229

test/test_cli_mode.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,9 @@ def test_cli_mode_uses_llamacpp_when_available(self):
108108

109109
assert result == "CLI response from DeepSeek V3"
110110
self.mock_available.assert_called_once_with(TEST_LLAMACPP_MODEL)
111-
self.mock_chat_llamacpp.assert_called_once_with(TEST_LLAMACPP_MODEL, 'Write a function', system_prompt=None, image_files=None)
111+
self.mock_chat_llamacpp.assert_called_once_with(
112+
TEST_LLAMACPP_MODEL, 'Write a function', system_prompt=None, image_files=None, json_schema=None
113+
)
112114

113115
def test_cli_mode_fallback_to_ollama_when_unavailable(self):
114116
"""Test CLI mode falls back to ollama when model not available in llama.cpp."""
@@ -119,7 +121,9 @@ def test_cli_mode_fallback_to_ollama_when_unavailable(self):
119121

120122
assert result == "Ollama response from DeepSeek Coder"
121123
self.mock_available.assert_called_once_with(TEST_OLLAMA_MODEL)
122-
self.mock_chat_ollama.assert_called_once_with(TEST_OLLAMA_MODEL, 'Help with coding', system_prompt=None, image_files=None)
124+
self.mock_chat_ollama.assert_called_once_with(
125+
TEST_OLLAMA_MODEL, 'Help with coding', system_prompt=None, image_files=None, json_schema=None
126+
)
123127

124128
def test_default_mode_is_cli(self):
125129
"""Test that default mode is CLI when no llama_mode specified."""
@@ -130,7 +134,9 @@ def test_default_mode_is_cli(self):
130134

131135
assert result == "Default CLI mode response"
132136
self.mock_available.assert_called_once_with(TEST_LLAMACPP_MODEL)
133-
self.mock_chat_llamacpp.assert_called_once_with(TEST_LLAMACPP_MODEL, 'Help me', system_prompt=None, image_files=None)
137+
self.mock_chat_llamacpp.assert_called_once_with(
138+
TEST_LLAMACPP_MODEL, 'Help me', system_prompt=None, image_files=None, json_schema=None
139+
)
134140

135141

136142
class TestCLIModeIntegration:
@@ -167,3 +173,24 @@ def test_complete_cli_fallback_flow_to_ollama(self, mock_glob, mock_ollama):
167173
assert result == "Ollama CLI fallback integration test successful!"
168174
mock_glob.assert_called_once_with(f'/data1/GGUF/{TEST_OLLAMA_MODEL}/*.gguf')
169175
mock_ollama.assert_called_once()
176+
177+
def test_cli_mode_passes_json_schema_to_ollama(self, tmp_path):
178+
"""
179+
When json_schema is supplied, chat_with_model should forward the parsed
180+
schema (as a dict) to chat_with_ollama.
181+
"""
182+
test_schema = {"schema": {"type": "object", "properties": {"answer": {"type": "string"}}}}
183+
184+
# Prepare mocks
185+
with patch('ai_server.server.is_llamacpp_available', return_value=False), patch(
186+
'ai_server.server.chat_with_ollama'
187+
) as mock_ollama:
188+
mock_ollama.return_value = "schema-aware response"
189+
190+
result = chat_with_model(TEST_OLLAMA_MODEL, "Give me an answer", llama_mode='cli', json_schema=test_schema)
191+
192+
assert result == "schema-aware response"
193+
194+
mock_ollama.assert_called_once_with(
195+
TEST_OLLAMA_MODEL, "Give me an answer", system_prompt=None, image_files=None, json_schema=test_schema
196+
)

test/test_core.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ def test_chat_with_ollama_success(self, mock_ollama):
8383
model=TEST_OLLAMA_MODEL,
8484
messages=[{'role': 'user', 'content': 'Help me write a Python function'}],
8585
stream=False,
86+
format=None,
8687
)
8788

8889
def test_chat_with_ollama_service_unavailable(self, mock_ollama):
@@ -98,3 +99,23 @@ def test_chat_with_ollama_model_not_found(self, mock_ollama):
9899

99100
with pytest.raises(Exception, match="model 'nonexistent:latest' not found"):
100101
chat_with_ollama('nonexistent:latest', 'Hello')
102+
103+
def test_chat_with_ollama_with_json_schema(self, mock_ollama, tmp_path):
104+
"""Ollama chat should forward the JSON schema (format=…) when provided."""
105+
# Fake schema file
106+
test_schema = {"schema": {"type": "object", "properties": {"answer": {"type": "string"}}}}
107+
108+
# Mock ollama response
109+
mock_response = MagicMock()
110+
mock_response.message.content = "42"
111+
mock_ollama.return_value = mock_response
112+
113+
result = chat_with_ollama(TEST_OLLAMA_MODEL, "What is the meaning of life?", json_schema=test_schema)
114+
115+
assert result == "42"
116+
mock_ollama.assert_called_once_with(
117+
model=TEST_OLLAMA_MODEL,
118+
messages=[{"role": "user", "content": "What is the meaning of life?"}],
119+
stream=False,
120+
format={"type": "object", "properties": {"answer": {"type": "string"}}},
121+
)

test/test_server_mode.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,9 @@ def test_server_mode_uses_llamacpp_when_available(self):
111111

112112
assert result == "Server response from DeepSeek V3"
113113
self.mock_available.assert_called_once_with(TEST_LLAMACPP_MODEL)
114-
self.mock_chat_server.assert_called_once_with(TEST_LLAMACPP_MODEL, 'Explain code', system_prompt=None, image_files=None)
114+
self.mock_chat_server.assert_called_once_with(
115+
TEST_LLAMACPP_MODEL, 'Explain code', system_prompt=None, image_files=None, json_schema=None
116+
)
115117

116118
def test_server_mode_fallback_to_ollama_when_unavailable(self):
117119
"""Test server mode falls back to ollama when model not available in llama.cpp."""
@@ -122,7 +124,9 @@ def test_server_mode_fallback_to_ollama_when_unavailable(self):
122124

123125
assert result == "Ollama fallback response"
124126
self.mock_available.assert_called_once_with(TEST_OLLAMA_MODEL)
125-
self.mock_chat_ollama.assert_called_once_with(TEST_OLLAMA_MODEL, 'Debug code', system_prompt=None, image_files=None)
127+
self.mock_chat_ollama.assert_called_once_with(
128+
TEST_OLLAMA_MODEL, 'Debug code', system_prompt=None, image_files=None, json_schema=None
129+
)
126130

127131
def test_server_mode_requires_server_url(self):
128132
"""Test server mode requires LLAMA_SERVER_URL to be set."""
@@ -178,3 +182,32 @@ def test_complete_server_fallback_flow_to_ollama(self, mock_glob, mock_ollama, m
178182
assert result == "Ollama server fallback integration test successful!"
179183
mock_glob.assert_called_once_with(f'/data1/GGUF/{TEST_OLLAMA_MODEL}/*.gguf')
180184
mock_ollama.assert_called_once()
185+
186+
def test_server_mode_passes_json_schema_to_llama_server(self, tmp_path, mock_requests_post, mock_llama_server_url):
187+
"""
188+
chat_with_model (server mode) should forward a json_schema file path
189+
and llama-server should receive the parsed schema in its JSON body.
190+
"""
191+
test_schema = {"schema": {"type": "object", "properties": {"answer": {"type": "string"}}}}
192+
193+
with patch('ai_server.server.is_llamacpp_available', return_value=True):
194+
mock_response = MagicMock()
195+
mock_response.status_code = 200
196+
mock_response.json.return_value = {"choices": [{"message": {"content": "Schema-aware server reply"}}]}
197+
mock_requests_post.return_value = mock_response
198+
199+
result = chat_with_model(
200+
TEST_LLAMACPP_MODEL, "Give me an answer", llama_mode="server", json_schema=test_schema
201+
)
202+
203+
assert result == "Schema-aware server reply"
204+
205+
# Verify POST call
206+
mock_requests_post.assert_called_once()
207+
args, kwargs = mock_requests_post.call_args
208+
assert args[0] == "http://localhost:8080/v1/chat/completions"
209+
210+
body = kwargs["json"]
211+
assert body["model"] == TEST_LLAMACPP_MODEL
212+
assert body["messages"][0]["content"] == "Give me an answer"
213+
assert body["json_schema"] == {"type": "object", "properties": {"answer": {"type": "string"}}}

test/test_system_prompt.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,4 +77,6 @@ def test_chat_with_model_routing(self, mock_available, mock_chat):
7777
mock_chat.return_value = "result"
7878

7979
chat_with_model(TEST_MODEL, TEST_USER_CONTENT, 'cli', TEST_SYSTEM_PROMPT)
80-
mock_chat.assert_called_once_with(TEST_MODEL, TEST_USER_CONTENT, system_prompt=TEST_SYSTEM_PROMPT, image_files=None)
80+
mock_chat.assert_called_once_with(
81+
TEST_MODEL, TEST_USER_CONTENT, system_prompt=TEST_SYSTEM_PROMPT, image_files=None, json_schema=None
82+
)

test/test_system_prompt_api.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def test_api_with_system_prompt(self, mock_chat, mock_redis, client):
3939

4040
assert response.status_code == 200
4141

42-
mock_chat.assert_called_once_with(TEST_MODEL, TEST_USER_CONTENT, 'cli', TEST_SYSTEM_PROMPT, [])
42+
mock_chat.assert_called_once_with(TEST_MODEL, TEST_USER_CONTENT, 'cli', TEST_SYSTEM_PROMPT, [], None)
4343

4444
@patch('ai_server.server.REDIS_CONNECTION')
4545
@patch('ai_server.server.chat_with_model')
@@ -54,7 +54,7 @@ def test_api_without_system_prompt(self, mock_chat, mock_redis, client):
5454

5555
assert response.status_code == 200
5656

57-
mock_chat.assert_called_once_with(TEST_MODEL, TEST_USER_CONTENT, 'cli', None, [])
57+
mock_chat.assert_called_once_with(TEST_MODEL, TEST_USER_CONTENT, 'cli', None, [], None)
5858

5959
@patch('ai_server.server.REDIS_CONNECTION')
6060
def test_api_authentication_still_required(self, mock_redis, client):

0 commit comments

Comments
 (0)