Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 20 additions & 11 deletions mindsdb_sdk/databases.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Union
from typing import Dict, List, Union

from mindsdb_sql_parser.ast.mindsdb import CreateDatabase
from mindsdb_sql_parser.ast import DropDatabase, Identifier
Expand All @@ -15,7 +15,7 @@ class Database:
Allows to work with database (datasource): to use tables and make raw queries

To run native query
At this moment query is just saved in Qeury object and not executed
At this moment query is just saved in Query object and not executed

>>> query = database.query('select * from table1') # returns Query

Expand All @@ -27,11 +27,12 @@ class Database:

"""

def __init__(self, server, name, engine=None):
def __init__(self, server, name: str, engine: str = None, params: Dict = None):
self.server = server
self.name = name
self.engine = engine
self.api = server.api
self.params = params

self.tables = Tables(self, self.api)

Expand All @@ -49,6 +50,7 @@ def query(self, sql: str) -> Query:
Make raw query to integration

:param sql: sql of the query
:param database: name of database to query (uses current database by default)
:return: Query object
"""
return Query(self.api, sql, database=self.name)
Expand All @@ -65,7 +67,7 @@ class Databases(CollectionBase):
# create

>>> db = databases.create('example_db',
... type='postgres',
... engine='postgres',
... connection_args={'host': ''})

# drop database
Expand All @@ -81,11 +83,16 @@ class Databases(CollectionBase):
def __init__(self, api):
self.api = api

def _list_databases(self):
def _list_databases(self) -> Dict[str, Database]:
data = self.api.sql_query(
"select NAME, ENGINE from information_schema.databases where TYPE='data'"
"select NAME, ENGINE, CONNECTION_DATA from information_schema.databases where TYPE='data'"
)
return dict(zip(data.NAME, data.ENGINE))
name_to_db = {}
for _, row in data.iterrows():
name_to_db[row["NAME"]] = Database(
self, row["NAME"], engine=row["ENGINE"], params=row["CONNECTION_DATA"]
)
return name_to_db

def list(self) -> List[Database]:
"""
Expand All @@ -94,9 +101,11 @@ def list(self) -> List[Database]:
:return: list of Database objects
"""
databases = self._list_databases()
return [Database(self, name, engine=engine) for name, engine in databases.items()]
return list(databases.values())

def create(self, name: str, engine: Union[str, Handler], connection_args: dict) -> Database:
def create(
self, name: str, engine: Union[str, Handler], connection_args: Dict
) -> Database:
"""
Create new integration and return it

Expand All @@ -114,7 +123,7 @@ def create(self, name: str, engine: Union[str, Handler], connection_args: dict)
parameters=connection_args,
)
self.api.sql_query(ast_query.to_string())
return Database(self, name, engine=engine)
return Database(self, name, engine=engine, params=connection_args)

def drop(self, name: str):
"""
Expand All @@ -135,4 +144,4 @@ def get(self, name: str) -> Database:
databases = self._list_databases()
if name not in databases:
raise AttributeError("Database doesn't exist")
return Database(self, name, engine=databases[name])
return databases[name]
66 changes: 43 additions & 23 deletions tests/test_sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,11 +237,25 @@ def test_flow(self, mock_post, mock_put, mock_get):
assert call_args[0][0] == 'https://cloud.mindsdb.com/api/status'

# --------- databases -------------
response_mock(mock_post, pd.DataFrame([{'NAME': 'db1','ENGINE': 'postgres'}]))
response_mock(
mock_post,
pd.DataFrame(
[
{
"NAME": "db1",
"ENGINE": "postgres",
"CONNECTION_DATA": {"host": "zoop"},
}
]
),
)

databases = server.list_databases()

check_sql_call(mock_post, "select NAME, ENGINE from information_schema.databases where TYPE='data'")
check_sql_call(
mock_post,
"select NAME, ENGINE, CONNECTION_DATA from information_schema.databases where TYPE='data'",
)

database = databases[0]
str(database)
Expand Down Expand Up @@ -283,7 +297,7 @@ def test_flow(self, mock_post, mock_put, mock_get):
check_sql_call(mock_post, 'DROP DATABASE `proj1-1`')

