diff --git a/README.md b/README.md index 5780968..bb6030a 100644 --- a/README.md +++ b/README.md @@ -75,6 +75,17 @@ mind3.add_datasource(postgres_config) # Using the config mind3.add_datasource(datasource) # Using the data source object ``` +Create mind with tables restriction for datasource: +```python +from minds.datasources.datasources import DatabaseTables +datasource = DatabaseTables( + name='my_db', + tables=['table1', 'table1'], +) +mind4 = client.minds.create(name='mind_name', datasources=[datasource]) +``` + + ### Managing Minds You can create a mind or replace an existing one with the same name. diff --git a/examples/base_usage.py b/examples/base_usage.py index 50df904..3296801 100644 --- a/examples/base_usage.py +++ b/examples/base_usage.py @@ -40,6 +40,11 @@ # with prompt template mind = client.minds.create(name='mind_name', prompt_template='You are codding assistant') +# restrict tables for datasource in context of the mind: +from minds.datasources.datasources import DatabaseTables +datasource = DatabaseTables(name='my_datasource', tables=['table1', 'table1']) +mind = client.minds.create(name='mind_name', datasources=[datasource]) + # or add to existed mind mind = client.minds.create(name='mind_name') # by config diff --git a/minds/datasources/datasources.py b/minds/datasources/datasources.py index a973dd4..e26a680 100644 --- a/minds/datasources/datasources.py +++ b/minds/datasources/datasources.py @@ -4,16 +4,35 @@ import minds.utils as utils import minds.exceptions as exc -class DatabaseConfig(BaseModel): +class DatabaseConfigBase(BaseModel): + """ + Base class + """ name: str + tables: Union[List[str], None] = [] + + +class DatabaseTables(DatabaseConfigBase): + """ + Used when only database and tables are required to be defined. For example in minds.create + """ + ... + + +class DatabaseConfig(DatabaseConfigBase): + """ + Used to define datasource before creating it. + """ engine: str description: str connection_data: Union[dict, None] = {} - tables: Union[List[str], None] = [] class Datasource(DatabaseConfig): + """ + Existed datasource. It is returned by this SDK when datasource is queried from server + """ ... diff --git a/minds/minds.py b/minds/minds.py index 99d151f..75679d2 100644 --- a/minds/minds.py +++ b/minds/minds.py @@ -2,7 +2,7 @@ from openai import OpenAI import minds.utils as utils import minds.exceptions as exc -from minds.datasources import Datasource, DatabaseConfig +from minds.datasources import Datasource, DatabaseConfig, DatabaseTables, DatabaseConfigBase from minds.knowledge_bases import KnowledgeBase, KnowledgeBaseConfig DEFAULT_PROMPT_TEMPLATE = 'Use your database tools to answer the user\'s question: {{question}}' @@ -89,11 +89,10 @@ def update( utils.validate_mind_name(name) if datasources is not None: - ds_names = [] + ds_list = [] for ds in datasources: - ds = self.client.minds._check_datasource(ds) - ds_names.append(ds) - data['datasources'] = ds_names + ds_list.append(self.client.minds._check_datasource(ds)) + data['datasources'] = ds_list if knowledge_bases is not None: kb_names = [] @@ -145,7 +144,7 @@ def add_datasource(self, datasource: Datasource): :param datasource: input datasource """ - ds_name = self.client.minds._check_datasource(datasource) + ds_name = self.client.minds._check_datasource(datasource)['name'] self.api.post( f'/projects/{self.project}/minds/{self.name}/datasources', @@ -279,20 +278,28 @@ def get(self, name: str) -> Mind: item = self.api.get(f'/projects/{self.project}/minds/{name}').json() return Mind(self.client, **item) - def _check_datasource(self, ds) -> str: - if isinstance(ds, Datasource): - ds = ds.name - elif isinstance(ds, DatabaseConfig): - # if not exists - create - try: - self.client.datasources.get(ds.name) - except exc.ObjectNotFound: - self.client.datasources.create(ds) + def _check_datasource(self, ds) -> dict: + if isinstance(ds, DatabaseConfigBase): + res = {'name': ds.name} + + if isinstance(ds, DatabaseTables): + if ds.tables: + res['tables'] = ds.tables + + if isinstance(ds, DatabaseConfig): + # if not exists - create + try: + self.client.datasources.get(ds.name) + except exc.ObjectNotFound: + self.client.datasources.create(ds) - ds = ds.name - elif not isinstance(ds, str): + elif isinstance(ds, str): + res = {'name': ds} + else: raise ValueError(f'Unknown type of datasource: {ds}') - return ds + + return res + def _check_knowledge_base(self, knowledge_base) -> str: if isinstance(knowledge_base, KnowledgeBase): @@ -355,12 +362,10 @@ def create( except exc.ObjectNotFound: ... - ds_names = [] + ds_list = [] if datasources: for ds in datasources: - ds = self._check_datasource(ds) - - ds_names.append(ds) + ds_list.append(self._check_datasource(ds)) kb_names = [] if knowledge_bases: @@ -390,7 +395,7 @@ def create( 'model_name': model_name, 'provider': provider, 'parameters': parameters, - 'datasources': ds_names, + 'datasources': ds_list, 'knowledge_bases': kb_names } ) diff --git a/tests/integration/test_base_flow.py b/tests/integration/test_base_flow.py index 66b659b..66c749e 100644 --- a/tests/integration/test_base_flow.py +++ b/tests/integration/test_base_flow.py @@ -8,7 +8,7 @@ logging.basicConfig(level=logging.DEBUG) from minds.datasources.examples import example_ds -from minds.datasources import DatabaseConfig +from minds.datasources import DatabaseConfig, DatabaseTables from minds.exceptions import ObjectNotFound, MindNameInvalid, DatasourceNameInvalid @@ -64,8 +64,8 @@ def test_datasources(): def test_minds(): client = get_client() - ds_name = 'test_datasource_' - ds_name2 = 'test_datasource2_' + ds_all_name = 'test_datasource_' # unlimited tables + ds_rentals_name = 'test_datasource2_' # limited to home rentals mind_name = 'int_test_mind_' invalid_mind_name = 'mind-123' mind_name2 = 'int_test_mind2_' @@ -80,38 +80,38 @@ def test_minds(): ... # prepare datasources - ds_cfg = copy.copy(example_ds) - ds_cfg.name = ds_name - ds = client.datasources.create(example_ds, update=True) + ds_all_cfg = copy.copy(example_ds) + ds_all_cfg.name = ds_all_name + ds_all = client.datasources.create(ds_all_cfg, update=True) # second datasource - ds2_cfg = copy.copy(example_ds) - ds2_cfg.name = ds_name2 - ds2_cfg.tables = ['home_rentals'] + ds_rentals_cfg = copy.copy(example_ds) + ds_rentals_cfg.name = ds_rentals_name + ds_rentals_cfg.tables = ['home_rentals'] # create with pytest.raises(MindNameInvalid): client.minds.create( invalid_mind_name, - datasources=[ds], + datasources=[ds_all], provider='openai' ) mind = client.minds.create( mind_name, - datasources=[ds], + datasources=[ds_all], provider='openai' ) mind = client.minds.create( mind_name, replace=True, - datasources=[ds.name, ds2_cfg], + datasources=[ds_all.name, ds_rentals_cfg], prompt_template=prompt1 ) mind = client.minds.create( mind_name, update=True, - datasources=[ds.name, ds2_cfg], + datasources=[ds_all.name, ds_rentals_cfg], prompt_template=prompt1 ) @@ -131,14 +131,14 @@ def test_minds(): # rename & update mind.update( name=mind_name2, - datasources=[ds.name], + datasources=[ds_all.name], prompt_template=prompt2 ) with pytest.raises(MindNameInvalid): mind.update( name=invalid_mind_name, - datasources=[ds.name], + datasources=[ds_all.name], prompt_template=prompt2 ) @@ -151,11 +151,11 @@ def test_minds(): assert mind.prompt_template == prompt2 # add datasource - mind.add_datasource(ds2_cfg) + mind.add_datasource(ds_rentals_cfg) assert len(mind.datasources) == 2 # del datasource - mind.del_datasource(ds2_cfg.name) + mind.del_datasource(ds_rentals_cfg.name) assert len(mind.datasources) == 1 # ask about data @@ -163,16 +163,28 @@ def test_minds(): assert '5602' in answer.replace(' ', '').replace(',', '') # limit tables - mind.del_datasource(ds.name) - mind.add_datasource(ds_name2) + mind.del_datasource(ds_all.name) + mind.add_datasource(ds_rentals_name) assert len(mind.datasources) == 1 - answer = mind.completion('what is max rental price in home rental?') - assert '5602' in answer.replace(' ', '').replace(',', '') + check_mind_can_see_only_rentals(mind) - # not accessible table - answer = mind.completion('what is max price in car sales?') - assert '145000' not in answer.replace(' ', '').replace(',', '') + # test ds with limited tables + ds_all_limited = DatabaseTables( + name=ds_all_name, + tables=['home_rentals'] + ) + # mind = client.minds.create( + # 'mind_ds_limited_', + # replace=True, + # datasources=[ds_all], + # prompt_template=prompt2 + # ) + mind.update( + name=mind.name, + datasources=[ds_all_limited], + ) + check_mind_can_see_only_rentals(mind) # stream completion success = False @@ -183,6 +195,13 @@ def test_minds(): # drop client.minds.drop(mind_name2) - client.datasources.drop(ds.name) - client.datasources.drop(ds2_cfg.name) + client.datasources.drop(ds_all.name) + client.datasources.drop(ds_rentals_cfg.name) + +def check_mind_can_see_only_rentals(mind): + answer = mind.completion('what is max rental price in home rental?') + assert '5602' in answer.replace(' ', '').replace(',', '') + # not accessible table + answer = mind.completion('what is max price in car sales?') + assert '145000' not in answer.replace(' ', '').replace(',', '') diff --git a/tests/unit/test_unit.py b/tests/unit/test_unit.py index b7b2c87..3f33fa1 100644 --- a/tests/unit/test_unit.py +++ b/tests/unit/test_unit.py @@ -3,6 +3,7 @@ from unittest.mock import patch +from minds.datasources.datasources import DatabaseTables from minds.datasources.examples import example_ds from minds.knowledge_bases import EmbeddingConfig, KnowledgeBaseConfig, VectorStoreConfig @@ -322,14 +323,59 @@ def check_mind_created(mind, mock_post, create_params, url): args, kwargs = mock_post.call_args assert args[0].endswith(url) request = kwargs['json'] - for key in ('name', 'datasources', 'knowledge_bases', 'provider', 'model_name'),: - assert request.get(key) == create_params.get(key) + for key in ('name', 'datasources', 'knowledge_bases', 'provider', 'model_name'): + req, param = request.get(key), create_params.get(key) + if key == 'datasources': + param = [{'name': param[0]}] + + assert req == param + + assert create_params.get('prompt_template') == request.get('parameters', {}).get('prompt_template') + + self.compare_mind(mind, self.mind_json) + + check_mind_created(mind, mock_post, create_params, '/api/projects/mindsdb/minds') + + # -- datasource with tables -- + knowledge_bases = ['example_kb'] + provider = 'openai' + + response_mock(mock_get, self.mind_json) + + ds_conf = DatabaseTables( + name='my_db', + tables=['table1', 'table1'], + ) + create_params = { + 'name': mind_name, + 'prompt_template': prompt_template, + 'datasources': [ds_conf], + 'knowledge_bases': knowledge_bases + } + mind = client.minds.create(**create_params) + + def check_mind_created(mind, mock_post, create_params, url): + args, kwargs = mock_post.call_args + assert args[0].endswith(url) + request = kwargs['json'] + for key in ('name', 'datasources', 'knowledge_bases', 'provider', 'model_name'): + req, param = request.get(key), create_params.get(key) + if key == 'datasources': + if param is not None: + ds = param[0] + param = [{'name': ds.name, 'tables': ds.tables}] + else: + param = [] + if key == 'knowledge_bases' and param is None: + param = [] + assert req == param assert create_params.get('prompt_template') == request.get('parameters', {}).get('prompt_template') self.compare_mind(mind, self.mind_json) check_mind_created(mind, mock_post, create_params, '/api/projects/mindsdb/minds') + # -- with replace -- create_params = { 'name': mind_name, @@ -375,7 +421,9 @@ def test_update(self, mock_patch, mock_get): args, kwargs = mock_patch.call_args assert args[0].endswith(f'/api/projects/mindsdb/minds/{self.mind_json["name"]}') - assert kwargs['json'] == update_params + params = update_params.copy() + params['datasources'] = [{'name': 'ds_name'}] + assert kwargs['json'] == params @patch('requests.get') def test_get(self, mock_get):