Skip to content

Commit bf94229

Browse files
fixed integration tests
1 parent 836fd54 commit bf94229

File tree

1 file changed

+76
-68
lines changed

1 file changed

+76
-68
lines changed

tests/integration/test_base_flow.py

Lines changed: 76 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,14 @@
11
import os
2-
import copy
32
import pytest
43

54
from minds.client import Client
5+
from minds.exceptions import ObjectNotFound, MindNameInvalid, DatasourceNameInvalid
66

77
import logging
88
logging.basicConfig(level=logging.DEBUG)
99

10-
from minds.datasources.examples import example_ds
11-
from minds.datasources import DatabaseConfig, DatabaseTables
1210

13-
from minds.exceptions import ObjectNotFound, MindNameInvalid, DatasourceNameInvalid
11+
# TODO: Validate these tests and ensure coverage
1412

1513

1614
def get_client():
@@ -20,6 +18,23 @@ def get_client():
2018
return Client(api_key, base_url=base_url)
2119

2220

21+
def get_example_datasource_config():
22+
"""Get example datasource configuration parameters"""
23+
return {
24+
'name': 'example_ds',
25+
'engine': 'postgres',
26+
'description': 'Minds example database',
27+
'connection_data': {
28+
"user": "demo_user",
29+
"password": "demo_password",
30+
"host": "samples.mindsdb.com",
31+
"port": "5432",
32+
"database": "demo",
33+
"schema": "demo_data"
34+
}
35+
}
36+
37+
2338
def test_wrong_api_key():
2439
base_url = 'https://dev.mdb.ai'
2540
client = Client('api_key', base_url=base_url)
@@ -29,29 +44,30 @@ def test_wrong_api_key():
2944

3045
def test_datasources():
3146
client = get_client()
47+
example_ds_config = get_example_datasource_config()
3248

3349
# remove previous object
3450
try:
35-
client.datasources.drop(example_ds.name, force=True)
51+
client.datasources.drop(example_ds_config['name'])
3652
except ObjectNotFound:
3753
...
3854

3955
# create
40-
ds = client.datasources.create(example_ds)
41-
assert ds.name == example_ds.name
42-
ds = client.datasources.create(example_ds, update=True)
43-
assert ds.name == example_ds.name
44-
45-
valid_ds_name = example_ds.name
56+
ds = client.datasources.create(**example_ds_config)
57+
assert ds.name == example_ds_config['name']
58+
59+
# create with replace
60+
ds = client.datasources.create(**example_ds_config, replace=True)
61+
assert ds.name == example_ds_config['name']
4662

63+
# test invalid datasource name
4764
with pytest.raises(DatasourceNameInvalid):
48-
example_ds.name = "invalid-ds-name"
49-
client.datasources.create(example_ds)
50-
51-
example_ds.name = valid_ds_name
65+
invalid_config = example_ds_config.copy()
66+
invalid_config['name'] = "invalid-ds-name"
67+
client.datasources.create(**invalid_config)
5268

5369
# get
54-
ds = client.datasources.get(example_ds.name)
70+
ds = client.datasources.get(example_ds_config['name'])
5571

5672
# list
5773
ds_list = client.datasources.list()
@@ -63,14 +79,13 @@ def test_datasources():
6379

6480
def test_minds():
6581
client = get_client()
82+
example_ds_config = get_example_datasource_config()
6683

6784
ds_all_name = 'test_datasource_' # unlimited tables
6885
ds_rentals_name = 'test_datasource2_' # limited to home rentals
6986
mind_name = 'int_test_mind_'
7087
invalid_mind_name = 'mind-123'
7188
mind_name2 = 'int_test_mind2_'
72-
prompt1 = 'answer in spanish'
73-
prompt2 = 'you are data expert'
7489

7590
# remove previous objects
7691
for name in (mind_name, mind_name2):
@@ -80,66 +95,69 @@ def test_minds():
8095
...
8196

8297
# prepare datasources
83-
ds_all_cfg = copy.copy(example_ds)
84-
ds_all_cfg.name = ds_all_name
85-
ds_all = client.datasources.create(ds_all_cfg, update=True)
98+
ds_all_config = example_ds_config.copy()
99+
ds_all_config['name'] = ds_all_name
100+
ds_all = client.datasources.create(**ds_all_config, replace=True)
86101

87-
# second datasource
88-
ds_rentals_cfg = copy.copy(example_ds)
89-
ds_rentals_cfg.name = ds_rentals_name
90-
ds_rentals_cfg.tables = ['home_rentals']
102+
# second datasource
103+
ds_rentals_config = example_ds_config.copy()
104+
ds_rentals_config['name'] = ds_rentals_name
105+
# Note: In the new API, tables are specified when adding datasource to mind, not when creating datasource
91106