# test upload file
response_mock(mock_post, pd.DataFrame([{'NAME': 'files', 'ENGINE': 'file'}]))
response_mock(mock_post, pd.DataFrame([{'NAME': 'files', 'ENGINE': 'file', 'CONNECTION_DATA': {'host': 'woop'}}]))
database = server.get_database('files')
# create file
df = pd.DataFrame([{'s': '1'}, {'s': 'a'}])
Expand Down Expand Up @@ -465,7 +479,6 @@ def check_project_models(self, project, database, mock_post):
}
)


@patch('requests.Session.post')
def check_project_models_versions(self, project, database, mock_post):
# ----------- model version --------------
Expand Down Expand Up @@ -495,7 +508,6 @@ def check_project_models_versions(self, project, database, mock_post):
project.drop_model_version('m1', 1)
check_sql_call(mock_post, f"DROP PREDICTOR m1.`1`")


@patch('requests.Session.post')
def check_database(self, database, mock_post):

Expand Down Expand Up @@ -545,7 +557,6 @@ def check_database(self, database, mock_post):
database.drop_table('t3')
check_sql_call(mock_post, f'drop table {database.name}.t3')


@patch('requests.Session.post')
def check_project_jobs(self, project, mock_post):

Expand Down Expand Up @@ -608,11 +619,11 @@ def test_flow(self, mock_post, mock_put):
assert call_args[1]['json']['email'] == '[email protected]'

# --------- databases -------------
response_mock(mock_post, pd.DataFrame([{'NAME': 'db1', 'ENGINE': 'postgres'}]))
response_mock(mock_post, pd.DataFrame([{'NAME': 'db1', 'ENGINE': 'postgres', 'CONNECTION_DATA': {}}]))

databases = con.databases.list()

check_sql_call(mock_post, "select NAME, ENGINE from information_schema.databases where TYPE='data'")
check_sql_call(mock_post, "select NAME, ENGINE, CONNECTION_DATA from information_schema.databases where TYPE='data'")

database = databases[0]
assert database.name == 'db1'
Expand Down Expand Up @@ -659,7 +670,7 @@ def test_flow(self, mock_post, mock_put):
check_sql_call(mock_post, 'DROP DATABASE `proj1-1`')

# test upload file
response_mock(mock_post, pd.DataFrame([{'NAME': 'files', 'ENGINE': 'file'}]))
response_mock(mock_post, pd.DataFrame([{'NAME': 'files', 'ENGINE': 'file', 'CONNECTION_DATA': {}}]))
database = con.databases.files
# create file
df = pd.DataFrame([{'s': '1'}, {'s': 'a'}])
Expand Down Expand Up @@ -1415,7 +1426,6 @@ def test_create(self, mock_get, mock_post):

assert new_agent == expected_agent


@patch('requests.Session.get')
@patch('requests.Session.put')
# Mock creating new skills.
Expand Down Expand Up @@ -1480,7 +1490,6 @@ def test_update(self, mock_get, mock_put, _):

assert updated_agent == expected_agent


@patch('requests.Session.post')
def test_completion(self, mock_post):
response_mock(mock_post, {
Expand Down Expand Up @@ -1693,18 +1702,29 @@ def test_add_database(self, mock_post, mock_put, mock_get):
'provider': 'mindsdb'
},
])
responses_mock(mock_post, [
# DB get (POST /sql).
pd.DataFrame([
{'NAME': 'existing_db', 'ENGINE': 'postgres'}
]),
# DB tables get (POST /sql).
pd.DataFrame([
{'name': 'existing_table'}
]),
# Skill creation.
{'name': 'new_skill', 'type': 'sql', 'params': {'database': 'existing_db', 'tables': ['existing_table']}}
])
responses_mock(
mock_post,
[
# DB get (POST /sql).
pd.DataFrame(
[
{
"NAME": "existing_db",
"ENGINE": "postgres",
"CONNECTION_DATA": {"host": "boop"},
}
]
),
# DB tables get (POST /sql).
pd.DataFrame([{"name": "existing_table"}]),
# Skill creation.
{
"name": "new_skill",
"type": "sql",
"params": {"database": "existing_db", "tables": ["existing_table"]},
},
],
)
responses_mock(mock_put, [
# Agent update with new skill.
{
Expand Down