Skip to content

Commit 5577004

Browse files
committed
implemented chat completion and chat workspace methods with docstrings and tests
1 parent c5293d9 commit 5577004

File tree

3 files changed

+363
-3
lines changed

3 files changed

+363
-3
lines changed

meilisearch/_httprequests.py

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import json
44
from functools import lru_cache
5-
from typing import Any, Callable, List, Mapping, Optional, Sequence, Tuple, Type, Union
5+
from typing import Any, Callable, Iterator, List, Mapping, Optional, Sequence, Tuple, Type, Union
66

77
import requests
88

@@ -146,6 +146,72 @@ def delete(
146146
) -> Any:
147147
return self.send_request(requests.delete, path, body)
148148

149+
def post_stream(
150+
self,
151+
path: str,
152+
body: Optional[
153+
Union[Mapping[str, Any], Sequence[Mapping[str, Any]], List[str], bytes, str]
154+
] = None,
155+
content_type: Optional[str] = "application/json",
156+
*,
157+
serializer: Optional[Type[json.JSONEncoder]] = None,
158+
) -> requests.Response:
159+
"""Send a POST request with streaming enabled.
160+
161+
Returns the raw response object for streaming consumption.
162+
"""
163+
if content_type:
164+
self.headers["Content-Type"] = content_type
165+
try:
166+
request_path = self.config.url + "/" + path
167+
168+
if isinstance(body, bytes):
169+
response = requests.post(
170+
request_path,
171+
timeout=self.config.timeout,
172+
headers=self.headers,
173+
data=body,
174+
stream=True,
175+
)
176+
else:
177+
serialize_body = isinstance(body, dict) or body
178+
data = (
179+
json.dumps(body, cls=serializer)
180+
if isinstance(body, bool) or serialize_body
181+
else "" if body == "" else "null"
182+
)
183+
184+
response = requests.post(
185+
request_path,
186+
timeout=self.config.timeout,
187+
headers=self.headers,
188+
data=data,
189+
stream=True,
190+
)
191+
192+
# For streaming responses, we validate status but don't parse JSON
193+
if not response.ok:
194+
response.raise_for_status()
195+
196+
return response
197+
198+
except requests.exceptions.Timeout as err:
199+
raise MeilisearchTimeoutError(str(err)) from err
200+
except requests.exceptions.ConnectionError as err:
201+
raise MeilisearchCommunicationError(str(err)) from err
202+
except requests.exceptions.HTTPError as err:
203+
raise MeilisearchApiError(str(err), response) from err
204+
except requests.exceptions.InvalidSchema as err:
205+
if "://" not in self.config.url:
206+
raise MeilisearchCommunicationError(
207+
f"""
208+
Invalid URL {self.config.url}, no scheme/protocol supplied.
209+
Did you mean https://{self.config.url}?
210+
"""
211+
) from err
212+
213+
raise MeilisearchCommunicationError(str(err)) from err
214+
149215
@staticmethod
150216
def __to_json(request: requests.Response) -> Any:
151217
if request.content == b"":

meilisearch/client.py

Lines changed: 152 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,12 @@
88
import hmac
99
import json
1010
import re
11-
from typing import Any, Dict, List, Mapping, MutableMapping, Optional, Sequence, Tuple, Union
11+
from typing import Any, Dict, Iterator, List, Mapping, MutableMapping, Optional, Sequence, Tuple, Union
1212
from urllib import parse
1313

