Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 73 additions & 32 deletions mindsdb_sdk/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,21 +80,25 @@ class Agent:
def __init__(
self,
name: str,
model_name: str,
skills: List[Skill],
params: dict,
created_at: datetime.datetime,
updated_at: datetime.datetime,
model: Union[Model, str, dict] = None,
skills: List[Skill] = [],
provider: str = None,
data: dict = {},
prompt_template: str = None,
params: dict = {},
collection: CollectionBase = None
):
self.name = name
self.model_name = model_name
self.provider = provider
self.skills = skills
self.params = params
self.created_at = created_at
self.updated_at = updated_at
self.model = model
self.skills = skills
self.provider = provider
self.data = data
self.prompt_template = prompt_template
self.params = params
self.collection = collection

def completion(self, messages: List[dict]) -> AgentCompletion:
Expand Down Expand Up @@ -197,14 +201,22 @@ def __eq__(self, other):

@classmethod
def from_json(cls, json: dict, collection: CollectionBase):
skills = []
if json.get('skills'):
skills = [Skill.from_json(skill) for skill in json['skills']]

model = json.get('model') or json.get('model_name')

return cls(
json['name'],
json['model_name'],
[Skill.from_json(skill) for skill in json['skills']],
json['params'],
json['created_at'],
json['updated_at'],
json['provider'],
model,
skills,
json.get('provider'),
json.get('data', {}),
json.get('prompt_template'),
json.get('params', {}),
collection
)

Expand Down Expand Up @@ -473,19 +485,24 @@ def add_database(self, name: str, database: str, tables: List[str], description:
self.update(agent.name, agent)

def create(
self,
name: str,
model: Union[Model, dict, str] = None,
provider: str = None,
skills: List[Union[Skill, str]] = None,
params: dict = None,
**kwargs) -> Agent:
self,
name: str,
model: Union[Model, str, dict] = None,
provider: str = None,
skills: List[Union[Skill, str]] = None,
data: dict = None,
prompt_template: str = None,
params: dict = None,
**kwargs
) -> Agent:
"""
Create new agent and return it

:param name: Name of the agent to be created
:param model: Model to be used by the agent
:param model: Model to be used by the agent. This can be a Model object, a string with model name, or a dictionary with model parameters.
:param skills: List of skills to be used by the agent. Currently only 'sql' is supported.
:param provider: Provider of the model, e.g. 'mindsdb', 'openai', etc.
:param data: Data to be used by the agent. This is usually a dictionary with 'tables' and/or 'knowledge_base' keys.
:param params: Parameters for the agent

:return: created agent object
Expand All @@ -507,17 +524,27 @@ def create(
params = {}
params.update(kwargs)

if 'prompt_template' not in params:
params['prompt_template'] = _DEFAULT_LLM_PROMPT

if model is None:
model = _DEFAULT_LLM_MODEL
elif isinstance(model, Model):
model = model.name
model_name = None
if isinstance(model_name, Model):
model_name = model_name.name
provider = 'mindsdb'
model = None
elif isinstance(model, str):
model_name = model
model = None

data = self.api.create_agent(self.project.name, name, model, provider, skill_names, params)
return Agent.from_json(data, self)
agent = self.api.create_agent(
self.project.name,
name,
model_name,
provider,
skill_names,
data,
model,
prompt_template,
params
)
return Agent.from_json(agent, self)

def update(self, name: str, updated_agent: Agent):
"""
Expand Down Expand Up @@ -549,17 +576,31 @@ def update(self, name: str, updated_agent: Agent):
existing_skills = set([s['name'] for s in existing_agent['skills']])
skills_to_add = updated_skills.difference(existing_skills)
skills_to_remove = existing_skills.difference(updated_skills)
data = self.api.update_agent(
updated_model_name = None
updated_provider = updated_agent.provider
updated_model = None
if isinstance(updated_agent.model, Model):
updated_model_name = updated_agent.model.name
updated_provider = 'mindsdb'
elif isinstance(updated_agent.model, str):
updated_model_name = updated_agent.model
elif isinstance(updated_agent.model, dict):
updated_model = updated_agent.model

agent = self.api.update_agent(
self.project.name,
name,
updated_agent.name,
updated_agent.provider,
updated_agent.model_name,
updated_provider,
updated_model_name,
list(skills_to_add),
list(skills_to_remove),
updated_agent.data,
updated_model,
updated_agent.prompt_template,
updated_agent.params
)
return Agent.from_json(data, self)
return Agent.from_json(agent, self)

def drop(self, name: str):
"""
Expand Down
46 changes: 33 additions & 13 deletions mindsdb_sdk/connectors/rest_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,16 +306,30 @@ def agent_completion_stream_v2(self, project: str, name: str, messages: List[dic
yield e

@_try_relogin
def create_agent(self, project: str, name: str, model: str = None, provider: str = None, skills: List[str] = None, params: dict = None):
def create_agent(
self,
project: str,
name: str,
model_name: str = None,
provider: str = None,
skills: List[str] = None,
data: dict = None,
model: dict = None,
prompt_template: str = None,
params: dict = None
):
url = self.url + f'/api/projects/{project}/agents'
r = self.session.post(
url,
json={
'agent': {
'name': name,
'model_name': model,
'model_name': model_name,
'provider': provider,
'skills': skills,
'data': data,
'model': model,
'prompt_template': prompt_template,
'params': params
}
}
Expand All @@ -325,26 +339,32 @@ def create_agent(self, project: str, name: str, model: str = None, provider: str

@_try_relogin
def update_agent(
self,
project: str,
name: str,
updated_name: str,
updated_provider: str,
updated_model: str,
skills_to_add: List[str],
skills_to_remove: List[str],
updated_params: dict
):
self,
project: str,
name: str,
updated_name: str,
updated_provider: str,
updated_model_name: str,
skills_to_add: List[str],
skills_to_remove: List[str],
updated_data: dict,
updated_model: dict,
updated_prompt_template: str,
updated_params: dict
):
url = self.url + f'/api/projects/{project}/agents/{name}'
r = self.session.put(
url,
json={
'agent': {
'name': updated_name,
'model_name': updated_model,
'model_name': updated_model_name,
'provider': updated_provider,
'skills_to_add': skills_to_add,
'skills_to_remove': skills_to_remove,
'data': updated_data,
'model': updated_model,
'prompt_template': updated_prompt_template,
'params': updated_params
}
}
Expand Down
Loading