Skip to content

Commit e2a908b

Browse files
Merge pull request #68 from mindsdb/ds-tables
Limit tables of datasourse in context of the mind
2 parents 9af5cb5 + 465a296 commit e2a908b

File tree

6 files changed

+161
-54
lines changed

6 files changed

+161
-54
lines changed

README.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,17 @@ mind3.add_datasource(postgres_config) # Using the config
7575
mind3.add_datasource(datasource) # Using the data source object
7676
```
7777

78+
Create mind with tables restriction for datasource:
79+
```python
80+
from minds.datasources.datasources import DatabaseTables
81+
datasource = DatabaseTables(
82+
name='my_db',
83+
tables=['table1', 'table1'],
84+
)
85+
mind4 = client.minds.create(name='mind_name', datasources=[datasource])
86+
```
87+
88+
7889
### Managing Minds
7990

8091
You can create a mind or replace an existing one with the same name.

examples/base_usage.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,11 @@
4040
# with prompt template
4141
mind = client.minds.create(name='mind_name', prompt_template='You are codding assistant')
4242

43+
# restrict tables for datasource in context of the mind:
44+
from minds.datasources.datasources import DatabaseTables
45+
datasource = DatabaseTables(name='my_datasource', tables=['table1', 'table1'])
46+
mind = client.minds.create(name='mind_name', datasources=[datasource])
47+
4348
# or add to existed mind
4449
mind = client.minds.create(name='mind_name')
4550
# by config

minds/datasources/datasources.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,35 @@
44
import minds.utils as utils
55
import minds.exceptions as exc
66

7-
class DatabaseConfig(BaseModel):
87

8+
class DatabaseConfigBase(BaseModel):
9+
"""
10+
Base class
11+
"""
912
name: str
13+
tables: Union[List[str], None] = []
14+
15+
16+
class DatabaseTables(DatabaseConfigBase):
17+
"""
18+
Used when only database and tables are required to be defined. For example in minds.create
19+
"""
20+
...
21+
22+
23+
class DatabaseConfig(DatabaseConfigBase):
24+
"""
25+
Used to define datasource before creating it.
26+
"""
1027
engine: str
1128
description: str
1229
connection_data: Union[dict, None] = {}
13-
tables: Union[List[str], None] = []
1430

1531

1632
class Datasource(DatabaseConfig):
33+
"""
34+
Existed datasource. It is returned by this SDK when datasource is queried from server
35+
"""
1736
...
1837

1938

minds/minds.py

Lines changed: 28 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from openai import OpenAI
33
import minds.utils as utils
44
import minds.exceptions as exc
5-
from minds.datasources import Datasource, DatabaseConfig
5+
from minds.datasources import Datasource, DatabaseConfig, DatabaseTables, DatabaseConfigBase
66
from minds.knowledge_bases import KnowledgeBase, KnowledgeBaseConfig
77

88
DEFAULT_PROMPT_TEMPLATE = 'Use your database tools to answer the user\'s question: {{question}}'
@@ -89,11 +89,10 @@ def update(
8989
utils.validate_mind_name(name)
9090

9191
if datasources is not None:
92-
ds_names = []
92+
ds_list = []
9393
for ds in datasources:
94-
ds = self.client.minds._check_datasource(ds)
95-
ds_names.append(ds)
96-
data['datasources'] = ds_names
94+
ds_list.append(self.client.minds._check_datasource(ds))
95+
data['datasources'] = ds_list
9796

9897
if knowledge_bases is not None:
9998
kb_names = []
@@ -145,7 +144,7 @@ def add_datasource(self, datasource: Datasource):
145144
:param datasource: input datasource
146145
"""
147146

148-
ds_name = self.client.minds._check_datasource(datasource)
147+
ds_name = self.client.minds._check_datasource(datasource)['name']
149148

150149
self.api.post(
151150
f'/projects/{self.project}/minds/{self.name}/datasources',
@@ -279,20 +278,28 @@ def get(self, name: str) -> Mind:
279278
item = self.api.get(f'/projects/{self.project}/minds/{name}').json()
280279
return Mind(self.client, **item)
281280

282-
def _check_datasource(self, ds) -> str:
283-
if isinstance(ds, Datasource):
284-
ds = ds.name
285-
elif isinstance(ds, DatabaseConfig):
286-
# if not exists - create
287-
try:
288-
self.client.datasources.get(ds.name)
289-
except exc.ObjectNotFound:
290-
self.client.datasources.create(ds)
281+
def _check_datasource(self, ds) -> dict:
282+
if isinstance(ds, DatabaseConfigBase):
283+
res = {'name': ds.name}
284+
285+
if isinstance(ds, DatabaseTables):
286+
if ds.tables:
287+
res['tables'] = ds.tables
288+
289+
if isinstance(ds, DatabaseConfig):
290+
# if not exists - create
291+
try:
292+
self.client.datasources.get(ds.name)
293+
except exc.ObjectNotFound:
294+
self.client.datasources.create(ds)
291295

292-
ds = ds.name
293-
elif not isinstance(ds, str):
296+
elif isinstance(ds, str):
297+
res = {'name': ds}
298+
else:
294299
raise ValueError(f'Unknown type of datasource: {ds}')
295-
return ds
300+
301+
return res
302+
296303

297304
def _check_knowledge_base(self, knowledge_base) -> str:
298305
if isinstance(knowledge_base, KnowledgeBase):
@@ -355,12 +362,10 @@ def create(
355362
except exc.ObjectNotFound:
356363
...
357364

358-
ds_names = []
365+
ds_list = []
359366
if datasources:
360367
for ds in datasources:
361-
ds = self._check_datasource(ds)
362-
363-
ds_names.append(ds)
368+
ds_list.append(self._check_datasource(ds))
364369

365370
kb_names = []
366371
if knowledge_bases:
@@ -390,7 +395,7 @@ def create(
390395
'model_name': model_name,
391396
'provider': provider,
392397
'parameters': parameters,
393-
'datasources': ds_names,
398+
'datasources': ds_list,
394399
'knowledge_bases': kb_names
395400
}
396401
)

tests/integration/test_base_flow.py

Lines changed: 45 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
logging.basicConfig(level=logging.DEBUG)
99

1010
from minds.datasources.examples import example_ds
11-
from minds.datasources import DatabaseConfig
11+
from minds.datasources import DatabaseConfig, DatabaseTables
1212

1313
from minds.exceptions import ObjectNotFound, MindNameInvalid, DatasourceNameInvalid
1414

@@ -64,8 +64,8 @@ def test_datasources():
6464
def test_minds():
6565
client = get_client()
6666

67-
ds_name = 'test_datasource_'
68-
ds_name2 = 'test_datasource2_'
67+
ds_all_name = 'test_datasource_' # unlimited tables
68+
ds_rentals_name = 'test_datasource2_' # limited to home rentals
6969
mind_name = 'int_test_mind_'
7070
invalid_mind_name = 'mind-123'
7171
mind_name2 = 'int_test_mind2_'
@@ -80,38 +80,38 @@ def test_minds():
8080
...
8181

8282
# prepare datasources
83-
ds_cfg = copy.copy(example_ds)
84-
ds_cfg.name = ds_name
85-
ds = client.datasources.create(example_ds, update=True)
83+
ds_all_cfg = copy.copy(example_ds)
84+
ds_all_cfg.name = ds_all_name
85+
ds_all = client.datasources.create(ds_all_cfg, update=True)
8686

8787
# second datasource
88-
ds2_cfg = copy.copy(example_ds)
89-
ds2_cfg.name = ds_name2
90-
ds2_cfg.tables = ['home_rentals']
88+
ds_rentals_cfg = copy.copy(example_ds)
89+
ds_rentals_cfg.name = ds_rentals_name
90+
ds_rentals_cfg.tables = ['home_rentals']
9191

9292
# create
9393
with pytest.raises(MindNameInvalid):
9494
client.minds.create(
9595
invalid_mind_name,
96-
datasources=[ds],
96+
datasources=[ds_all],
9797
provider='openai'
9898
)
9999

100100
mind = client.minds.create(
101101
mind_name,
102-
datasources=[ds],
102+
datasources=[ds_all],
103103
provider='openai'
104104
)
105105
mind = client.minds.create(
106106
mind_name,
107107
replace=True,
108-
datasources=[ds.name, ds2_cfg],
108+
datasources=[ds_all.name, ds_rentals_cfg],
109109
prompt_template=prompt1
110110
)
111111
mind = client.minds.create(
112112
mind_name,
113113
update=True,
114-
datasources=[ds.name, ds2_cfg],
114+
datasources=[ds_all.name, ds_rentals_cfg],
115115
prompt_template=prompt1
116116
)
117117

@@ -131,14 +131,14 @@ def test_minds():
131131
# rename & update
132132
mind.update(
133133
name=mind_name2,
134-
datasources=[ds.name],
134+
datasources=[ds_all.name],
135135
prompt_template=prompt2
136136
)
137137

138138
with pytest.raises(MindNameInvalid):
139139
mind.update(
140140
name=invalid_mind_name,
141-
datasources=[ds.name],
141+
datasources=[ds_all.name],
142142
prompt_template=prompt2
143143
)
144144

@@ -151,28 +151,40 @@ def test_minds():
151151
assert mind.prompt_template == prompt2
152152

153153
# add datasource
154-
mind.add_datasource(ds2_cfg)
154+
mind.add_datasource(ds_rentals_cfg)
155155
assert len(mind.datasources) == 2
156156

157157
# del datasource
158-
mind.del_datasource(ds2_cfg.name)
158+
mind.del_datasource(ds_rentals_cfg.name)
159159
assert len(mind.datasources) == 1
160160

161161
# ask about data
162162
answer = mind.completion('what is max rental price in home rental?')
163163
assert '5602' in answer.replace(' ', '').replace(',', '')
164164

165165
# limit tables
166-
mind.del_datasource(ds.name)
167-
mind.add_datasource(ds_name2)
166+
mind.del_datasource(ds_all.name)
167+
mind.add_datasource(ds_rentals_name)
168168
assert len(mind.datasources) == 1
169169

170-
answer = mind.completion('what is max rental price in home rental?')
171-
assert '5602' in answer.replace(' ', '').replace(',', '')
170+
check_mind_can_see_only_rentals(mind)
172171

173-
# not accessible table
174-
answer = mind.completion('what is max price in car sales?')
175-
assert '145000' not in answer.replace(' ', '').replace(',', '')
172+
# test ds with limited tables
173+
ds_all_limited = DatabaseTables(
174+
name=ds_all_name,
175+
tables=['home_rentals']
176+
)
177+
# mind = client.minds.create(
178+
# 'mind_ds_limited_',
179+
# replace=True,
180+
# datasources=[ds_all],
181+
# prompt_template=prompt2
182+
# )
183+
mind.update(
184+
name=mind.name,
185+
datasources=[ds_all_limited],
186+
)
187+
check_mind_can_see_only_rentals(mind)
176188

177189
# stream completion
178190
success = False
@@ -183,6 +195,13 @@ def test_minds():
183195

184196
# drop
185197
client.minds.drop(mind_name2)
186-
client.datasources.drop(ds.name)
187-
client.datasources.drop(ds2_cfg.name)
198+
client.datasources.drop(ds_all.name)
199+
client.datasources.drop(ds_rentals_cfg.name)
200+
201+
def check_mind_can_see_only_rentals(mind):
202+
answer = mind.completion('what is max rental price in home rental?')
203+
assert '5602' in answer.replace(' ', '').replace(',', '')
188204

205+
# not accessible table
206+
answer = mind.completion('what is max price in car sales?')
207+
assert '145000' not in answer.replace(' ', '').replace(',', '')

tests/unit/test_unit.py

Lines changed: 51 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from unittest.mock import patch
44

55

6+
from minds.datasources.datasources import DatabaseTables
67
from minds.datasources.examples import example_ds
78
from minds.knowledge_bases import EmbeddingConfig, KnowledgeBaseConfig, VectorStoreConfig
89

@@ -322,14 +323,59 @@ def check_mind_created(mind, mock_post, create_params, url):
322323
args, kwargs = mock_post.call_args
323324
assert args[0].endswith(url)
324325
request = kwargs['json']
325-
for key in ('name', 'datasources', 'knowledge_bases', 'provider', 'model_name'),:
326-
assert request.get(key) == create_params.get(key)
326+
for key in ('name', 'datasources', 'knowledge_bases', 'provider', 'model_name'):
327+
req, param = request.get(key), create_params.get(key)
328+
if key == 'datasources':
329+
param = [{'name': param[0]}]
330+
331+
assert req == param
332+
333+
assert create_params.get('prompt_template') == request.get('parameters', {}).get('prompt_template')
334+
335+
self.compare_mind(mind, self.mind_json)
336+
337+
check_mind_created(mind, mock_post, create_params, '/api/projects/mindsdb/minds')
338+
339+
# -- datasource with tables --
340+
knowledge_bases = ['example_kb']
341+
provider = 'openai'
342+
343+
response_mock(mock_get, self.mind_json)
344+
345+
ds_conf = DatabaseTables(
346+
name='my_db',
347+
tables=['table1', 'table1'],
348+
)
349+
create_params = {
350+
'name': mind_name,
351+
'prompt_template': prompt_template,
352+
'datasources': [ds_conf],
353+
'knowledge_bases': knowledge_bases
354+
}
355+
mind = client.minds.create(**create_params)
356+
357+
def check_mind_created(mind, mock_post, create_params, url):
358+
args, kwargs = mock_post.call_args
359+
assert args[0].endswith(url)
360+
request = kwargs['json']
361+
for key in ('name', 'datasources', 'knowledge_bases', 'provider', 'model_name'):
362+
req, param = request.get(key), create_params.get(key)
363+
if key == 'datasources':
364+
if param is not None:
365+
ds = param[0]
366+
param = [{'name': ds.name, 'tables': ds.tables}]
367+
else:
368+
param = []
369+
if key == 'knowledge_bases' and param is None:
370+
param = []
371+
assert req == param
327372
assert create_params.get('prompt_template') == request.get('parameters', {}).get('prompt_template')
328373

329374
self.compare_mind(mind, self.mind_json)
330375

331376
check_mind_created(mind, mock_post, create_params, '/api/projects/mindsdb/minds')
332377

378+
333379
# -- with replace --
334380
create_params = {
335381
'name': mind_name,
@@ -375,7 +421,9 @@ def test_update(self, mock_patch, mock_get):
375421
args, kwargs = mock_patch.call_args
376422
assert args[0].endswith(f'/api/projects/mindsdb/minds/{self.mind_json["name"]}')
377423

378-
assert kwargs['json'] == update_params
424+
params = update_params.copy()
425+
params['datasources'] = [{'name': 'ds_name'}]
426+
assert kwargs['json'] == params
379427

380428
@patch('requests.get')
381429
def test_get(self, mock_get):

0 commit comments

Comments
 (0)