Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,17 @@ mind3.add_datasource(postgres_config) # Using the config
mind3.add_datasource(datasource) # Using the data source object
```

Create mind with tables restriction for datasource:
```python
from minds.datasources.datasources import DatabaseTables
datasource = DatabaseTables(
name='my_db',
tables=['table1', 'table1'],
)
mind4 = client.minds.create(name='mind_name', datasources=[datasource])
```


### Managing Minds

You can create a mind or replace an existing one with the same name.
Expand Down
5 changes: 5 additions & 0 deletions examples/base_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@
# with prompt template
mind = client.minds.create(name='mind_name', prompt_template='You are codding assistant')

# restrict tables for datasource in context of the mind:
from minds.datasources.datasources import DatabaseTables
datasource = DatabaseTables(name='my_datasource', tables=['table1', 'table1'])
mind = client.minds.create(name='mind_name', datasources=[datasource])

# or add to existed mind
mind = client.minds.create(name='mind_name')
# by config
Expand Down
23 changes: 21 additions & 2 deletions minds/datasources/datasources.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,35 @@
import minds.utils as utils
import minds.exceptions as exc

class DatabaseConfig(BaseModel):

class DatabaseConfigBase(BaseModel):
"""
Base class
"""
name: str
tables: Union[List[str], None] = []


class DatabaseTables(DatabaseConfigBase):
"""
Used when only database and tables are required to be defined. For example in minds.create
"""
...


class DatabaseConfig(DatabaseConfigBase):
"""
Used to define datasource before creating it.
"""
engine: str
description: str
connection_data: Union[dict, None] = {}
tables: Union[List[str], None] = []


class Datasource(DatabaseConfig):
"""
Existed datasource. It is returned by this SDK when datasource is queried from server
"""
...


Expand Down
51 changes: 28 additions & 23 deletions minds/minds.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from openai import OpenAI
import minds.utils as utils
import minds.exceptions as exc
from minds.datasources import Datasource, DatabaseConfig
from minds.datasources import Datasource, DatabaseConfig, DatabaseTables, DatabaseConfigBase
from minds.knowledge_bases import KnowledgeBase, KnowledgeBaseConfig

DEFAULT_PROMPT_TEMPLATE = 'Use your database tools to answer the user\'s question: {{question}}'
Expand Down Expand Up @@ -89,11 +89,10 @@ def update(
utils.validate_mind_name(name)

if datasources is not None:
ds_names = []
ds_list = []
for ds in datasources:
ds = self.client.minds._check_datasource(ds)
ds_names.append(ds)
data['datasources'] = ds_names
ds_list.append(self.client.minds._check_datasource(ds))
data['datasources'] = ds_list

if knowledge_bases is not None:
kb_names = []
Expand Down Expand Up @@ -145,7 +144,7 @@ def add_datasource(self, datasource: Datasource):
:param datasource: input datasource
"""

ds_name = self.client.minds._check_datasource(datasource)
ds_name = self.client.minds._check_datasource(datasource)['name']

self.api.post(
f'/projects/{self.project}/minds/{self.name}/datasources',
Expand Down Expand Up @@ -279,20 +278,28 @@ def get(self, name: str) -> Mind:
item = self.api.get(f'/projects/{self.project}/minds/{name}').json()
return Mind(self.client, **item)

def _check_datasource(self, ds) -> str:
if isinstance(ds, Datasource):
ds = ds.name
elif isinstance(ds, DatabaseConfig):
# if not exists - create
try:
self.client.datasources.get(ds.name)
except exc.ObjectNotFound:
self.client.datasources.create(ds)
def _check_datasource(self, ds) -> dict:
if isinstance(ds, DatabaseConfigBase):
res = {'name': ds.name}

if isinstance(ds, DatabaseTables):
if ds.tables:
res['tables'] = ds.tables

if isinstance(ds, DatabaseConfig):
# if not exists - create
try:
self.client.datasources.get(ds.name)
except exc.ObjectNotFound:
self.client.datasources.create(ds)

ds = ds.name
elif not isinstance(ds, str):
elif isinstance(ds, str):
res = {'name': ds}
else:
raise ValueError(f'Unknown type of datasource: {ds}')
return ds

return res


def _check_knowledge_base(self, knowledge_base) -> str:
if isinstance(knowledge_base, KnowledgeBase):
Expand Down Expand Up @@ -355,12 +362,10 @@ def create(
except exc.ObjectNotFound:
...

ds_names = []
ds_list = []
if datasources:
for ds in datasources:
ds = self._check_datasource(ds)

ds_names.append(ds)
ds_list.append(self._check_datasource(ds))

kb_names = []
if knowledge_bases:
Expand Down Expand Up @@ -390,7 +395,7 @@ def create(
'model_name': model_name,
'provider': provider,
'parameters': parameters,
'datasources': ds_names,
'datasources': ds_list,
'knowledge_bases': kb_names
}
)
Expand Down
71 changes: 45 additions & 26 deletions tests/integration/test_base_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
logging.basicConfig(level=logging.DEBUG)

from minds.datasources.examples import example_ds
from minds.datasources import DatabaseConfig
from minds.datasources import DatabaseConfig, DatabaseTables

from minds.exceptions import ObjectNotFound, MindNameInvalid, DatasourceNameInvalid

Expand Down Expand Up @@ -64,8 +64,8 @@ def test_datasources():
def test_minds():
client = get_client()

ds_name = 'test_datasource_'
ds_name2 = 'test_datasource2_'
ds_all_name = 'test_datasource_' # unlimited tables
ds_rentals_name = 'test_datasource2_' # limited to home rentals
mind_name = 'int_test_mind_'
invalid_mind_name = 'mind-123'
mind_name2 = 'int_test_mind2_'
Expand All @@ -80,38 +80,38 @@ def test_minds():
...

# prepare datasources
ds_cfg = copy.copy(example_ds)
ds_cfg.name = ds_name
ds = client.datasources.create(example_ds, update=True)
ds_all_cfg = copy.copy(example_ds)
ds_all_cfg.name = ds_all_name
ds_all = client.datasources.create(ds_all_cfg, update=True)

# second datasource
ds2_cfg = copy.copy(example_ds)
ds2_cfg.name = ds_name2
ds2_cfg.tables = ['home_rentals']
ds_rentals_cfg = copy.copy(example_ds)
ds_rentals_cfg.name = ds_rentals_name
ds_rentals_cfg.tables = ['home_rentals']

# create
with pytest.raises(MindNameInvalid):
client.minds.create(
invalid_mind_name,
datasources=[ds],
datasources=[ds_all],
provider='openai'
)

mind = client.minds.create(
mind_name,
datasources=[ds],
datasources=[ds_all],
provider='openai'
)
mind = client.minds.create(
mind_name,
replace=True,
datasources=[ds.name, ds2_cfg],
datasources=[ds_all.name, ds_rentals_cfg],
prompt_template=prompt1
)
mind = client.minds.create(
mind_name,
update=True,
datasources=[ds.name, ds2_cfg],
datasources=[ds_all.name, ds_rentals_cfg],
prompt_template=prompt1
)

Expand All @@ -131,14 +131,14 @@ def test_minds():
# rename & update
mind.update(
name=mind_name2,
datasources=[ds.name],
datasources=[ds_all.name],
prompt_template=prompt2
)

with pytest.raises(MindNameInvalid):
mind.update(
name=invalid_mind_name,
datasources=[ds.name],
datasources=[ds_all.name],
prompt_template=prompt2
)

Expand All @@ -151,28 +151,40 @@ def test_minds():
assert mind.prompt_template == prompt2

# add datasource
mind.add_datasource(ds2_cfg)
mind.add_datasource(ds_rentals_cfg)
assert len(mind.datasources) == 2

# del datasource
mind.del_datasource(ds2_cfg.name)
mind.del_datasource(ds_rentals_cfg.name)
assert len(mind.datasources) == 1

# ask about data
answer = mind.completion('what is max rental price in home rental?')
assert '5602' in answer.replace(' ', '').replace(',', '')

# limit tables
mind.del_datasource(ds.name)
mind.add_datasource(ds_name2)
mind.del_datasource(ds_all.name)
mind.add_datasource(ds_rentals_name)
assert len(mind.datasources) == 1

answer = mind.completion('what is max rental price in home rental?')
assert '5602' in answer.replace(' ', '').replace(',', '')
check_mind_can_see_only_rentals(mind)

# not accessible table
answer = mind.completion('what is max price in car sales?')
assert '145000' not in answer.replace(' ', '').replace(',', '')
# test ds with limited tables
ds_all_limited = DatabaseTables(
name=ds_all_name,
tables=['home_rentals']
)
# mind = client.minds.create(
# 'mind_ds_limited_',
# replace=True,
# datasources=[ds_all],
# prompt_template=prompt2
# )
mind.update(
name=mind.name,
datasources=[ds_all_limited],
)
check_mind_can_see_only_rentals(mind)

# stream completion
success = False
Expand All @@ -183,6 +195,13 @@ def test_minds():

# drop
client.minds.drop(mind_name2)
client.datasources.drop(ds.name)
client.datasources.drop(ds2_cfg.name)
client.datasources.drop(ds_all.name)
client.datasources.drop(ds_rentals_cfg.name)

def check_mind_can_see_only_rentals(mind):
answer = mind.completion('what is max rental price in home rental?')
assert '5602' in answer.replace(' ', '').replace(',', '')

# not accessible table
answer = mind.completion('what is max price in car sales?')
assert '145000' not in answer.replace(' ', '').replace(',', '')
54 changes: 51 additions & 3 deletions tests/unit/test_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from unittest.mock import patch


from minds.datasources.datasources import DatabaseTables
from minds.datasources.examples import example_ds
from minds.knowledge_bases import EmbeddingConfig, KnowledgeBaseConfig, VectorStoreConfig

Expand Down Expand Up @@ -322,14 +323,59 @@ def check_mind_created(mind, mock_post, create_params, url):
args, kwargs = mock_post.call_args
assert args[0].endswith(url)
request = kwargs['json']
for key in ('name', 'datasources', 'knowledge_bases', 'provider', 'model_name'),:
assert request.get(key) == create_params.get(key)
for key in ('name', 'datasources', 'knowledge_bases', 'provider', 'model_name'):
req, param = request.get(key), create_params.get(key)
if key == 'datasources':
param = [{'name': param[0]}]

assert req == param

assert create_params.get('prompt_template') == request.get('parameters', {}).get('prompt_template')

self.compare_mind(mind, self.mind_json)

check_mind_created(mind, mock_post, create_params, '/api/projects/mindsdb/minds')

# -- datasource with tables --
knowledge_bases = ['example_kb']
provider = 'openai'

response_mock(mock_get, self.mind_json)

ds_conf = DatabaseTables(
name='my_db',
tables=['table1', 'table1'],
)
create_params = {
'name': mind_name,
'prompt_template': prompt_template,
'datasources': [ds_conf],
'knowledge_bases': knowledge_bases
}
mind = client.minds.create(**create_params)

def check_mind_created(mind, mock_post, create_params, url):
args, kwargs = mock_post.call_args
assert args[0].endswith(url)
request = kwargs['json']
for key in ('name', 'datasources', 'knowledge_bases', 'provider', 'model_name'):
req, param = request.get(key), create_params.get(key)
if key == 'datasources':
if param is not None:
ds = param[0]
param = [{'name': ds.name, 'tables': ds.tables}]
else:
param = []
if key == 'knowledge_bases' and param is None:
param = []
assert req == param
assert create_params.get('prompt_template') == request.get('parameters', {}).get('prompt_template')

self.compare_mind(mind, self.mind_json)

check_mind_created(mind, mock_post, create_params, '/api/projects/mindsdb/minds')


# -- with replace --
create_params = {
'name': mind_name,
Expand Down Expand Up @@ -375,7 +421,9 @@ def test_update(self, mock_patch, mock_get):
args, kwargs = mock_patch.call_args
assert args[0].endswith(f'/api/projects/mindsdb/minds/{self.mind_json["name"]}')

assert kwargs['json'] == update_params
params = update_params.copy()
params['datasources'] = [{'name': 'ds_name'}]
assert kwargs['json'] == params

@patch('requests.get')
def test_get(self, mock_get):
Expand Down
Loading