Skip to content

Commit 430ec4d

Browse files
fixed unit tests
1 parent 12d023e commit 430ec4d

File tree

2 files changed

+126
-73
lines changed

2 files changed

+126
-73
lines changed

mindsdb_sdk/agents.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -565,9 +565,11 @@ def update(self, name: str, updated_agent: Agent):
565565
updated_skills.add(skill.name)
566566

567567
existing_agent = self.api.agent(self.project.name, name)
568-
existing_skills = set([s['name'] for s in existing_agent['skills']])
568+
569+
existing_skills = set([s['name'] for s in existing_agent.get('skills', [])])
569570
skills_to_add = updated_skills.difference(existing_skills)
570571
skills_to_remove = existing_skills.difference(updated_skills)
572+
571573
updated_model_name = None
572574
updated_provider = updated_agent.provider
573575
updated_model = None

tests/test_sdk.py

Lines changed: 123 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -1337,12 +1337,18 @@ def test_list(self, mock_get):
13371337
'id': 1,
13381338
'name': 'test_agent',
13391339
'project_id': 1,
1340-
'model_name': 'test_model',
1341-
'skills': [],
1340+
'model': {
1341+
'model_name': 'gpt-3.5-turbo',
1342+
'provider': 'openai',
1343+
'api_key': 'sk-...',
1344+
},
1345+
'data': {
1346+
'tables': ['test_database.test_table'],
1347+
'knowledge_bases': ['test_kb'],
1348+
},
13421349
'params': {},
13431350
'created_at': created_at,
13441351
'updated_at': updated_at,
1345-
'provider': 'mindsdb'
13461352
}
13471353
])
13481354
all_agents = server.agents.list()
@@ -1352,12 +1358,17 @@ def test_list(self, mock_get):
13521358
assert len(all_agents) == 1
13531359
expected_agent = Agent(
13541360
'test_agent',
1355-
'test_model',
1356-
[],
1357-
{},
13581361
created_at,
13591362
updated_at,
1360-
'mindsdb'
1363+
model={
1364+
'model_name': 'gpt-3.5-turbo',
1365+
'provider': 'openai',
1366+
'api_key': 'sk-...',
1367+
},
1368+
data={
1369+
'tables': ['test_database.test_table'],
1370+
'knowledge_bases': ['test_kb'],
1371+
},
13611372
)
13621373
assert all_agents[0] == expected_agent
13631374

@@ -1371,118 +1382,136 @@ def test_get(self, mock_get):
13711382
'id': 1,
13721383
'name': 'test_agent',
13731384
'project_id': 1,
1374-
'model_name': 'test_model',
1375-
'skills': [],
1385+
'model': {
1386+
'model_name': 'gpt-3.5-turbo',
1387+
'provider': 'openai',
1388+
'api_key': 'sk-...',
1389+
},
1390+
'data': {
1391+
'tables': ['test_database.test_table'],
1392+
'knowledge_bases': ['test_kb'],
1393+
},
13761394
'params': {},
13771395
'created_at': created_at,
13781396
'updated_at': updated_at,
1379-
'provider': 'mindsdb'
13801397
}
13811398
)
13821399
agent = server.agents.get('test_agent')
13831400
# Check API call.
13841401
assert mock_get.call_args[0][0] == f'{DEFAULT_LOCAL_API_URL}/api/projects/mindsdb/agents/test_agent'
13851402
expected_agent = Agent(
13861403
'test_agent',
1387-
'test_model',
1388-
[],
1389-
{},
13901404
created_at,
13911405
updated_at,
1392-
'mindsdb'
1406+
model={
1407+
'model_name': 'gpt-3.5-turbo',
1408+
'provider': 'openai',
1409+
'api_key': 'sk-...',
1410+
},
1411+
data={
1412+
'tables': ['test_database.test_table'],
1413+
'knowledge_bases': ['test_kb'],
1414+
},
13931415
)
13941416
assert agent == expected_agent
13951417

