Skip to content

Commit ff8017c

Browse files
Merge pull request #13 from mindsdb/unit-tests
Unit tests
2 parents 0ab63f8 + e222e5d commit ff8017c

File tree

5 files changed

+309
-6
lines changed

5 files changed

+309
-6
lines changed

.github/workflows/test_on_pr.yml

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
name: PR workflow
2+
3+
on:
4+
pull_request:
5+
branches:
6+
- main
7+
8+
jobs:
9+
test:
10+
runs-on: ${{ matrix.os }}
11+
strategy:
12+
matrix:
13+
os: [ubuntu-latest]
14+
python-version: ['3.8', '3.9', '3.10', '3.11']
15+
steps:
16+
- uses: actions/checkout@v2
17+
18+
- name: Set up Python ${{ matrix.python-version }}
19+
uses: actions/setup-python@v4
20+
with:
21+
python-version: ${{ matrix.python-version }}
22+
23+
- name: Install dependencies
24+
run: |
25+
python -m pip install --upgrade pip
26+
pip install -r requirements.txt
27+
pip install -r requirements_test.txt
28+
- name: Run tests
29+
run: |
30+
env PYTHONPATH=./ pytest tests/unit

minds/exceptions.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,14 @@ class ObjectNotSupported(Exception):
77
...
88

99

10+
class Forbidden(Exception):
11+
...
12+
13+
14+
class Unauthorized(Exception):
15+
...
16+
17+
1018
class UnknownError(Exception):
1119
...
1220

minds/rest_api.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,12 @@ def _raise_for_status(response):
77
if response.status_code == 404:
88
raise exc.ObjectNotFound(response.text)
99

10+
if response.status_code == 403:
11+
raise exc.Forbidden(response.text)
12+
13+
if response.status_code == 401:
14+
raise exc.Unauthorized(response.text)
15+
1016
if 400 <= response.status_code < 600:
1117
raise exc.UnknownError(f'{response.reason}: {response.text}')
1218

tests/integration/test_base_flow.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
import copy
3+
import pytest
34

45
from minds.client import Client
56

@@ -18,6 +19,13 @@ def get_client():
1819
return Client(api_key, base_url=base_url)
1920

2021

22+
def test_wrong_api_key():
23+
base_url = 'https://dev.mindsdb.com'
24+
client = Client('api_key', base_url=base_url)
25+
with pytest.raises(Exception):
26+
client.datasources.get('example_db')
27+
28+
2129
def test_datasources():
2230
client = get_client()
2331

@@ -102,12 +110,9 @@ def test_minds():
102110
'prompt_template': prompt2
103111
}
104112
)
105-
try:
106-
mind = client.minds.get(mind_name)
107-
except ObjectNotFound:
108-
...
109-
else:
110-
raise Exception('mind is not renamed')
113+
with pytest.raises(ObjectNotFound):
114+
# this name not exists
115+
client.minds.get(mind_name)
111116

112117
mind = client.minds.get(mind_name2)
113118
assert len(mind.datasources) == 1

tests/unit/test_unit.py

