Skip to content

Commit 203c86c

Browse files
updated the attributes of the CreateKnowledgeBase class
1 parent 26cece0 commit 203c86c

File tree

1 file changed

+21
-8
lines changed

1 file changed

+21
-8
lines changed

mindsdb_sql_parser/ast/mindsdb/knowledge_base.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import json
2+
13
from mindsdb_sql_parser.ast.base import ASTNode
24
from mindsdb_sql_parser.utils import indent
35

@@ -9,7 +11,8 @@ class CreateKnowledgeBase(ASTNode):
911
def __init__(
1012
self,
1113
name,
12-
model=None,
14+
embedding_model=None,
15+
reranking_model=None,
1316
storage=None,
1417
from_select=None,
1518
params=None,
@@ -20,30 +23,37 @@ def __init__(
2023
"""
2124
Args:
2225
name: Identifier -- name of the knowledge base
23-
model: Identifier -- name of the model to use
26+
embedding_model: dict -- name of the embedding model to use
27+
reranking_model: dict -- name of the reranking model to use
2428
storage: Identifier -- name of the storage to use
2529
from_select: SelectStatement -- select statement to use as the source of the knowledge base
2630
params: dict -- additional parameters to pass to the knowledge base. E.g., chunking strategy, etc.
2731
if_not_exists: bool -- if True, do not raise an error if the knowledge base already exists
2832
"""
2933
super().__init__(*args, **kwargs)
3034
self.name = name
31-
self.model = model
35+
self.embedding_model = embedding_model
36+
self.reranking_model = reranking_model
3237
self.storage = storage
3338
self.params = params
3439
self.if_not_exists = if_not_exists
3540
self.from_query = from_select
3641

3742
def to_tree(self, *args, level=0, **kwargs):
3843
ind = indent(level)
39-
storage_str = f"{ind} storage={self.storage.to_string()},\n" if self.storage else ""
40-
model_str = f"{ind} model={self.model.to_string()},\n" if self.model else ""
44+
storage_str = f"{ind}storage={self.storage.to_string()}," if self.storage else ""
45+
embedding_model_str = f"{ind}embedding_model={str(self.embedding_model)}," if self.embedding_model else ""
46+
reranking_model_str = f"{ind}reranking_model={str(self.reranking_model)}," if self.reranking_model else ""
47+
4148
out_str = f"""
4249
{ind}CreateKnowledgeBase(
4350
{ind} if_not_exists={self.if_not_exists},
4451
{ind} name={self.name.to_string()},
4552
{ind} from_query={self.from_query.to_tree(level=level + 1) if self.from_query else None},
46-
{model_str}{storage_str}{ind} params={self.params}
53+
{ind} {embedding_model_str}
54+
{ind} {reranking_model_str}
55+
{ind} {storage_str}
56+
{ind} params={self.params}
4757
{ind})
4858
"""
4959
return out_str
@@ -56,8 +66,11 @@ def get_string(self, *args, **kwargs):
5666
using_ar = []
5767
if self.storage:
5868
using_ar.append(f" STORAGE={self.storage.to_string()}")
59-
if self.model:
60-
using_ar.append(f" MODEL={self.model.to_string()}")
69+
embedding_model_str = ""
70+
if self.embedding_model_str:
71+
using_ar.append(f" EMBEDDING_MODEL={json.dumps(self.embedding_model)}")
72+
if self.reranking_model_str:
73+
using_ar.append(f" RERANKING_MODEL={json.dumps(self.reranking_model)}")
6174

6275
params = self.params.copy()
6376
if params:

0 commit comments

Comments
 (0)