92-
# create
107+
# create mind with invalid name should fail
93108
with pytest.raises(MindNameInvalid):
94109
client.minds.create(
95110
invalid_mind_name,
96-
datasources=[ds_all],
111+
datasources=[{'name': ds_all.name}],
97112
provider='openai'
98113
)
99114

115+
# create mind
100116
mind = client.minds.create(
101117
mind_name,
102-
datasources=[ds_all],
118+
datasources=[{'name': ds_all.name}],
103119
provider='openai'
104120
)
121+
122+
# create mind with replace
105123
mind = client.minds.create(
106124
mind_name,
107125
replace=True,
108-
datasources=[ds_all.name, ds_rentals_cfg],
109-
prompt_template=prompt1
110-
)
111-
mind = client.minds.create(
112-
mind_name,
113-
update=True,
114-
datasources=[ds_all.name, ds_rentals_cfg],
115-
prompt_template=prompt1
126+
datasources=[
127+
{'name': ds_all.name},
128+
{'name': ds_rentals_name, 'tables': ['home_rentals']}
129+
]
116130
)
117131

132+
# Create the second datasource that will be used later
133+
ds_rentals = client.datasources.create(**ds_rentals_config, replace=True)
134+
118135
# get
119136
mind = client.minds.get(mind_name)
120137
assert len(mind.datasources) == 2
121-
assert mind.prompt_template == prompt1
122138

123139
# list
124140
mind_list = client.minds.list()
125141
assert len(mind_list) > 0
126142

127-
# completion with prompt 1
143+
# completion test
128144
answer = mind.completion('say hello')
129-
assert 'hola' in answer.lower()
145+
assert len(answer) > 0 # Just check that we get a response
130146

131-
# rename & update
132-
mind.update(
133-
name=mind_name2,
134-
datasources=[ds_all.name],
135-
prompt_template=prompt2
147+
# rename & update using client.minds.update
148+
updated_mind = client.minds.update(
149+
name=mind_name,
150+
new_name=mind_name2,
151+
datasources=[{'name': ds_all.name}]
136152
)
153+
assert updated_mind.name == mind_name2
154+
assert len(updated_mind.datasources) == 1
137155

138156
with pytest.raises(MindNameInvalid):
139-
mind.update(
140-
name=invalid_mind_name,
141-
datasources=[ds_all.name],
142-
prompt_template=prompt2
157+
client.minds.update(
158+
name=mind_name2,
159+
new_name=invalid_mind_name,
160+
datasources=[{'name': ds_all.name}]
143161
)
144162

145163
with pytest.raises(ObjectNotFound):
@@ -148,42 +166,32 @@ def test_minds():
148166

149167
mind = client.minds.get(mind_name2)
150168
assert len(mind.datasources) == 1
151-
assert mind.prompt_template == prompt2
152169

153170
# add datasource
154-
mind.add_datasource(ds_rentals_cfg)
171+
mind.add_datasource(ds_rentals.name)
155172
assert len(mind.datasources) == 2
156173

157-
# del datasource
158-
mind.del_datasource(ds_rentals_cfg.name)
174+
# remove datasource
175+
mind.remove_datasource(ds_rentals.name)
159176
assert len(mind.datasources) == 1
160177

161178
# ask about data
162179
answer = mind.completion('what is max rental price in home rental?')
163180
assert '5602' in answer.replace(' ', '').replace(',', '')
164181

165182
# limit tables
166-
mind.del_datasource(ds_all.name)
167-
mind.add_datasource(ds_rentals_name)
183+
mind.remove_datasource(ds_all.name)
184+
mind.add_datasource(ds_rentals.name, tables=['home_rentals'])
168185
assert len(mind.datasources) == 1
169186

170187
check_mind_can_see_only_rentals(mind)
171188

172-
# test ds with limited tables
173-
ds_all_limited = DatabaseTables(
174-
name=ds_all_name,
175-
tables=['home_rentals']
176-
)
177-
# mind = client.minds.create(
178-
# 'mind_ds_limited_',
179-
# replace=True,
180-
# datasources=[ds_all],
181-
# prompt_template=prompt2
182-
# )
183-
mind.update(
189+
# test ds with limited tables - use client.minds.update instead of DatabaseTables
190+
client.minds.update(
184191
name=mind.name,
185-
datasources=[ds_all_limited],
192+
datasources=[{'name': ds_all.name, 'tables': ['home_rentals']}]
186193
)
194+
mind = client.minds.get(mind.name) # refresh mind object
187195
check_mind_can_see_only_rentals(mind)
188196

189197
# stream completion
@@ -196,7 +204,7 @@ def test_minds():
196204
# drop
197205
client.minds.drop(mind_name2)
198206
client.datasources.drop(ds_all.name)
199-
client.datasources.drop(ds_rentals_cfg.name)
207+
client.datasources.drop(ds_rentals.name)
200208

201209
def check_mind_can_see_only_rentals(mind):
202210
answer = mind.completion('what is max rental price in home rental?')

0 commit comments

Comments
 (0)