Skip to content

Commit bd5015c

Browse files
committed
feat: add export terminologies
1 parent aa81a8a commit bd5015c

File tree

11 files changed

+324
-265
lines changed

11 files changed

+324
-265
lines changed

backend/apps/terminology/api/terminology.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
1+
import asyncio
2+
import io
13
from typing import Optional
24

5+
import pandas as pd
36
from fastapi import APIRouter, Query
7+
from fastapi.responses import StreamingResponse
48

9+
from apps.chat.models.chat_model import AxisObj
10+
from apps.chat.task.llm import LLMService
511
from apps.terminology.curd.terminology import page_terminology, create_terminology, update_terminology, \
6-
delete_terminology, enable_terminology
12+
delete_terminology, enable_terminology, get_all_terminology
713
from apps.terminology.models.terminology_model import TerminologyInfo
814
from common.core.deps import SessionDep, CurrentUser, Trans
915

@@ -42,3 +48,44 @@ async def delete(session: SessionDep, id_list: list[int]):
4248
@router.get("/{id}/enable/{enabled}")
4349
async def enable(session: SessionDep, id: int, enabled: bool, trans: Trans):
4450
enable_terminology(session, id, enabled, trans)
51+
52+
53+
@router.get("/export")
54+
async def export_excel(session: SessionDep, trans: Trans, current_user: CurrentUser,
55+
word: Optional[str] = Query(None, description="搜索术语(可选)")):
56+
def inner():
57+
_list = get_all_terminology(session, word, oid=current_user.oid)
58+
59+
data_list = []
60+
for obj in _list:
61+
_data = {
62+
"word": obj.word,
63+
"other_words": ', '.join(obj.other_words) if obj.other_words else '',
64+
"description": obj.description,
65+
"all_data_sources": 'Y' if obj.specific_ds else 'N',
66+
"datasource": ', '.join(obj.datasource_names) if obj.datasource_names else '',
67+
}
68+
data_list.append(_data)
69+
70+
fields = []
71+
fields.append(AxisObj(name=trans('i18n_terminology.term_name'), value='word'))
72+
fields.append(AxisObj(name=trans('i18n_terminology.synonyms'), value='other_words'))
73+
fields.append(AxisObj(name=trans('i18n_terminology.term_description'), value='description'))
74+
fields.append(AxisObj(name=trans('i18n_terminology.effective_data_sources'), value='datasource'))
75+
fields.append(AxisObj(name=trans('i18n_terminology.all_data_sources'), value='all_data_sources'))
76+
77+
md_data, _fields_list = LLMService.convert_object_array_for_pandas(fields, data_list)
78+
79+
df = pd.DataFrame(md_data, columns=_fields_list)
80+
81+
buffer = io.BytesIO()
82+
83+
with pd.ExcelWriter(buffer, engine='xlsxwriter',
84+
engine_kwargs={'options': {'strings_to_numbers': False}}) as writer:
85+
df.to_excel(writer, sheet_name='Sheet1', index=False)
86+
87+
buffer.seek(0)
88+
return io.BytesIO(buffer.getvalue())
89+
90+
result = await asyncio.to_thread(inner)
91+
return StreamingResponse(result, media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet")

backend/apps/terminology/curd/terminology.py

Lines changed: 118 additions & 142 deletions
Original file line numberDiff line numberDiff line change
@@ -17,24 +17,18 @@
1717
from common.utils.embedding_threads import run_save_terminology_embeddings
1818

1919

20-
def page_terminology(session: SessionDep, current_page: int = 1, page_size: int = 10, name: Optional[str] = None,
21-
oid: Optional[int] = 1):
22-
_list: List[TerminologyInfo] = []
23-
20+
def get_terminology_base_query(oid: int, name: Optional[str] = None):
21+
"""
22+
获取术语查询的基础查询结构
23+
"""
2424
child = aliased(Terminology)
2525

26-
current_page = max(1, current_page)
27-
page_size = max(10, page_size)
28-
29-
total_count = 0
30-
total_pages = 0
31-
3226
if name and name.strip() != "":
3327
keyword_pattern = f"%{name.strip()}%"
3428
# 步骤1:先找到所有匹配的节点ID(无论是父节点还是子节点)
3529
matched_ids_subquery = (
3630
select(Terminology.id)
37-
.where(and_(Terminology.word.ilike(keyword_pattern), Terminology.oid == oid)) # LIKE查询条件
31+
.where(and_(Terminology.word.ilike(keyword_pattern), Terminology.oid == oid))
3832
.subquery()
3933
)
4034

@@ -51,161 +45,118 @@ def page_terminology(session: SessionDep, current_page: int = 1, page_size: int
5145
)
5246
.where(Terminology.pid.is_(None)) # 只取父节点
5347
)
48+
else:
49+
parent_ids_subquery = (
50+
select(Terminology.id)
51+
.where(and_(Terminology.pid.is_(None), Terminology.oid == oid))
52+
)
53+
54+
return parent_ids_subquery, child
5455

