@@ -82,23 +82,21 @@ def __init__(
8282 name : str ,
8383 created_at : datetime .datetime ,
8484 updated_at : datetime .datetime ,
85- model_name : str = None ,
85+ model : Union [ Model , str , dict ] = None ,
8686 skills : List [Skill ] = [],
8787 provider : str = None ,
8888 data : dict = {},
89- model : dict = {},
9089 prompt_template : str = None ,
9190 params : dict = {},
9291 collection : CollectionBase = None
9392 ):
9493 self .name = name
9594 self .created_at = created_at
9695 self .updated_at = updated_at
97- self .model_name = model_name
96+ self .model = model
9897 self .skills = skills
9998 self .provider = provider
10099 self .data = data
101- self .model = model
102100 self .prompt_template = prompt_template
103101 self .params = params
104102 self .collection = collection
@@ -207,15 +205,16 @@ def from_json(cls, json: dict, collection: CollectionBase):
207205 if json .get ('skills' ):
208206 skills = [Skill .from_json (skill ) for skill in json ['skills' ]]
209207
208+ model = json .get ('model' ) or json .get ('model_name' )
209+
210210 return cls (
211211 json ['name' ],
212212 json ['created_at' ],
213213 json ['updated_at' ],
214- json . get ( 'model_name' ) ,
214+ model ,
215215 skills ,
216216 json .get ('provider' ),
217217 json .get ('data' , {}),
218- json .get ('model' , {}),
219218 json .get ('prompt_template' ),
220219 json .get ('params' , {}),
221220 collection
@@ -488,11 +487,10 @@ def add_database(self, name: str, database: str, tables: List[str], description:
488487 def create (
489488 self ,
490489 name : str ,
491- model_name : Union [Model , str ] = None ,
490+ model : Union [Model , str , dict ] = None ,
492491 provider : str = None ,
493492 skills : List [Union [Skill , str ]] = None ,
494493 data : dict = None ,
495- model : dict = None ,
496494 prompt_template : str = None ,
497495 params : dict = None ,
498496 ** kwargs
@@ -501,11 +499,10 @@ def create(
501499 Create new agent and return it
502500
503501 :param name: Name of the agent to be created
504- :param model_name: MindsDB model to be used by the agent
502+ :param model: Model to be used by the agent. This can be a Model object, a string with model name, or a dictionary with model parameters.
505503 :param skills: List of skills to be used by the agent. Currently only 'sql' is supported.
506504 :param provider: Provider of the model, e.g. 'mindsdb', 'openai', etc.
507505 :param data: Data to be used by the agent. This is usually a dictionary with 'tables' and/or 'knowledge_base' keys.
508- :param model: Model parameters to be used by the agent. This is usually a dictionary
509506 :param params: Parameters for the agent
510507
511508 :return: created agent object
@@ -526,10 +523,15 @@ def create(
526523 if params is None :
527524 params = {}
528525 params .update (kwargs )
529-
526+
527+ model_name = None
530528 if isinstance (model_name , Model ):
531529 model_name = model_name .name
532530 provider = 'mindsdb'
531+ model = None
532+ elif isinstance (model , str ):
533+ model_name = model
534+ model = None
533535
534536 agent = self .api .create_agent (
535537 self .project .name ,
@@ -574,20 +576,31 @@ def update(self, name: str, updated_agent: Agent):
574576 existing_skills = set ([s ['name' ] for s in existing_agent ['skills' ]])
575577 skills_to_add = updated_skills .difference (existing_skills )
576578 skills_to_remove = existing_skills .difference (updated_skills )
577- data = self .api .update_agent (
579+ updated_model_name = None
580+ updated_provider = updated_agent .provider
581+ updated_model = None
582+ if isinstance (updated_agent .model , Model ):
583+ updated_model_name = updated_agent .model .name
584+ updated_provider = 'mindsdb'
585+ elif isinstance (updated_agent .model , str ):
586+ updated_model_name = updated_agent .model
587+ elif isinstance (updated_agent .model , dict ):
588+ updated_model = updated_agent .model
589+
590+ agent = self .api .update_agent (
578591 self .project .name ,
579592 name ,
580593 updated_agent .name ,
581- updated_agent . provider ,
582- updated_agent . model_name ,
594+ updated_provider ,
595+ updated_model_name ,
583596 list (skills_to_add ),
584597 list (skills_to_remove ),
585598 updated_agent .data ,
586- updated_agent . model ,
599+ updated_model ,
587600 updated_agent .prompt_template ,
588601 updated_agent .params
589602 )
590- return Agent .from_json (data , self )
603+ return Agent .from_json (agent , self )
591604
592605 def drop (self , name : str ):
593606 """
0 commit comments