Skip to content

Commit 68b979a

Browse files
authored
Merge pull request #44 from mindsdb/upsert
Upsert minds/datasources
2 parents b643430 + 5d80f37 commit 68b979a

File tree

5 files changed

+49
-20
lines changed

5 files changed

+49
-20
lines changed

minds/datasources/datasources.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,12 @@ class DatabaseConfig(BaseModel):
1616
class Datasource(DatabaseConfig):
1717
...
1818

19+
1920
class Datasources:
2021
def __init__(self, client):
2122
self.api = client.api
2223

23-
def create(self, ds_config: DatabaseConfig, replace=False):
24+
def create(self, ds_config: DatabaseConfig, update=False):
2425
"""
2526
Create new datasource and return it
2627
@@ -30,19 +31,16 @@ def create(self, ds_config: DatabaseConfig, replace=False):
3031
- description: str, description of the database. Used by mind to know what data can be got from it.
3132
- connection_data: dict, optional, credentials to connect to database
3233
- tables: list of str, optional, list of allowed tables
34+
:param update: if true - to update datasourse if exists, default is false
3335
:return: datasource object
3436
"""
3537

3638
name = ds_config.name
3739

38-
if replace:
39-
try:
40-
self.get(name)
41-
self.drop(name, force=True)
42-
except exc.ObjectNotFound:
43-
...
44-
45-
self.api.post('/datasources', data=ds_config.model_dump())
40+
if update:
41+
self.api.put('/datasources', data=ds_config.model_dump())
42+
else:
43+
self.api.post('/datasources', data=ds_config.model_dump())
4644
return self.get(name)
4745

4846
def list(self) -> List[Datasource]:

minds/minds.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,7 @@ def create(
243243
datasources=None,
244244
parameters=None,
245245
replace=False,
246+
update=False,
246247
) -> Mind:
247248
"""
248249
Create a new mind and return it
@@ -259,6 +260,7 @@ def create(
259260
:param datasources: list of datasources used by mind, optional
260261
:param parameters, dict: other parameters of the mind, optional
261262
:param replace: if true - to remove existing mind, default is false
263+
:param update: if true - to update mind if exists, default is false
262264
:return: created mind
263265
"""
264266

@@ -284,7 +286,12 @@ def create(
284286
if 'prompt_template' not in parameters:
285287
parameters['prompt_template'] = DEFAULT_PROMPT_TEMPLATE
286288

287-
self.api.post(
289+
if update:
290+
method = self.api.put
291+
else:
292+
method = self.api.post
293+
294+
method(
288295
f'/projects/{self.project}/minds',
289296
data={
290297
'name': name,

minds/rest_api.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,16 @@ def post(self, url, data):
5757
_raise_for_status(resp)
5858
return resp
5959

60+
def put(self, url, data):
61+
resp = requests.put(
62+
self.base_url + url,
63+
headers=self._headers(),
64+
json=data,
65+
)
66+
67+
_raise_for_status(resp)
68+
return resp
69+
6070
def patch(self, url, data):
6171
resp = requests.patch(
6272
self.base_url + url,

tests/integration/test_base_flow.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ def test_datasources():
3737

3838
# create
3939
ds = client.datasources.create(example_ds)
40-
ds = client.datasources.create(example_ds, replace=True)
40+
assert ds.name == example_ds.name
41+
ds = client.datasources.create(example_ds, update=True)
4142
assert ds.name == example_ds.name
4243

4344
# get
@@ -90,6 +91,12 @@ def test_minds():
9091
datasources=[ds.name, ds2_cfg],
9192
prompt_template=prompt1
9293
)
94+
mind = client.minds.create(
95+
mind_name,
96+
update=True,
97+
datasources=[ds.name, ds2_cfg],
98+
prompt_template=prompt1
99+
)
93100

94101
# get
95102
mind = client.minds.get(mind_name)

tests/unit/test_unit.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,10 @@ def _compare_ds(self, ds1, ds2):
3636
assert ds1.tables == ds2.tables
3737

3838
@patch('requests.get')
39+
@patch('requests.put')
3940
@patch('requests.post')
4041
@patch('requests.delete')
41-
def test_create_datasources(self, mock_del, mock_post, mock_get):
42+
def test_create_datasources(self, mock_del, mock_post, mock_put, mock_get):
4243
client = get_client()
4344
response_mock(mock_get, example_ds.model_dump())
4445

@@ -53,12 +54,9 @@ def check_ds_created(ds, mock_post):
5354

5455
check_ds_created(ds, mock_post)
5556

56-
# with replace
57-
ds = client.datasources.create(example_ds, replace=True)
58-
args, _ = mock_del.call_args
59-
assert args[0].endswith(f'/api/datasources/{example_ds.name}')
60-
61-
check_ds_created(ds, mock_post)
57+
# with update
58+
ds = client.datasources.create(example_ds, update=True)
59+
check_ds_created(ds, mock_put)
6260

6361
@patch('requests.get')
6462
def test_get_datasource(self, mock_get):
@@ -115,9 +113,10 @@ def compare_mind(self, mind, mind_json):
115113
assert mind.parameters == mind_json['parameters']
116114

117115
@patch('requests.get')
116+
@patch('requests.put')
118117
@patch('requests.post')
119118
@patch('requests.delete')
120-
def test_create(self, mock_del, mock_post, mock_get):
119+
def test_create(self, mock_del, mock_post, mock_put, mock_get):
121120
client = get_client()
122121

123122
mind_name = 'test_mind'
@@ -145,7 +144,7 @@ def check_mind_created(mind, mock_post, create_params):
145144

146145
check_mind_created(mind, mock_post, create_params)
147146

148-
# with replace
147+
# -- with replace --
149148
create_params = {
150149
'name': mind_name,
151150
'prompt_template': prompt_template,
@@ -159,6 +158,14 @@ def check_mind_created(mind, mock_post, create_params):
159158

160159
check_mind_created(mind, mock_post, create_params)
161160

161+
# -- with update --
162+
mock_del.reset_mock()
163+
mind = client.minds.create(update=True, **create_params)
164+
# is not deleted
165+
assert not mock_del.called
166+
167+
check_mind_created(mind, mock_put, create_params)
168+
162169
@patch('requests.get')
163170
@patch('requests.patch')
164171
def test_update(self, mock_patch, mock_get):

0 commit comments

Comments
 (0)