1+ import json
2+
13from mindsdb_sql_parser .ast .base import ASTNode
24from mindsdb_sql_parser .utils import indent
35
@@ -9,7 +11,8 @@ class CreateKnowledgeBase(ASTNode):
911 def __init__ (
1012 self ,
1113 name ,
12- model = None ,
14+ embedding_model = None ,
15+ reranking_model = None ,
1316 storage = None ,
1417 from_select = None ,
1518 params = None ,
@@ -20,30 +23,37 @@ def __init__(
2023 """
2124 Args:
2225 name: Identifier -- name of the knowledge base
23- model: Identifier -- name of the model to use
26+ embedding_model: dict -- name of the embedding model to use
27+ reranking_model: dict -- name of the reranking model to use
2428 storage: Identifier -- name of the storage to use
2529 from_select: SelectStatement -- select statement to use as the source of the knowledge base
2630 params: dict -- additional parameters to pass to the knowledge base. E.g., chunking strategy, etc.
2731 if_not_exists: bool -- if True, do not raise an error if the knowledge base already exists
2832 """
2933 super ().__init__ (* args , ** kwargs )
3034 self .name = name
31- self .model = model
35+ self .embedding_model = embedding_model
36+ self .reranking_model = reranking_model
3237 self .storage = storage
3338 self .params = params
3439 self .if_not_exists = if_not_exists
3540 self .from_query = from_select
3641
3742 def to_tree (self , * args , level = 0 , ** kwargs ):
3843 ind = indent (level )
39- storage_str = f"{ ind } storage={ self .storage .to_string ()} ,\n " if self .storage else ""
40- model_str = f"{ ind } model={ self .model .to_string ()} ,\n " if self .model else ""
44+ storage_str = f"{ ind } storage={ self .storage .to_string ()} ," if self .storage else ""
45+ embedding_model_str = f"{ ind } embedding_model={ str (self .embedding_model )} ," if self .embedding_model else ""
46+ reranking_model_str = f"{ ind } reranking_model={ str (self .reranking_model )} ," if self .reranking_model else ""
47+
4148 out_str = f"""
4249 { ind } CreateKnowledgeBase(
4350 { ind } if_not_exists={ self .if_not_exists } ,
4451 { ind } name={ self .name .to_string ()} ,
4552 { ind } from_query={ self .from_query .to_tree (level = level + 1 ) if self .from_query else None } ,
46- { model_str } { storage_str } { ind } params={ self .params }
53+ { ind } { embedding_model_str }
54+ { ind } { reranking_model_str }
55+ { ind } { storage_str }
56+ { ind } params={ self .params }
4757 { ind } )
4858 """
4959 return out_str
@@ -56,8 +66,11 @@ def get_string(self, *args, **kwargs):
5666 using_ar = []
5767 if self .storage :
5868 using_ar .append (f" STORAGE={ self .storage .to_string ()} " )
59- if self .model :
60- using_ar .append (f" MODEL={ self .model .to_string ()} " )
69+ embedding_model_str = ""
70+ if self .embedding_model_str :
71+ using_ar .append (f" EMBEDDING_MODEL={ json .dumps (self .embedding_model )} " )
72+ if self .reranking_model_str :
73+ using_ar .append (f" RERANKING_MODEL={ json .dumps (self .reranking_model )} " )
6174
6275 params = self .params .copy ()
6376 if params :
0 commit comments