Skip to content

Commit 00b9813

Browse files
committed
feat: Advanced Application support to use SQL Examples
1 parent 9ed3427 commit 00b9813

File tree

9 files changed

+251
-96
lines changed

9 files changed

+251
-96
lines changed

backend/apps/chat/task/llm.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -505,8 +505,12 @@ def select_datasource(self, _session: Session):
505505

506506
self.chat_question.terminologies = get_terminology_template(_session, self.chat_question.question, oid,
507507
ds_id)
508-
self.chat_question.data_training = get_training_template(_session, self.chat_question.question, ds_id,
509-
oid)
508+
if self.current_assistant and self.current_assistant.type == 1:
509+
self.chat_question.data_training = get_training_template(_session, self.chat_question.question,
510+
oid, None, self.current_assistant.id)
511+
else:
512+
self.chat_question.data_training = get_training_template(_session, self.chat_question.question,
513+
oid, ds_id)
510514
if SQLBotLicenseUtil.valid():
511515
self.chat_question.custom_prompt = find_custom_prompts(_session, CustomPromptTypeEnum.GENERATE_SQL,
512516
oid, ds_id)
@@ -902,8 +906,12 @@ def run_task(self, in_chat: bool = True, stream: bool = True,
902906
ds_id = self.ds.id if isinstance(self.ds, CoreDatasource) else None
903907
self.chat_question.terminologies = get_terminology_template(_session, self.chat_question.question,
904908
oid, ds_id)
905-
self.chat_question.data_training = get_training_template(_session, self.chat_question.question,
906-
ds_id, oid)
909+
if self.current_assistant and self.current_assistant.type == 1:
910+
self.chat_question.data_training = get_training_template(_session, self.chat_question.question,
911+
oid, None, self.current_assistant.id)
912+
else:
913+
self.chat_question.data_training = get_training_template(_session, self.chat_question.question,
914+
oid, ds_id)
907915
if SQLBotLicenseUtil.valid():
908916
self.chat_question.custom_prompt = find_custom_prompts(_session,
909917
CustomPromptTypeEnum.GENERATE_SQL,

backend/apps/data_training/curd/data_training.py

Lines changed: 82 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,9 @@
99
from sqlalchemy import text
1010

1111
from apps.ai_model.embedding import EmbeddingModelCache
12-
from apps.data_training.models.data_training_model import DataTrainingInfo, DataTraining
12+
from apps.data_training.models.data_training_model import DataTrainingInfo, DataTraining, DataTrainingInfoResult
1313
from apps.datasource.models.datasource import CoreDatasource
14+
from apps.system.models.system_model import AssistantModel
1415
from apps.template.generate_chart.generator import get_base_data_training_template
1516
from common.core.config import settings
1617
from common.core.deps import SessionDep, Trans
@@ -19,7 +20,7 @@
1920

2021
def page_data_training(session: SessionDep, current_page: int = 1, page_size: int = 10, name: Optional[str] = None,
2122
oid: Optional[int] = 1):
22-
_list: List[DataTrainingInfo] = []
23+
_list: List[DataTrainingInfoResult] = []
2324

2425
current_page = max(1, current_page)
2526
page_size = max(10, page_size)
@@ -63,40 +64,60 @@ def page_data_training(session: SessionDep, current_page: int = 1, page_size: in
6364
DataTraining.create_time,
6465
DataTraining.description,
6566
DataTraining.enabled,
67+
DataTraining.advanced_application,
68+
AssistantModel.name.label('advanced_application_name'),
6669
)
6770
.outerjoin(CoreDatasource, and_(DataTraining.datasource == CoreDatasource.id))
71+
.outerjoin(AssistantModel,
72+
and_(DataTraining.advanced_application == AssistantModel.id, AssistantModel.type == 1))
6873
.where(and_(DataTraining.id.in_(paginated_parent_ids)))
6974
.order_by(DataTraining.create_time.desc())
7075
)
7176

7277
result = session.execute(stmt)
7378

7479
for row in result:
75-
_list.append(DataTrainingInfo(
76-
id=row.id,
77-
oid=row.oid,
80+
_list.append(DataTrainingInfoResult(
81+
id=str(row.id),
82+
oid=str(row.oid),
7883
datasource=row.datasource,
7984
datasource_name=row.name,
8085
question=row.question,
8186
create_time=row.create_time,
8287
description=row.description,
8388
enabled=row.enabled,
89+
advanced_application=str(row.advanced_application) if row.advanced_application else None,
90+
advanced_application_name=row.advanced_application_name,
8491
))
8592

8693
return current_page, page_size, total_count, total_pages, _list
8794

8895

8996
def create_training(session: SessionDep, info: DataTrainingInfo, oid: int, trans: Trans):
9097
create_time = datetime.datetime.now()
91-
if info.datasource is None:
92-
raise Exception(trans("i18n_data_training.datasource_cannot_be_none"))
98+
if info.datasource is None and info.advanced_application is None:
99+
if oid == 1:
100+
raise Exception(trans("i18n_data_training.datasource_assistant_cannot_be_none"))
101+
else:
102+
raise Exception(trans("i18n_data_training.datasource_cannot_be_none"))
103+
93104
parent = DataTraining(question=info.question, create_time=create_time, description=info.description, oid=oid,
94-
datasource=info.datasource, enabled=info.enabled)
105+
datasource=info.datasource, enabled=info.enabled,
106+
advanced_application=info.advanced_application)
107+
108+
stmt = select(DataTraining.id).where(and_(DataTraining.question == info.question, DataTraining.oid == oid))
109+
110+
if info.datasource is not None and info.advanced_application is not None:
111+
stmt = stmt.where(
112+
or_(DataTraining.datasource == info.datasource,
113+
DataTraining.advanced_application == info.advanced_application))
114+
elif info.datasource is not None and info.advanced_application is None:
115+
stmt = stmt.where(and_(DataTraining.datasource == info.datasource))
116+
elif info.datasource is None and info.advanced_application is not None:
117+
stmt = stmt.where(and_(DataTraining.advanced_application == info.advanced_application))
118+
119+
exists = session.query(stmt.exists()).scalar()
95120

96-
exists = session.query(
97-
session.query(DataTraining).filter(
98-
and_(DataTraining.question == info.question, DataTraining.oid == oid,
99-
DataTraining.datasource == info.datasource)).exists()).scalar()
100121
if exists:
101122
raise Exception(trans("i18n_data_training.exists_in_db"))
102123

@@ -116,20 +137,32 @@ def create_training(session: SessionDep, info: DataTrainingInfo, oid: int, trans
116137

117138

118139
def update_training(session: SessionDep, info: DataTrainingInfo, oid: int, trans: Trans):
119-
if info.datasource is None:
120-
raise Exception(trans("i18n_data_training.datasource_cannot_be_none"))
140+
if info.datasource is None and info.advanced_application is None:
141+
if oid == 1:
142+
raise Exception(trans("i18n_data_training.datasource_assistant_cannot_be_none"))
143+
else:
144+
raise Exception(trans("i18n_data_training.datasource_cannot_be_none"))
121145

122146
count = session.query(DataTraining).filter(
123147
DataTraining.id == info.id
124148
).count()
125149
if count == 0:
126150
raise Exception(trans('i18n_data_training.data_training_not_exists'))
127151

128-
exists = session.query(
129-
session.query(DataTraining).filter(
130-
and_(DataTraining.question == info.question, DataTraining.oid == oid,
131-
DataTraining.datasource == info.datasource,
132-
DataTraining.id != info.id)).exists()).scalar()
152+
stmt = select(DataTraining.id).where(
153+
and_(DataTraining.question == info.question, DataTraining.oid == oid, DataTraining.id != info.id))
154+
155+
if info.datasource is not None and info.advanced_application is not None:
156+
stmt = stmt.where(
157+
or_(DataTraining.datasource == info.datasource,
158+
DataTraining.advanced_application == info.advanced_application))
159+
elif info.datasource is not None and info.advanced_application is None:
160+
stmt = stmt.where(and_(DataTraining.datasource == info.datasource))
161+
elif info.datasource is None and info.advanced_application is not None:
162+
stmt = stmt.where(and_(DataTraining.advanced_application == info.advanced_application))
163+
164+
exists = session.query(stmt.exists()).scalar()
165+
133166
if exists:
134167
raise Exception(trans("i18n_data_training.exists_in_db"))
135168

@@ -138,6 +171,7 @@ def update_training(session: SessionDep, info: DataTrainingInfo, oid: int, trans
138171
description=info.description,
139172
datasource=info.datasource,
140173
enabled=info.enabled,
174+
advanced_application=info.advanced_application,
141175
)
142176
session.execute(stmt)
143177
session.commit()
@@ -231,9 +265,21 @@ def save_embeddings(session_maker, ids: List[int]):
231265
ORDER BY similarity DESC
232266
LIMIT {settings.EMBEDDING_DATA_TRAINING_TOP_COUNT}
233267
"""
268+
embedding_sql_in_advanced_application = f"""
269+
SELECT id, datasource, question, similarity
270+
FROM
271+
(SELECT id, datasource, question, oid, enabled,
272+
( 1 - (embedding <=> :embedding_array) ) AS similarity
273+
FROM data_training AS child
274+
) TEMP
275+
WHERE similarity > {settings.EMBEDDING_DATA_TRAINING_SIMILARITY} and oid = :oid and advanced_application = :advanced_application and enabled = true
276+
ORDER BY similarity DESC
277+
LIMIT {settings.EMBEDDING_DATA_TRAINING_TOP_COUNT}
278+
"""
234279

235280

236-
def select_training_by_question(session: SessionDep, question: str, oid: int, datasource: int):
281+
def select_training_by_question(session: SessionDep, question: str, oid: int, datasource: Optional[int] = None,
282+
advanced_application_id: Optional[int] = None):
237283
if question.strip() == "":
238284
return []
239285

@@ -248,10 +294,13 @@ def select_training_by_question(session: SessionDep, question: str, oid: int, da
248294
.where(
249295
and_(or_(text(":sentence ILIKE '%' || question || '%'"), text("question ILIKE '%' || :sentence || '%'")),
250296
DataTraining.oid == oid,
251-
DataTraining.datasource == datasource,
252-
DataTraining.enabled == True,)
297+
DataTraining.enabled == True)
253298
)
254299
)
300+
if advanced_application_id is not None:
301+
stmt = stmt.where(and_(DataTraining.advanced_application == advanced_application_id))
302+
else:
303+
stmt = stmt.where(and_(DataTraining.datasource == datasource))
255304

256305
results = session.execute(stmt, {'sentence': question}).fetchall()
257306

@@ -264,8 +313,13 @@ def select_training_by_question(session: SessionDep, question: str, oid: int, da
264313

265314
embedding = model.embed_query(question)
266315

267-
results = session.execute(text(embedding_sql),
268-
{'embedding_array': str(embedding), 'oid': oid, 'datasource': datasource})
316+
if advanced_application_id is not None:
317+
results = session.execute(text(embedding_sql_in_advanced_application),
318+
{'embedding_array': str(embedding), 'oid': oid,
319+
'advanced_application': advanced_application_id})
320+
else:
321+
results = session.execute(text(embedding_sql),
322+
{'embedding_array': str(embedding), 'oid': oid, 'datasource': datasource})
269323

270324
for row in results:
271325
_list.append(DataTraining(id=row.id, question=row.question))
@@ -328,12 +382,13 @@ def to_xml_string(_dict: list[dict] | dict, root: str = 'sql-examples') -> str:
328382
return pretty_xml
329383

330384

331-
def get_training_template(session: SessionDep, question: str, datasource: int, oid: Optional[int] = 1) -> str:
385+
def get_training_template(session: SessionDep, question: str, oid: Optional[int] = 1, datasource: Optional[int] = None,
386+
advanced_application_id: Optional[int] = None) -> str:
332387
if not oid:
333388
oid = 1
334-
if not datasource:
389+
if not datasource and not advanced_application_id:
335390
return ''
336-
_results = select_training_by_question(session, question, oid, datasource)
391+
_results = select_training_by_question(session, question, oid, datasource, advanced_application_id)
337392
if _results and len(_results) > 0:
338393
data_training = to_xml_string(_results)
339394
template = get_base_data_training_template().format(data_training=data_training)

backend/apps/data_training/models/data_training_model.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ class DataTraining(SQLModel, table=True):
1717
description: Optional[str] = Field(sa_column=Column(Text, nullable=True))
1818
embedding: Optional[List[float]] = Field(sa_column=Column(VECTOR(), nullable=True))
1919
enabled: Optional[bool] = Field(sa_column=Column(Boolean, default=True))
20+
advanced_application: Optional[int] = Field(sa_column=Column(BigInteger, nullable=True))
2021

2122

2223
class DataTrainingInfo(BaseModel):
@@ -28,3 +29,18 @@ class DataTrainingInfo(BaseModel):
2829
question: Optional[str] = None
2930
description: Optional[str] = None
3031
enabled: Optional[bool] = True
32+
advanced_application: Optional[int] = None
33+
advanced_application_name: Optional[str] = None
34+
35+
36+
class DataTrainingInfoResult(BaseModel):
37+
id: Optional[str] = None
38+
oid: Optional[str] = None
39+
datasource: Optional[int] = None
40+
datasource_name: Optional[str] = None
41+
create_time: Optional[datetime] = None
42+
question: Optional[str] = None
43+
description: Optional[str] = None
44+
enabled: Optional[bool] = True
45+
advanced_application: Optional[str] = None
46+
advanced_application_name: Optional[str] = None

0 commit comments

Comments
 (0)