Skip to content

Commit e77bdc7

Browse files
committed
feat: Support tag in knowledge_workflow
1 parent 735d3d0 commit e77bdc7

File tree

6 files changed

+113
-12
lines changed

6 files changed

+113
-12
lines changed

apps/application/flow/step_node/knowledge_write_node/impl/base_knowledge_write_node.py

Lines changed: 104 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
@desc:
88
"""
99
from functools import reduce
10-
from typing import Dict, List
10+
from typing import Dict, List, Any
1111
import uuid_utils.compat as uuid
1212
from django.db.models import QuerySet
1313
from django.db.models.aggregates import Max
@@ -18,7 +18,8 @@
1818
from application.flow.step_node.knowledge_write_node.i_knowledge_write_node import IKnowledgeWriteNode
1919
from common.chunk import text_to_chunk
2020
from common.utils.common import bulk_create_in_batches
21-
from knowledge.models import Document, KnowledgeType, Paragraph, File, FileSourceType, Problem, ProblemParagraphMapping
21+
from knowledge.models import Document, KnowledgeType, Paragraph, File, FileSourceType, Problem, ProblemParagraphMapping, \
22+
Tag, DocumentTag
2223
from knowledge.serializers.common import ProblemParagraphObject, ProblemParagraphManage
2324
from knowledge.serializers.document import DocumentSerializers
2425

@@ -33,10 +34,16 @@ class ParagraphInstanceSerializer(serializers.Serializer):
3334
chunks = serializers.ListField(required=False, child=serializers.CharField(required=True))
3435

3536

37+
class TagInstanceSerializer(serializers.Serializer):
38+
key = serializers.CharField(required=True, max_length=64, label=_('Tag Key'))
39+
value = serializers.CharField(required=True, max_length=128, label=_('Tag Value'))
40+
41+
3642
class KnowledgeWriteParamSerializer(serializers.Serializer):
3743
name = serializers.CharField(required=True, label=_('document name'), max_length=128, min_length=1,
3844
source=_('document name'))
3945
meta = serializers.DictField(required=False)
46+
tags = serializers.ListField(required=False, label=_('Tags'), child=TagInstanceSerializer())
4047
paragraphs = ParagraphInstanceSerializer(required=False, many=True, allow_null=True)
4148
source_file_id = serializers.UUIDField(required=False, allow_null=True)
4249

@@ -51,6 +58,7 @@ def convert_uuid_to_str(obj):
5158
else:
5259
return obj
5360

61+
5462
def link_file(source_file_id, document_id):
5563
if source_file_id is None:
5664
return
@@ -70,14 +78,15 @@ def link_file(source_file_id, document_id):
7078
# 保存文件内容和元数据
7179
new_file.save(file_content)
7280

81+
7382
def get_paragraph_problem_model(knowledge_id: str, document_id: str, instance: Dict):
7483
paragraph = Paragraph(
7584
id=uuid.uuid7(),
7685
document_id=document_id,
7786
content=instance.get("content"),
7887
knowledge_id=knowledge_id,
7988
title=instance.get("title") if 'title' in instance else '',
80-
chunks = instance.get('chunks') if 'chunks' in instance else text_to_chunk(instance.get("content")),
89+
chunks=instance.get('chunks') if 'chunks' in instance else text_to_chunk(instance.get("content")),
8190
)
8291

8392
problem_paragraph_object_list = [ProblemParagraphObject(
@@ -136,6 +145,65 @@ def get_document_paragraph_model(knowledge_id: str, instance: Dict):
136145
instance.get('paragraphs') if 'paragraphs' in instance else []
137146
)
138147

148+
def save_knowledge_tags(knowledge_id: str, tags: List[Dict[str,Any]]):
149+
150+
existed_tags_dict = {
151+
(key, value): str(tag_id)
152+
for key,value,tag_id in QuerySet(Tag).filter(knowledge_id=knowledge_id).values_list("key", "value", "id")
153+
}
154+
155+
tag_model_list = []
156+
new_tag_dict = {}
157+
for tag in tags:
158+
key = tag.get("key")
159+
value = tag.get("value")
160+
161+
if (key,value) not in existed_tags_dict:
162+
tag_model = Tag(
163+
id=uuid.uuid7(),
164+
knowledge_id=knowledge_id,
165+
key=key,
166+
value=value
167+
)
168+
tag_model_list.append(tag_model)
169+
new_tag_dict[(key,value)] = str(tag_model.id)
170+
171+
if tag_model_list:
172+
Tag.objects.bulk_create(tag_model_list)
173+
174+
all_tag_dict={**existed_tags_dict,**new_tag_dict}
175+
176+
return all_tag_dict, new_tag_dict
177+
178+
def batch_add_document_tag(document_tag_map: Dict[str, List[str]]):
179+
"""
180+
批量添加文档-标签关联
181+
document_tag_map: {document_id: [tag_id1, tag_id2, ...]}
182+
"""
183+
all_document_ids = list(document_tag_map.keys())
184+
all_tag_ids = list(set(tag_id for tag_ids in document_tag_map.values() for tag_id in tag_ids))
185+
186+
# 查询已存在的文档-标签关联
187+
existed_relations = set(
188+
QuerySet(DocumentTag).filter(
189+
document_id__in=all_document_ids,
190+
tag_id__in=all_tag_ids
191+
).values_list('document_id', 'tag_id')
192+
)
193+
194+
new_relations = [
195+
DocumentTag(
196+
id=uuid.uuid7(),
197+
document_id=doc_id,
198+
tag_id=tag_id,
199+
)
200+
for doc_id, tag_ids in document_tag_map.items()
201+
for tag_id in tag_ids
202+
if (doc_id,tag_id) not in existed_relations
203+
]
204+
205+
if new_relations:
206+
QuerySet(DocumentTag).bulk_create(new_relations)
139207

140208
class BaseKnowledgeWriteNode(IKnowledgeWriteNode):
141209

@@ -153,6 +221,11 @@ def save(self, document_list):
153221
document_model_list = []
154222
paragraph_model_list = []
155223
problem_paragraph_object_list = []
224+
# 所有标签
225+
knowledge_tag_list = []
226+
# 文档标签映射关系
227+
document_tags_map = {}
228+
knowledge_tag_dict = {}
156229

157230
for document in document_list:
158231
document_paragraph_dict_model = get_document_paragraph_model(
@@ -162,10 +235,38 @@ def save(self, document_list):
162235
document_instance = document_paragraph_dict_model.get('document')
163236
link_file(document.get("source_file_id"), document_instance.id)
164237
document_model_list.append(document_instance)
238+
# 收集标签
239+
single_document_tag_list = document.get("tags", [])
240+
# 去重传入的标签
241+
for tag in single_document_tag_list:
242+
tag_key = (tag['key'], tag['value'])
243+
if tag_key not in knowledge_tag_dict:
244+
knowledge_tag_dict[tag_key]= tag
245+
246+
if single_document_tag_list:
247+
document_tags_map[str(document_instance.id)] = single_document_tag_list
248+
165249
for paragraph in document_paragraph_dict_model.get("paragraph_model_list"):
166250
paragraph_model_list.append(paragraph)
167251
for problem_paragraph_object in document_paragraph_dict_model.get("problem_paragraph_object_list"):
168252
problem_paragraph_object_list.append(problem_paragraph_object)
253+
knowledge_tag_list = list(knowledge_tag_dict.values())
254+
# 保存所有文档中含有的标签到知识库
255+
if knowledge_tag_list:
256+
all_tag_dict, new_tag_dict = save_knowledge_tags(knowledge_id, knowledge_tag_list)
257+
# 构建文档-标签ID映射
258+
document_tag_id_map = {}
259+
# 为每个文档添加其对应的标签
260+
for doc_id, doc_tags in document_tags_map.items():
261+
doc_tag_ids = [
262+
all_tag_dict[(tag.get("key"),tag.get("value"))]
263+
for tag in doc_tags
264+
if (tag.get("key"),tag.get("value")) in all_tag_dict
265+
]
266+
if doc_tag_ids:
267+
document_tag_id_map[doc_id] = doc_tag_ids
268+
if document_tag_id_map:
269+
batch_add_document_tag(document_tag_id_map)
169270

170271
problem_model_list, problem_paragraph_mapping_list = (
171272
ProblemParagraphManage(problem_paragraph_object_list, knowledge_id).to_problem_model_list()

apps/locales/en_US/LC_MESSAGES/django.po

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8810,7 +8810,7 @@ msgstr ""
88108810
msgid "Audio file recognition - Tongyi Qwen"
88118811
msgstr ""
88128812

8813-
msgid "Audio file recognition - Fun-ASR/Paraformer/SenseVoice"
8813+
msgid "Real-time speech recognition - Fun-ASR/Paraformer"
88148814
msgstr ""
88158815

88168816
msgid "Qwen-Omni"

apps/locales/zh_CN/LC_MESSAGES/django.po

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8936,8 +8936,8 @@ msgstr "阻止不安全的重定向到内部主机"
89368936
msgid "Audio file recognition - Tongyi Qwen"
89378937
msgstr "录音文件识别-通义千问"
89388938

8939-
msgid "Audio file recognition - Fun-ASR/Paraformer/SenseVoice"
8940-
msgstr "录音文件识别-Fun-ASR/Paraformer/SenseVoice"
8939+
msgid "Real-time speech recognition - Fun-ASR/Paraformer"
8940+
msgstr "实时语音识别-Fun-ASR/Paraformer"
89418941

89428942
msgid "Qwen-Omni"
89438943
msgstr "多模态"

apps/locales/zh_Hant/LC_MESSAGES/django.po

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8936,8 +8936,8 @@ msgstr "阻止不安全的重定向到內部主機"
89368936
msgid "Audio file recognition - Tongyi Qwen"
89378937
msgstr "錄音文件識別-通義千問"
89388938

8939-
msgid "Audio file recognition - Fun-ASR/Paraformer/SenseVoice"
8940-
msgstr "錄音文件識別-Fun-ASR/Paraformer/SenseVoice"
8939+
msgid "Real-time speech recognition - Fun-ASR/Paraformer"
8940+
msgstr "實時語音識別-Fun-ASR/Paraformer"
89418941

89428942
msgid "Qwen-Omni"
89438943
msgstr "多模態"

apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/stt/default_stt.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,13 @@
1818

1919

2020
class AliyunBaiLianDefaultSTTModelCredential(BaseForm, BaseModelCredential):
21-
type = forms.Radio(_("Type"), required=True, text_field='label', default_value='qwen', provider='', method='',
21+
type = forms.SingleSelect(_("API"), required=True, text_field='label', default_value='qwen', provider='', method='',
2222
value_field='value', option_list=[
2323
{'label': _('Audio file recognition - Tongyi Qwen'),
2424
'value': 'qwen'},
2525
{'label': _('Qwen-Omni'),
2626
'value': 'omni'},
27-
{'label': _('Audio file recognition - Fun-ASR/Paraformer/SenseVoice'),
27+
{'label': _('Real-time speech recognition - Fun-ASR/Paraformer'),
2828
'value': 'other'}
2929
])
3030
api_url = forms.TextInputField(_('API URL'), required=True, relation_show_field_dict={'type': ['qwen', 'omni']})

apps/models_provider/impl/aliyun_bai_lian_model_provider/model/stt/omni_stt.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def speech_to_text(self, audio_file):
6868
"format": "mp3",
6969
},
7070
},
71-
{"type": "text", "text": self.params.get('CueWord')},
71+
{"type": "text", "text": self.params.get('CueWord') or '这段音频在说什么'},
7272
],
7373
},
7474
],
@@ -77,7 +77,7 @@ def speech_to_text(self, audio_file):
7777
# stream 必须设置为 True,否则会报错
7878
stream=True,
7979
stream_options={"include_usage": True},
80-
extra_body=self.params
80+
extra_body = {'enable_thinking': False, **self.params},
8181
)
8282
result = []
8383
for chunk in completion:

0 commit comments

Comments
 (0)