Lines changed: 254 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,254 @@
1+
2+
from unittest.mock import Mock
3+
from unittest.mock import patch
4+
5+
6+
from minds.datasources.examples import example_ds
7+
8+
def get_client():
9+
from minds.client import Client
10+
return Client(API_KEY)
11+
12+
from minds import rest_api
13+
14+
# patch _raise_for_status
15+
rest_api._raise_for_status = Mock()
16+
#
17+
def response_mock(mock, data):
18+
def side_effect(*args, **kwargs):
19+
r_mock = Mock()
20+
r_mock.status_code = 200
21+
r_mock.json.return_value = data
22+
return r_mock
23+
mock.side_effect = side_effect
24+
25+
26+
API_KEY = '1234567890abc'
27+
28+
29+
class TestDatasources:
30+
31+
def _compare_ds(self, ds1, ds2):
32+
assert ds1.name == ds2.name
33+
assert ds1.engine == ds2.engine
34+
assert ds1.description == ds2.description
35+
assert ds1.connection_data == ds2.connection_data
36+
assert ds1.tables == ds2.tables
37+
38+
@patch('requests.get')
39+
@patch('requests.post')
40+
@patch('requests.delete')
41+
def test_create_datasources(self, mock_del, mock_post, mock_get):
42+
client = get_client()
43+
response_mock(mock_get, example_ds.model_dump())
44+
45+
ds = client.datasources.create(example_ds)
46+
def check_ds_created(ds, mock_post):
47+
self._compare_ds(ds, example_ds)
48+
args, kwargs = mock_post.call_args
49+
50+
assert kwargs['headers'] == {'Authorization': 'Bearer ' + API_KEY}
51+
assert kwargs['json'] == example_ds.model_dump()
52+
assert args[0] == 'https://mdb.ai/api/datasources'
53+
54+
check_ds_created(ds, mock_post)
55+
56+
# with replace
57+
ds = client.datasources.create(example_ds, replace=True)
58+
args, _ = mock_del.call_args
59+
assert args[0].endswith(f'/api/datasources/{example_ds.name}')
60+
61+
check_ds_created(ds, mock_post)
62+
63+
@patch('requests.get')
64+
def test_get_datasource(self, mock_get):
65+
client = get_client()
66+
67+
response_mock(mock_get, example_ds.model_dump())
68+
ds = client.datasources.get(example_ds.name)
69+
self._compare_ds(ds, example_ds)
70+
71+
args, _ = mock_get.call_args
72+
assert args[0].endswith(f'/api/datasources/{example_ds.name}')
73+
74+
@patch('requests.delete')
75+
def test_delete_datasource(self, mock_del):
76+
client = get_client()
77+
78+
client.datasources.drop('ds_name')
79+
80+
args, _ = mock_del.call_args
81+
assert args[0].endswith(f'/api/datasources/ds_name')
82+
83+
@patch('requests.get')
84+
def test_list_datasources(self, mock_get):
85+
client = get_client()
86+
87+
response_mock(mock_get, [example_ds.model_dump()])
88+
ds_list = client.datasources.list()
89+
assert len(ds_list) == 1
90+
ds = ds_list[0]
91+
self._compare_ds(ds, example_ds)
92+
93+
args, _ = mock_get.call_args
94+
assert args[0].endswith(f'/api/datasources')
95+
96+
97+
class TestMinds:
98+
99+
mind_json = {
100+
'model_name': 'gpt-4o',
101+
'name': 'test_mind',
102+
'datasources': ['example_ds'],
103+
'provider': 'openai',
104+
'parameters': {
105+
'prompt_template': "Answer the user's question"
106+
},
107+
'created_at': 'Thu, 26 Sep 2024 13:40:57 GMT',
108+
'updated_at': 'Thu, 26 Sep 2024 13:40:57 GMT',
109+
}
110+
111+
def compare_mind(self, mind, mind_json):
112+
assert mind.name == mind_json['name']
113+
assert mind.model_name == mind_json['model_name']
114+
assert mind.provider == mind_json['provider']
115+
assert mind.parameters == mind_json['parameters']
116+
117+
@patch('requests.get')
118+
@patch('requests.post')
119+
@patch('requests.delete')
120+
def test_create(self, mock_del, mock_post, mock_get):
121+
client = get_client()
122+
123+
mind_name = 'test_mind'
124+
parameters = {'prompt_template': 'always agree'}
125+
datasources = ['my_ds']
126+
provider = 'openai'
127+
128+
response_mock(mock_get, self.mind_json)
129+
create_params = {
130+
'name': mind_name,
131+
'parameters': parameters,
132+
'datasources': datasources
133+
}
134+
mind = client.minds.create(**create_params)
135+
136+
def check_mind_created(mind, mock_post, create_params):
137+
args, kwargs = mock_post.call_args
138+
assert args[0].endswith('/api/projects/mindsdb/minds')
139+
request = kwargs['json']
140+
for k, v in create_params.items():
141+
assert request[k] == v
142+
143+
self.compare_mind(mind, self.mind_json)
144+
145+
check_mind_created(mind, mock_post, create_params)
146+
147+
# with replace
148+
create_params = {
149+
'name': mind_name,
150+
'parameters': parameters,
151+
'provider': provider,
152+
}
153+
mind = client.minds.create(replace=True, **create_params)
154+
155+
# was deleted
156+
args, _ = mock_del.call_args
157+
assert args[0].endswith(f'/api/projects/mindsdb/minds/{mind_name}')
158+
159+
check_mind_created(mind, mock_post, create_params)
160+
161+
@patch('requests.get')
162+
@patch('requests.patch')
163+
def test_update(self, mock_patch, mock_get):
164+
client = get_client()
165+
166+
response_mock(mock_get, self.mind_json)
167+
mind = client.minds.get('mind_name')
168+
169+
update_params = dict(
170+
name='mind_name2',
171+
datasources=['ds_name'],
172+
provider='ollama',
173+
model_name='llama',
174+
parameters={
175+
'prompt_template': 'be polite'
176+
}
177+
)
178+
mind.update(**update_params)
179+
180+
args, kwargs = mock_patch.call_args
181+
assert args[0].endswith(f'/api/projects/mindsdb/minds/{self.mind_json["name"]}')
182+
183+
assert kwargs['json'] == update_params
184+
185+
@patch('requests.get')
186+
def test_get(self, mock_get):
187+
client = get_client()
188+
189+
response_mock(mock_get, self.mind_json)
190+
191+
mind = client.minds.get('my_mind')
192+
self.compare_mind(mind, self.mind_json)
193+
194+
args, _ = mock_get.call_args
195+
assert args[0].endswith(f'/api/projects/mindsdb/minds/my_mind')
196+
197+
@patch('requests.get')
198+
def test_list(self, mock_get):
199+
client = get_client()
200+
201+
response_mock(mock_get, [self.mind_json])
202+
minds_list = client.minds.list()
203+
assert len(minds_list) == 1
204+
self.compare_mind(minds_list[0], self.mind_json)
205+
206+
args, _ = mock_get.call_args
207+
assert args[0].endswith(f'/api/projects/mindsdb/minds')
208+
209+
@patch('requests.delete')
210+
def test_delete(self, mock_del):
211+
client = get_client()
212+
client.minds.drop('my_name')
213+
214+
args, _ = mock_del.call_args
215+
assert args[0].endswith(f'/api/projects/mindsdb/minds/my_name')
216+
217+
@patch('requests.get')
218+
@patch('minds.minds.OpenAI')
219+
def test_completion(self, mock_openai, mock_get):
220+
client = get_client()
221+
222+
response_mock(mock_get, self.mind_json)
223+
mind = client.minds.get('mind_name')
224+
225+
openai_response = 'how can I assist you today?'
226+
227+
def openai_completion_f(messages, *args, **kwargs):
228+
# echo question
229+
answer = messages[0]['content']
230+
231+
response = Mock()
232+
choice = Mock()
233+
choice.message.content = answer
234+
choice.delta.content = answer # for stream
235+
response.choices = [choice]
236+
237+
if kwargs.get('stream'):
238+
return [response]
239+
else:
240+
return response
241+
242+
mock_openai().chat.completions.create.side_effect = openai_completion_f
243+
244+
question = 'the ultimate question'
245+
246+
answer = mind.completion(question)
247+
assert answer == question
248+
249+
success = False
250+
for chunk in mind.completion(question, stream=True):
251+
if question == chunk.content.lower():
252+
success = True
253+
assert success is True
254+

0 commit comments

Comments
 (0)