55-
count_stmt = select(func.count()).select_from(parent_ids_subquery.subquery())
56-
total_count = session.execute(count_stmt).scalar()
57-
total_pages = (total_count + page_size - 1) // page_size
5856

59-
if current_page > total_pages:
60-
current_page = 1
57+
def build_terminology_query(session: SessionDep, oid: int, name: Optional[str] = None,
58+
paginate: bool = True, current_page: int = 1, page_size: int = 10):
59+
"""
60+
构建术语查询的通用方法
61+
"""
62+
parent_ids_subquery, child = get_terminology_base_query(oid, name)
63+
64+
# 计算总数
65+
count_stmt = select(func.count()).select_from(parent_ids_subquery.subquery())
66+
total_count = session.execute(count_stmt).scalar()
67+
68+
if paginate:
69+
# 分页处理
70+
page_size = max(10, page_size)
71+
total_pages = (total_count + page_size - 1) // page_size
72+
current_page = max(1, min(current_page, total_pages)) if total_pages > 0 else 1
6173

62-
# 步骤3:获取分页后的父节点ID
6374
paginated_parent_ids = (
6475
parent_ids_subquery
6576
.order_by(Terminology.create_time.desc())
6677
.offset((current_page - 1) * page_size)
6778
.limit(page_size)
6879
.subquery()
6980
)
70-
71-
# 步骤4:获取这些父节点的childrenNames
72-
children_subquery = (
73-
select(
74-
child.pid,
75-
func.jsonb_agg(child.word).filter(child.word.isnot(None)).label('other_words')
76-
)
77-
.where(child.pid.isnot(None))
78-
.group_by(child.pid)
79-
.subquery()
80-
)
81-
82-
# 创建子查询来获取数据源名称,添加类型转换
83-
datasource_names_subquery = (
84-
select(
85-
func.jsonb_array_elements(Terminology.datasource_ids).cast(BigInteger).label('ds_id'),
86-
Terminology.id.label('term_id')
87-
)
88-
.where(Terminology.id.in_(paginated_parent_ids))
89-
.subquery()
90-
)
91-
92-
# 主查询
93-
stmt = (
94-
select(
95-
Terminology.id,
96-
Terminology.word,
97-
Terminology.create_time,
98-
Terminology.description,
99-
Terminology.specific_ds,
100-
Terminology.datasource_ids,
101-
children_subquery.c.other_words,
102-
func.jsonb_agg(CoreDatasource.name).filter(CoreDatasource.id.isnot(None)).label('datasource_names'),
103-
Terminology.enabled
104-
)
105-
.outerjoin(
106-
children_subquery,
107-
Terminology.id == children_subquery.c.pid
108-
)
109-
# 关联数据源名称子查询和 CoreDatasource 表
110-
.outerjoin(
111-
datasource_names_subquery,
112-
datasource_names_subquery.c.term_id == Terminology.id
113-
)
114-
.outerjoin(
115-
CoreDatasource,
116-
CoreDatasource.id == datasource_names_subquery.c.ds_id
117-
)
118-
.where(and_(Terminology.id.in_(paginated_parent_ids), Terminology.oid == oid))
119-
.group_by(
120-
Terminology.id,
121-
Terminology.word,
122-
Terminology.create_time,
123-
Terminology.description,
124-
Terminology.specific_ds,
125-
Terminology.datasource_ids,
126-
children_subquery.c.other_words,
127-
Terminology.enabled
128-
)
129-
.order_by(Terminology.create_time.desc())
130-
)
13181
else:
132-
parent_ids_subquery = (
133-
select(Terminology.id)
134-
.where(and_(Terminology.pid.is_(None), Terminology.oid == oid)) # 只取父节点
135-
)
136-
count_stmt = select(func.count()).select_from(parent_ids_subquery.subquery())
137-
total_count = session.execute(count_stmt).scalar()
138-
total_pages = (total_count + page_size - 1) // page_size
139-
140-
if current_page > total_pages:
141-
current_page = 1
82+
# 不分页,获取所有数据
83+
total_pages = 1
84+
current_page = 1
85+
page_size = total_count if total_count > 0 else 1
14286

14387
paginated_parent_ids = (
14488
parent_ids_subquery
14589
.order_by(Terminology.create_time.desc())
146-
.offset((current_page - 1) * page_size)
147-
.limit(page_size)
14890
.subquery()
14991
)
15092

