Skip to content

Commit 70d168f

Browse files
combined the model_name and model parameters
1 parent a935f81 commit 70d168f

File tree

1 file changed

+29
-16
lines changed

1 file changed

+29
-16
lines changed

mindsdb_sdk/agents.py

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -82,23 +82,21 @@ def __init__(
8282
name: str,
8383
created_at: datetime.datetime,
8484
updated_at: datetime.datetime,
85-
model_name: str = None,
85+
model: Union[Model, str, dict] = None,
8686
skills: List[Skill] = [],
8787
provider: str = None,
8888
data: dict = {},
89-
model: dict = {},
9089
prompt_template: str = None,
9190
params: dict = {},
9291
collection: CollectionBase = None
9392
):
9493
self.name = name
9594
self.created_at = created_at
9695
self.updated_at = updated_at
97-
self.model_name = model_name
96+
self.model = model
9897
self.skills = skills
9998
self.provider = provider
10099
self.data = data
101-
self.model = model
102100
self.prompt_template = prompt_template
103101
self.params = params
104102
self.collection = collection
@@ -207,15 +205,16 @@ def from_json(cls, json: dict, collection: CollectionBase):
207205
if json.get('skills'):
208206
skills = [Skill.from_json(skill) for skill in json['skills']]
209207

208+
model = json.get('model') or json.get('model_name')
209+
210210
return cls(
211211
json['name'],
212212
json['created_at'],
213213
json['updated_at'],
214-
json.get('model_name'),
214+
model,
215215
skills,
216216
json.get('provider'),
217217
json.get('data', {}),
218-
json.get('model', {}),
219218
json.get('prompt_template'),
220219
json.get('params', {}),
221220
collection
@@ -488,11 +487,10 @@ def add_database(self, name: str, database: str, tables: List[str], description:
488487
def create(
489488
self,
490489
name: str,
491-
model_name: Union[Model, str] = None,
490+
model: Union[Model, str, dict] = None,
492491
provider: str = None,
493492
skills: List[Union[Skill, str]] = None,
494493
data: dict = None,
495-
model: dict = None,
496494
prompt_template: str = None,
497495
params: dict = None,
498496
**kwargs
@@ -501,11 +499,10 @@ def create(
501499
Create new agent and return it
502500
503501
:param name: Name of the agent to be created
504-
:param model_name: MindsDB model to be used by the agent
502+
:param model: Model to be used by the agent. This can be a Model object, a string with model name, or a dictionary with model parameters.
505503
:param skills: List of skills to be used by the agent. Currently only 'sql' is supported.
506504
:param provider: Provider of the model, e.g. 'mindsdb', 'openai', etc.
507505
:param data: Data to be used by the agent. This is usually a dictionary with 'tables' and/or 'knowledge_base' keys.
508-
:param model: Model parameters to be used by the agent. This is usually a dictionary
509506
:param params: Parameters for the agent
510507
511508
:return: created agent object
@@ -526,10 +523,15 @@ def create(
526523
if params is None:
527524
params = {}
528525
params.update(kwargs)
529-
526+
527+
model_name = None
530528
if isinstance(model_name, Model):
531529
model_name = model_name.name
532530
provider = 'mindsdb'
531+
model = None
532+
elif isinstance(model, str):
533+
model_name = model
534+
model = None
533535

534536
agent = self.api.create_agent(
535537
self.project.name,
@@ -574,20 +576,31 @@ def update(self, name: str, updated_agent: Agent):
574576
existing_skills = set([s['name'] for s in existing_agent['skills']])
575577
skills_to_add = updated_skills.difference(existing_skills)
576578
skills_to_remove = existing_skills.difference(updated_skills)
577-
data = self.api.update_agent(
579+
updated_model_name = None
580+
updated_provider = updated_agent.provider
581+
updated_model = None
582+
if isinstance(updated_agent.model, Model):
583+
updated_model_name = updated_agent.model.name
584+
updated_provider = 'mindsdb'
585+
elif isinstance(updated_agent.model, str):
586+
updated_model_name = updated_agent.model
587+
elif isinstance(updated_agent.model, dict):
588+
updated_model = updated_agent.model
589+
590+
agent = self.api.update_agent(
578591
self.project.name,
579592
name,
580593
updated_agent.name,
581-
updated_agent.provider,
582-
updated_agent.model_name,
594+
updated_provider,
595+
updated_model_name,
583596
list(skills_to_add),
584597
list(skills_to_remove),
585598
updated_agent.data,
586-
updated_agent.model,
599+
updated_model,
587600
updated_agent.prompt_template,
588601
updated_agent.params
589602
)
590-
return Agent.from_json(data, self)
603+
return Agent.from_json(agent, self)
591604

592605
def drop(self, name: str):
593606
"""

0 commit comments

Comments
 (0)