13961418
@patch('requests.Session.post')
1397-
@patch('requests.Session.get')
1398-
def test_create(self, mock_get, mock_post):
1419+
def test_create(self, mock_post):
13991420
created_at = dt.datetime(2000, 3, 1, 9, 30)
14001421
updated_at = dt.datetime(2001, 3, 1, 9, 30)
14011422
data = {
14021423
'id': 1,
14031424
'name': 'test_agent',
14041425
'project_id': 1,
1405-
'model_name': 'test_model',
1406-
'skills': [{
1407-
'id': 0,
1408-
'name': 'test_skill',
1409-
'project_id': 1,
1410-
'type': 'sql',
1411-
'params': {'tables': ['test_table'], 'database': 'test_database', 'description': 'test_description'},
1412-
}],
1413-
'params': {'k1': 'v1'},
1426+
'model': {
1427+
'model_name': 'gpt-3.5-turbo',
1428+
'provider': 'openai',
1429+
'api_key': 'sk-...',
1430+
},
1431+
'data': {
1432+
'tables': ['test_database.test_table'],
1433+
'knowledge_bases': ['test_kb'],
1434+
},
14141435
'created_at': created_at,
14151436
'updated_at': updated_at,
1416-
'provider': 'mindsdb',
14171437
}
14181438
responses_mock(mock_post, [
1419-
# ML Engine get (SQL post for SHOW ML_ENGINES)
14201439
data
14211440
])
1422-
responses_mock(mock_get, [
1423-
# Skill get.
1424-
{'name': 'test_skill', 'type': 'sql', 'params': {'tables': ['test_table'], 'database': 'test_database', 'description': 'test_description'}},
1425-
])
14261441

14271442
# Create the agent.
14281443
server = mindsdb_sdk.connect()
14291444
new_agent = server.agents.create(
14301445
name='test_agent',
1431-
model=Model(None, {'name':'m1'}),
1432-
skills=['test_skill'],
1433-
params={'k1': 'v1'}
1446+
model={
1447+
'model_name': 'gpt-3.5-turbo',
1448+
'provider': 'openai',
1449+
'api_key': 'sk-...',
1450+
},
1451+
data={
1452+
'tables': ['test_database.test_table'],
1453+
'knowledge_bases': ['test_kb'],
1454+
}
14341455
)
14351456
# Check API call.
14361457
assert len(mock_post.call_args_list) == 1
14371458
assert mock_post.call_args_list[-1][0][0] == f'{DEFAULT_LOCAL_API_URL}/api/projects/mindsdb/agents'
14381459
assert mock_post.call_args_list[-1][1]['json'] == {
14391460
'agent': {
14401461
'name': 'test_agent',
1441-
'model_name': 'm1',
1442-
'skills': ['test_skill'],
1443-
'params': {
1444-
'k1': 'v1',
1445-
'prompt_template': 'Answer the user"s question in a helpful way: {{question}}'
1462+
'model_name': None,
1463+
'provider': None,
1464+
'skills': [],
1465+
'model': {
1466+
'model_name': 'gpt-3.5-turbo',
1467+
'provider': 'openai',
1468+
'api_key': 'sk-...',
14461469
},
1447-
'provider': 'mindsdb'
1470+
'data': {
1471+
'tables': ['test_database.test_table'],
1472+
'knowledge_bases': ['test_kb'],
1473+
},
1474+
'prompt_template': None,
1475+
'params': {}
14481476
}
14491477
}
1450-
expected_skill = SQLSkill('test_skill', ['test_table'], 'test_database', 'test_description')
14511478
expected_agent = Agent(
14521479
'test_agent',
1453-
'test_model',
1454-
[expected_skill],
1455-
{'k1': 'v1'},
1456-
created_at,
1457-
updated_at,
1458-
'mindsdb'
1480+
created_at=created_at,
1481+
updated_at=updated_at,
1482+
model={
1483+
'model_name': 'gpt-3.5-turbo',
1484+
'provider': 'openai',
1485+
'api_key': 'sk-...',
1486+
},
1487+
data={
1488+
'tables': ['test_database.test_table'],
1489+
'knowledge_bases': ['test_kb'],
1490+
}
14591491
)
14601492

14611493
assert new_agent == expected_agent
14621494

14631495
@patch('requests.Session.get')
14641496
@patch('requests.Session.put')
1465-
# Mock creating new skills.
1466-
@patch('requests.Session.post')
1467-
def test_update(self, mock_get, mock_put, _):
1497+
def test_update(self, mock_put, mock_get):
14681498
created_at = dt.datetime(2000, 3, 1, 9, 30)
14691499
updated_at = dt.datetime(2001, 3, 1, 9, 30)
14701500
data = {
14711501
'id': 1,
14721502
'name': 'test_agent',
14731503
'project_id': 1,
1474-
'model_name': 'updated_model',
1475-
'skills': [{
1476-
'id': 1,
1477-
'name': 'updated_skill',
1478-
'project_id': 1,
1479-
'type': 'sql',
1480-
'params': {'tables': ['updated_table'], 'database': 'updated_database', 'description': 'test_description'},
1481-
}],
1482-
'params': {'k2': 'v2'},
1504+
'model': {
1505+
'model_name': 'gpt-3.5-turbo',
1506+
'provider': 'openai',
1507+
'api_key': 'sk-...',
1508+
},
1509+
'data': {
1510+
'tables': ['test_database.test_table'],
1511+
'knowledge_bases': ['test_kb'],
1512+
},
14831513
'created_at': created_at,
14841514
'updated_at': updated_at,
1485-
'provider': 'mindsdb',
14861515
}
14871516
response_mock(mock_put, data)
14881517

@@ -1491,21 +1520,33 @@ def test_update(self, mock_get, mock_put, _):
14911520
'id': 1,
14921521
'name': 'test_agent',
14931522
'project_id': 1,
1494-
'model_name': 'test_model',
1495-
'skills': [],
1496-
'params': {'k1': 'v1'},
1497-
'provider': 'mindsdb',
1523+
'model': {
1524+
'model_name': 'gpt-3.5-turbo',
1525+
'provider': 'openai',
1526+
'api_key': 'sk-...',
1527+
},
1528+
'data': {
1529+
'tables': ['test_database.test_table'],
1530+
'knowledge_bases': ['test_kb'],
1531+
},
1532+
'created_at': created_at,
1533+
'updated_at': updated_at,
14981534
})
14991535