151-
children_subquery = (
152-
select(
153-
child.pid,
154-
func.jsonb_agg(child.word).filter(child.word.isnot(None)).label('other_words')
155-
)
156-
.where(child.pid.isnot(None))
157-
.group_by(child.pid)
158-
.subquery()
93+
# 构建公共查询部分
94+
children_subquery = (
95+
select(
96+
child.pid,
97+
func.jsonb_agg(child.word).filter(child.word.isnot(None)).label('other_words')
15998
)
99+
.where(child.pid.isnot(None))
100+
.group_by(child.pid)
101+
.subquery()
102+
)
160103

161-
# 创建子查询来获取数据源名称
162-
datasource_names_subquery = (
163-
select(
164-
func.jsonb_array_elements(Terminology.datasource_ids).cast(BigInteger).label('ds_id'),
165-
Terminology.id.label('term_id')
166-
)
167-
.where(Terminology.id.in_(paginated_parent_ids))
168-
.subquery()
104+
# 创建子查询来获取数据源名称
105+
datasource_names_subquery = (
106+
select(
107+
func.jsonb_array_elements(Terminology.datasource_ids).cast(BigInteger).label('ds_id'),
108+
Terminology.id.label('term_id')
169109
)
110+
.where(Terminology.id.in_(paginated_parent_ids))
111+
.subquery()
112+
)
170113

171-
stmt = (
172-
select(
173-
Terminology.id,
174-
Terminology.word,
175-
Terminology.create_time,
176-
Terminology.description,
177-
Terminology.specific_ds,
178-
Terminology.datasource_ids,
179-
children_subquery.c.other_words,
180-
func.jsonb_agg(CoreDatasource.name).filter(CoreDatasource.id.isnot(None)).label('datasource_names'),
181-
Terminology.enabled
182-
)
183-
.outerjoin(
184-
children_subquery,
185-
Terminology.id == children_subquery.c.pid
186-
)
187-
# 关联数据源名称子查询和 CoreDatasource 表
188-
.outerjoin(
189-
datasource_names_subquery,
190-
datasource_names_subquery.c.term_id == Terminology.id
191-
)
192-
.outerjoin(
193-
CoreDatasource,
194-
CoreDatasource.id == datasource_names_subquery.c.ds_id
195-
)
196-
.where(and_(Terminology.id.in_(paginated_parent_ids), Terminology.oid == oid))
197-
.group_by(Terminology.id,
198-
Terminology.word,
199-
Terminology.create_time,
200-
Terminology.description,
201-
Terminology.specific_ds,
202-
Terminology.datasource_ids,
203-
children_subquery.c.other_words,
204-
Terminology.enabled
205-
)
206-
.order_by(Terminology.create_time.desc())
114+
stmt = (
115+
select(
116+
Terminology.id,
117+
Terminology.word,
118+
Terminology.create_time,
119+
Terminology.description,
120+
Terminology.specific_ds,
121+
Terminology.datasource_ids,
122+
children_subquery.c.other_words,
123+
func.jsonb_agg(CoreDatasource.name).filter(CoreDatasource.id.isnot(None)).label('datasource_names'),
124+
Terminology.enabled
207125
)
126+
.outerjoin(
127+
children_subquery,
128+
Terminology.id == children_subquery.c.pid
129+
)
130+
.outerjoin(
131+
datasource_names_subquery,
132+
datasource_names_subquery.c.term_id == Terminology.id
133+
)
134+
.outerjoin(
135+
CoreDatasource,
136+
CoreDatasource.id == datasource_names_subquery.c.ds_id
137+
)
138+
.where(and_(Terminology.id.in_(paginated_parent_ids), Terminology.oid == oid))
139+
.group_by(
140+
Terminology.id,
141+
Terminology.word,
142+
Terminology.create_time,
143+
Terminology.description,
144+
Terminology.specific_ds,
145+
Terminology.datasource_ids,
146+
children_subquery.c.other_words,
147+
Terminology.enabled
148+
)
149+
.order_by(Terminology.create_time.desc())
150+
)
151+
152+
return stmt, total_count, total_pages, current_page, page_size
153+
208154

155+
def execute_terminology_query(session: SessionDep, stmt) -> List[TerminologyInfo]:
156+
"""
157+
执行查询并返回术语信息列表
158+
"""
159+
_list = []
209160
result = session.execute(stmt)
210161

211162
for row in result:
@@ -221,9 +172,34 @@ def page_terminology(session: SessionDep, current_page: int = 1, page_size: int
221172
enabled=row.enabled if row.enabled is not None else False,
222173
))
223174

175+
return _list
176+
177+
178+
def page_terminology(session: SessionDep, current_page: int = 1, page_size: int = 10,
179+
name: Optional[str] = None, oid: Optional[int] = 1):
180+
"""
181+
分页查询术语(原方法保持不变)
182+
"""
183+
stmt, total_count, total_pages, current_page, page_size = build_terminology_query(
184+
session, oid, name, True, current_page, page_size
185+
)
186+
_list = execute_terminology_query(session, stmt)
187+
224188
return current_page, page_size, total_count, total_pages, _list
225189

226190

191+
def get_all_terminology(session: SessionDep, name: Optional[str] = None, oid: Optional[int] = 1):
192+
"""
193+
获取所有术语(不分页)
194+
"""
195+
stmt, total_count, total_pages, current_page, page_size = build_terminology_query(
196+
session, oid, name, False
197+
)
198+
_list = execute_terminology_query(session, stmt)
199+
200+
return _list
201+
202+
227203
def create_terminology(session: SessionDep, info: TerminologyInfo, oid: int, trans: Trans):
228204
create_time = datetime.datetime.now()
229205

0 commit comments

Comments
 (0)