Skip to content

Commit 5139504

Browse files
committed
add dataframe to agent
1 parent 633527f commit 5139504

File tree

2 files changed

+112
-0
lines changed

2 files changed

+112
-0
lines changed

mindsdb_sdk/agents.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from uuid import uuid4
55
import datetime
66
import json
7+
import pandas as pd
78

89
from mindsdb_sdk.knowledge_bases import KnowledgeBase
910
from mindsdb_sdk.models import Model
@@ -155,6 +156,41 @@ def add_webpage(
155156
"""
156157
self.collection.add_webpage(self.name, url, description, knowledge_base=knowledge_base, crawl_depth=crawl_depth, filters=filters)
157158

159+
def add_dataframe(
160+
self,
161+
df: pd.DataFrame,
162+
description: str,
163+
knowledge_base: str = None
164+
):
165+
"""
166+
Add a list of webpages to the agent for retrieval.
167+
168+
:param df: dataframe to be added.
169+
:param description: Description of the webpages. Used by agent to know when to do retrieval.
170+
:param knowledge_base: Name of an existing knowledge base to be used. Will create a default knowledge base if not given.
171+
"""
172+
if df is None or df.empty:
173+
return
174+
175+
if knowledge_base is not None:
176+
kb = self.collection.knowledge_bases.get(knowledge_base)
177+
else:
178+
kb_name = f'{self.name.lower()}_df_{uuid4().hex}_kb'
179+
kb = self.collection._create_default_knowledge_base(self, kb_name)
180+
181+
# Insert crawled webpage.
182+
kb.insert(df)
183+
184+
# Make sure skill name is unique.
185+
skill_name = f'df_retrieval_skill_{uuid4().hex}'
186+
retrieval_params = {
187+
'source': kb.name,
188+
'description': description,
189+
}
190+
dataframe_retrieval_skill = self.collection.skills.create(skill_name, 'retrieval', retrieval_params)
191+
self.skills.append(dataframe_retrieval_skill)
192+
self.collection.update(self.name, self)
193+
158194
def add_database(self, database: str, tables: List[str], description: str):
159195
"""
160196
Add a database to the agent for retrieval.

tests/test_sdk.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1664,6 +1664,82 @@ def test_add_webpage(self, mock_post, mock_put, mock_get):
16641664
}
16651665
assert agent_update_json == expected_agent_json
16661666

1667+
@patch('requests.Session.get')
1668+
@patch('requests.Session.put')
1669+
@patch('requests.Session.post')
1670+
def test_add_dataframe(self, mock_post, mock_put, mock_get):
1671+
server = mindsdb_sdk.connect()
1672+
responses_mock(mock_get, [
1673+
# Existing agent get.
1674+
{
1675+
'name': 'test_agent',
1676+
'model_name': 'test_model',
1677+
'skills': [],
1678+
'params': {},
1679+
'created_at': None,
1680+
'updated_at': None,
1681+
'provider': 'mindsdb'
1682+
},
1683+
# get KB
1684+
{
1685+
'id': 1,
1686+
'name': 'my_kb',
1687+
'project_id': 1,
1688+
'embedding_model': 'openai_emb',
1689+
'vector_database': 'pvec',
1690+
'vector_database_table': 'tbl1',
1691+
'updated_at': '2024-10-04 10:55:25.350799',
1692+
'created_at': '2024-10-04 10:55:25.350790',
1693+
'params': {}
1694+
},
1695+
# Skills get in Agent update to check if it exists.
1696+
{'name':'new_skill', 'type':'retrieval', 'params':{'source':'test_agent_docs_mdb_ai_kb'}},
1697+
# Existing agent get in Agent update.
1698+
{
1699+
'name':'test_agent',
1700+
'model_name':'test_model',
1701+
'skills':[],
1702+
'params':{},
1703+
'created_at':None,
1704+
'updated_at':None,
1705+
'provider':'mindsdb' # Added provider field
1706+
},
1707+
])
1708+
responses_mock(mock_post, [
1709+
# Skill creation.
1710+
{'name':'new_skill', 'type':'retrieval', 'params':{'source':'test_agent_docs_mdb_ai_kb'}}
1711+
])
1712+
responses_mock(mock_put, [
1713+
# KB update.
1714+
{'name':'test_agent_docs_mdb_ai_kb'},
1715+
# Agent update with new skill.
1716+
{
1717+
'name':'test_agent',
1718+
'model_name':'test_model',
1719+
'skills':[{'name':'new_skill', 'type':'retrieval', 'params':{'source':'test_agent_docs_mdb_ai_kb'}}],
1720+
'params':{},
1721+
'created_at':None,
1722+
'updated_at':None,
1723+
'provider':'mindsdb' # Added provider field
1724+
},
1725+
])
1726+
server.agents.test_agent.add_dataframe(pd.DataFrame([{'content': 'doc'}]), 'Documentation for MindsDB', 'existing_kb')
1727+
1728+
# Check Agent was updated with a new skill.
1729+
agent_update_json = mock_put.call_args[-1]['json']
1730+
expected_agent_json = {
1731+
'agent':{
1732+
'name':'test_agent',
1733+
'model_name':'test_model',
1734+
# Skill name is a generated UUID.
1735+
'skills_to_add':[agent_update_json['agent']['skills_to_add'][0]],
1736+
'skills_to_remove':[],
1737+
'params':{},
1738+
'provider': 'mindsdb'
1739+
}
1740+
}
1741+
assert agent_update_json == expected_agent_json
1742+
16671743
@patch('requests.Session.get')
16681744
@patch('requests.Session.put')
16691745
@patch('requests.Session.post')

0 commit comments

Comments
 (0)