1414
from meilisearch._httprequests import HttpRequests
1515
from meilisearch.config import Config
16-
from meilisearch.errors import MeilisearchError
16+
from meilisearch.errors import MeilisearchApiError, MeilisearchCommunicationError, MeilisearchError
1717
from meilisearch.index import Index
1818
from meilisearch.models.key import Key, KeysResults
1919
from meilisearch.models.task import Batch, BatchResults, Task, TaskInfo, TaskResults
@@ -795,6 +795,156 @@ def get_all_networks(self) -> Dict[str, str]:
795795
"""
796796
return self.http.get(path=f"{self.config.paths.network}")
797797

798+
def create_chat_completion(
799+
self,
800+
workspace_uid: str,
801+
messages: List[Dict[str, str]],
802+
model: str = "gpt-3.5-turbo",
803+
stream: bool = True,
804+
) -> Iterator[Dict[str, Any]]:
805+
"""Streams a chat completion from the Meilisearch chat API.
806+
807+
Parameters
808+
----------
809+
workspace_uid:
810+
Unique identifier of the chat workspace to use.
811+
messages:
812+
List of message dicts (e.g. {"role": "user", "content": "..."}) comprising the chat history.
813+
model:
814+
The model name to use for completion (should correspond to the LLM in workspace settings).
815+
stream:
816+
Whether to stream the response. Must be True for now (only streaming is supported).
817+
818+
Returns
819+
-------
820+
chunks:
821+
Parsed chunks of the completion as Python dicts. Each chunk is a partial response (in OpenAI format).
822+
Iteration ends when the completion is done.
823+
824+
Raises
825+
------
826+
MeilisearchApiError
827+
An error containing details about why Meilisearch can't process your request. Meilisearch error codes are described here: https://www.meilisearch.com/docs/reference/errors/error_codes#meilisearch-errors
828+
MeilisearchCommunicationError
829+
If a network error occurs.
830+
ValueError
831+
If stream=False is passed (not currently supported).
832+
"""
833+
if not stream:
834+
# The API currently only supports streaming responses:
835+
raise ValueError("Non-streaming chat completions are not supported. Use stream=True.")
836+
837+
payload = {
838+
"model": model,
839+
"messages": messages,
840+
"stream": True
841+
}
842+
843+
# Construct the URL for the chat completions route.
844+
endpoint = f"chats/{workspace_uid}/chat/completions"
845+
846+
# Initiate the HTTP POST request in streaming mode.
847+
response = self.http.post_stream(endpoint, body=payload)
848+
849+
try:
850+
# Iterate over the streaming response lines
851+
for raw_line in response.iter_lines():
852+
if raw_line is None or raw_line == b'':
853+
continue
854+
855+
line = raw_line.decode('utf-8')
856+
if line.startswith("data: "):
857+
data = line[len("data: "):]
858+
if data.strip() == "[DONE]":
859+
break
860+
861+
try:
862+
chunk = json.loads(data)
863+
yield chunk
864+
except json.JSONDecodeError as e:
865+
866+
raise MeilisearchCommunicationError(f"Failed to parse chat chunk: {e}")
867+
finally:
868+
response.close()
869+
870+
def get_chat_workspaces(
871+
self,
872+
*,
873+
offset: Optional[int] = None,
874+
limit: Optional[int] = None,
875+
) -> Dict[str, Any]:
876+
"""Get all chat workspaces.
877+
878+
Parameters
879+
----------
880+
offset (optional):
881+
Number of workspaces to skip.
882+
limit (optional):
883+
Maximum number of workspaces to return.
884+
885+
Returns
886+
-------
887+
workspaces
888+
Dictionary containing the list of chat workspaces and pagination information.
889+
890+
Raises
891+
------
892+
MeilisearchApiError
893+
An error containing details about why Meilisearch can't process your request. Meilisearch error codes are described here: https://www.meilisearch.com/docs/reference/errors/error_codes#meilisearch-errors
894+
"""
895+
q = []
896+
if offset is not None:
897+
q.append(f"offset={offset}")
898+
if limit is not None:
899+
q.append(f"limit={limit}")
900+
path = "chats" + ("?" + "&".join(q) if q else "")
901+
return self.http.get(path)
902+
903+
def get_chat_workspace_settings(self, workspace_uid: str) -> Dict[str, Any]:
904+
"""Get the settings for a specific chat workspace.
905+
906+
Parameters
907+
----------
908+
workspace_uid:
909+
Unique identifier of the chat workspace.
910+
911+
Returns
912+
-------
913+
settings:
914+
Dictionary containing the workspace settings.
915+
916+
Raises
917+
------
918+
MeilisearchApiError
919+
An error containing details about why Meilisearch can't process your request. Meilisearch error codes are described here: https://www.meilisearch.com/docs/reference/errors/error_codes#meilisearch-errors
920+
"""
921+
return self.http.get(f"chats/{workspace_uid}/settings")
922+
923+
924+
def update_chat_workspace_settings(
925+
self, workspace_uid: str, settings: Mapping[str, Any]
926+
) -> Dict[str, Any]:
927+
"""Update the settings for a specific chat workspace.
928+
929+
Parameters
930+
----------
931+
workspace_uid:
932+
Unique identifier of the chat workspace.
933+
settings:
934+
Dictionary containing the settings to update.
935+
936+
Returns
937+
-------
938+
settings:
939+
Dictionary containing the updated workspace settings.
940+
941+
Raises
942+
------
943+
MeilisearchApiError
944+
An error containing details about why Meilisearch can't process your request. Meilisearch error codes are described here: https://www.meilisearch.com/docs/reference/errors/error_codes#meilisearch-errors
945+
"""
946+
return self.http.patch(f"chats/{workspace_uid}/settings", body=settings)
947+
798948
@staticmethod
799949
def _base64url_encode(data: bytes) -> str:
800950
return base64.urlsafe_b64encode(data).decode("utf-8").replace("=", "")
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
# pylint: disable=invalid-name
2+
3+
import json
4+
from unittest.mock import patch
5+
6+
import pytest
7+
import requests
8+
9+
import meilisearch
10+
from meilisearch.errors import MeilisearchApiError, MeilisearchCommunicationError
11+
from tests import BASE_URL, MASTER_KEY
12+
13+
14+
class MockStreamingResponse:
15+
"""Mock response object for testing streaming functionality."""
16+
17+
def __init__(self, lines, ok=True, status_code=200, text=""):
18+
self.lines = lines
19+
self.ok = ok
20+
self.status_code = status_code
21+
self.text = text
22+
self._closed = False
23+
24+
def iter_lines(self):
25+
"""Simulate iter_lines() method of requests.Response."""
26+
for line in self.lines:
27+
yield line
28+
29+
def close(self):
30+
"""Simulate close() method of requests.Response."""
31+
self._closed = True
32+
33+
def raise_for_status(self):
34+
"""Simulate raise_for_status() method of requests.Response."""
35+
if not self.ok:
36+
raise requests.exceptions.HTTPError(f"HTTP {self.status_code}")
37+
38+
39+
def test_create_chat_completion_basic_stream(client):
40+
"""Test basic streaming functionality with successful response."""
41+
dummy_lines = [
42+
b'data: {"id":"chatcmpl-1","object":"chat.completion.chunk","choices":[{"delta":{"content":"Hello"}}]}',
43+
b'data: {"id":"chatcmpl-1","object":"chat.completion.chunk","choices":[{"delta":{"content":" world"}}]}',
44+
b'data: [DONE]'
45+
]
46+
mock_resp = MockStreamingResponse(dummy_lines)
47+
48+
with patch.object(client.http, 'post_stream', return_value=mock_resp) as mock_post:
49+
messages = [{"role": "user", "content": "Hi"}]
50+
chunks = list(client.create_chat_completion("my-assistant", messages=messages))
51+
52+
# Verify the HTTP call was made correctly
53+
mock_post.assert_called_once_with(
54+
"chats/my-assistant/chat/completions",
55+
body={
56+
"model": "gpt-3.5-turbo",
57+
"messages": messages,
58+
"stream": True
59+
}
60+
)
61+
62+
# Verify the chunks are parsed correctly
63+
assert len(chunks) == 2
64+
assert chunks[0]["choices"][0]["delta"]["content"] == "Hello"
65+
assert chunks[1]["choices"][0]["delta"]["content"] == " world"
66+
assert mock_resp._closed
67+
68+
69+
def test_create_chat_completion_stream_false_raises_error(client):
70+
"""Test that stream=False raises ValueError."""
71+
messages = [{"role": "user", "content": "Test"}]
72+
73+
with pytest.raises(ValueError, match="Non-streaming chat completions are not supported"):
74+
list(client.create_chat_completion("my-assistant", messages=messages, stream=False))
75+
76+
77+
def test_create_chat_completion_json_decode_error(client):
78+
"""Test that malformed JSON raises MeilisearchCommunicationError."""
79+
dummy_lines = [
80+
b'data: {"invalid": json}', # Malformed JSON
81+
]
82+
mock_resp = MockStreamingResponse(dummy_lines)
83+
84+
with patch.object(client.http, 'post_stream', return_value=mock_resp):
85+
messages = [{"role": "user", "content": "Test"}]
86+
87+
with pytest.raises(MeilisearchCommunicationError, match="Failed to parse chat chunk"):
88+
list(client.create_chat_completion("my-assistant", messages=messages))
89+
90+
91+
def test_create_chat_completion_http_error_propagated(client):
92+
"""Test that HTTP errors from post_stream are properly propagated."""
93+
with patch.object(client.http, 'post_stream') as mock_post:
94+
error_response = MockStreamingResponse([], ok=False, status_code=400, text='{"message": "API Error"}')
95+
mock_post.side_effect = MeilisearchApiError("API Error", error_response)
96+
messages = [{"role": "user", "content": "Test"}]
97+
98+
with pytest.raises(MeilisearchApiError, match="API Error"):
99+
list(client.create_chat_completion("my-assistant", messages=messages))
100+
101+
102+
def test_get_chat_workspaces(client):
103+
"""Test basic get_chat_workspaces functionality."""
104+
mock_response = {
105+
"results": [
106+
{"uid": "workspace1", "name": "My Workspace", "model": "gpt-3.5-turbo"},
107+
{"uid": "workspace2", "name": "Another Workspace", "model": "gpt-4"}
108+
],
109+
"offset": 0,
110+
"limit": 20,
111+
"total": 2
112+
}
113+
114+
with patch.object(client.http, 'get', return_value=mock_response) as mock_get:
115+
result = client.get_chat_workspaces()
116+
117+
# Verify the HTTP call was made correctly
118+
mock_get.assert_called_once_with("chats")
119+
120+
# Verify the response is returned as-is
121+
assert result == mock_response
122+
123+
124+
def test_update_chat_workspace_settings(client):
125+
"""Test basic update_chat_workspace_settings functionality."""
126+
mock_response = {
127+
"model": "gpt-4-turbo",
128+
"temperature": 0.8,
129+
"max_tokens": 1500
130+
}
131+
132+
settings_update = {
133+
"temperature": 0.8,
134+
"max_tokens": 1500
135+
}
136+
137+
with patch.object(client.http, 'patch', return_value=mock_response) as mock_patch:
138+
result = client.update_chat_workspace_settings("my-workspace", settings_update)
139+
140+
# Verify the HTTP call was made correctly
141+
mock_patch.assert_called_once_with("chats/my-workspace/settings", body=settings_update)
142+
143+
# Verify the response is returned as-is
144+
assert result == mock_response

0 commit comments

Comments
 (0)