15001536
server = mindsdb_sdk.connect()
15011537
expected_agent = Agent(
15021538
'test_agent',
1503-
'updated_model',
1504-
[SQLSkill('updated_skill', ['updated_table'], 'updated_database', 'test_description')],
1505-
{'k2': 'v2'},
15061539
created_at,
15071540
updated_at,
1508-
'mindsdb'
1541+
model={
1542+
'model_name': 'gpt-3.5-turbo',
1543+
'provider': 'openai',
1544+
'api_key': 'sk-...',
1545+
},
1546+
data={
1547+
'tables': ['test_database.test_table'],
1548+
'knowledge_bases': ['test_kb', 'test_kb2'],
1549+
},
15091550
)
15101551

15111552
updated_agent = server.agents.update('test_agent', expected_agent)
@@ -1514,11 +1555,21 @@ def test_update(self, mock_get, mock_put, _):
15141555
assert mock_put.call_args[1]['json'] == {
15151556
'agent': {
15161557
'name': 'test_agent',
1517-
'model_name': 'updated_model',
1518-
'skills_to_add': ['updated_skill'],
1558+
'model_name': None,
1559+
'provider': None,
1560+
'skills_to_add': [],
15191561
'skills_to_remove': [],
1520-
'params': {'k2': 'v2'},
1521-
'provider': 'mindsdb'
1562+
'data': {
1563+
'tables': ['test_database.test_table'],
1564+
'knowledge_bases': ['test_kb', 'test_kb2'],
1565+
},
1566+
'model': {
1567+
'model_name': 'gpt-3.5-turbo',
1568+
'provider': 'openai',
1569+
'api_key': 'sk-...',
1570+
},
1571+
'prompt_template': None,
1572+
'params': {},
15221573
}
15231574
}
15241575

0 commit comments

Comments
 (0)