Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
@desc:
"""
from functools import reduce
from typing import Dict, List
from typing import Dict, List, Any
import uuid_utils.compat as uuid
from django.db.models import QuerySet
from django.db.models.aggregates import Max
Expand All @@ -18,7 +18,8 @@
from application.flow.step_node.knowledge_write_node.i_knowledge_write_node import IKnowledgeWriteNode
from common.chunk import text_to_chunk
from common.utils.common import bulk_create_in_batches
from knowledge.models import Document, KnowledgeType, Paragraph, File, FileSourceType, Problem, ProblemParagraphMapping
from knowledge.models import Document, KnowledgeType, Paragraph, File, FileSourceType, Problem, ProblemParagraphMapping, \
Tag, DocumentTag
from knowledge.serializers.common import ProblemParagraphObject, ProblemParagraphManage
from knowledge.serializers.document import DocumentSerializers

Expand All @@ -33,10 +34,16 @@ class ParagraphInstanceSerializer(serializers.Serializer):
chunks = serializers.ListField(required=False, child=serializers.CharField(required=True))


class TagInstanceSerializer(serializers.Serializer):
key = serializers.CharField(required=True, max_length=64, label=_('Tag Key'))
value = serializers.CharField(required=True, max_length=128, label=_('Tag Value'))


class KnowledgeWriteParamSerializer(serializers.Serializer):
name = serializers.CharField(required=True, label=_('document name'), max_length=128, min_length=1,
source=_('document name'))
meta = serializers.DictField(required=False)
tags = serializers.ListField(required=False, label=_('Tags'), child=TagInstanceSerializer())
paragraphs = ParagraphInstanceSerializer(required=False, many=True, allow_null=True)
source_file_id = serializers.UUIDField(required=False, allow_null=True)

Expand All @@ -51,6 +58,7 @@ def convert_uuid_to_str(obj):
else:
return obj


def link_file(source_file_id, document_id):
if source_file_id is None:
return
Expand All @@ -70,14 +78,15 @@ def link_file(source_file_id, document_id):
# 保存文件内容和元数据
new_file.save(file_content)


def get_paragraph_problem_model(knowledge_id: str, document_id: str, instance: Dict):
paragraph = Paragraph(
id=uuid.uuid7(),
document_id=document_id,
content=instance.get("content"),
knowledge_id=knowledge_id,
title=instance.get("title") if 'title' in instance else '',
chunks = instance.get('chunks') if 'chunks' in instance else text_to_chunk(instance.get("content")),
chunks=instance.get('chunks') if 'chunks' in instance else text_to_chunk(instance.get("content")),
)

problem_paragraph_object_list = [ProblemParagraphObject(
Expand Down Expand Up @@ -136,6 +145,65 @@ def get_document_paragraph_model(knowledge_id: str, instance: Dict):
instance.get('paragraphs') if 'paragraphs' in instance else []
)

def save_knowledge_tags(knowledge_id: str, tags: List[Dict[str,Any]]):

existed_tags_dict = {
(key, value): str(tag_id)
for key,value,tag_id in QuerySet(Tag).filter(knowledge_id=knowledge_id).values_list("key", "value", "id")
}

tag_model_list = []
new_tag_dict = {}
for tag in tags:
key = tag.get("key")
value = tag.get("value")

if (key,value) not in existed_tags_dict:
tag_model = Tag(
id=uuid.uuid7(),
knowledge_id=knowledge_id,
key=key,
value=value
)
tag_model_list.append(tag_model)
new_tag_dict[(key,value)] = str(tag_model.id)

if tag_model_list:
Tag.objects.bulk_create(tag_model_list)

all_tag_dict={**existed_tags_dict,**new_tag_dict}

return all_tag_dict, new_tag_dict

def batch_add_document_tag(document_tag_map: Dict[str, List[str]]):
"""
批量添加文档-标签关联
document_tag_map: {document_id: [tag_id1, tag_id2, ...]}
"""
all_document_ids = list(document_tag_map.keys())
all_tag_ids = list(set(tag_id for tag_ids in document_tag_map.values() for tag_id in tag_ids))

# 查询已存在的文档-标签关联
existed_relations = set(
QuerySet(DocumentTag).filter(
document_id__in=all_document_ids,
tag_id__in=all_tag_ids
).values_list('document_id', 'tag_id')
)

new_relations = [
DocumentTag(
id=uuid.uuid7(),
document_id=doc_id,
tag_id=tag_id,
)
for doc_id, tag_ids in document_tag_map.items()
for tag_id in tag_ids
if (doc_id,tag_id) not in existed_relations
]

if new_relations:
QuerySet(DocumentTag).bulk_create(new_relations)

class BaseKnowledgeWriteNode(IKnowledgeWriteNode):

Expand All @@ -153,6 +221,11 @@ def save(self, document_list):
document_model_list = []
paragraph_model_list = []
problem_paragraph_object_list = []
# 所有标签
knowledge_tag_list = []
# 文档标签映射关系
document_tags_map = {}
knowledge_tag_dict = {}

for document in document_list:
document_paragraph_dict_model = get_document_paragraph_model(
Expand All @@ -162,10 +235,38 @@ def save(self, document_list):
document_instance = document_paragraph_dict_model.get('document')
link_file(document.get("source_file_id"), document_instance.id)
document_model_list.append(document_instance)
# 收集标签
single_document_tag_list = document.get("tags", [])
# 去重传入的标签
for tag in single_document_tag_list:
tag_key = (tag['key'], tag['value'])
if tag_key not in knowledge_tag_dict:
knowledge_tag_dict[tag_key]= tag

if single_document_tag_list:
document_tags_map[str(document_instance.id)] = single_document_tag_list

for paragraph in document_paragraph_dict_model.get("paragraph_model_list"):
paragraph_model_list.append(paragraph)
for problem_paragraph_object in document_paragraph_dict_model.get("problem_paragraph_object_list"):
problem_paragraph_object_list.append(problem_paragraph_object)
knowledge_tag_list = list(knowledge_tag_dict.values())
# 保存所有文档中含有的标签到知识库
if knowledge_tag_list:
all_tag_dict, new_tag_dict = save_knowledge_tags(knowledge_id, knowledge_tag_list)
# 构建文档-标签ID映射
document_tag_id_map = {}
# 为每个文档添加其对应的标签
for doc_id, doc_tags in document_tags_map.items():
doc_tag_ids = [
all_tag_dict[(tag.get("key"),tag.get("value"))]
for tag in doc_tags
if (tag.get("key"),tag.get("value")) in all_tag_dict
]
if doc_tag_ids:
document_tag_id_map[doc_id] = doc_tag_ids
if document_tag_id_map:
batch_add_document_tag(document_tag_id_map)

problem_model_list, problem_paragraph_mapping_list = (
ProblemParagraphManage(problem_paragraph_object_list, knowledge_id).to_problem_model_list()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The provided Python code is generally well-structured and follows Django best practices. However, there are a few areas where improvements could be made:

  1. Imports: The typing.Any import can often lead to type-checking errors. It might be better to specify more precise types where possible.

  2. Method Names: Method names like convert_uuid_to_str, link_file, etc., are good for clarity but consider using underscores consistently (snake_case) instead of hyphens (-).

  3. Error Handling: Ensure that error handling is robust enough to manage any exceptions that may arise during file operations or database interactions.

  4. Database Operations: For larger datasets, ensure that transactions are used to group multiple SQL queries into one transaction block, which can improve performance and consistency.

  5. Serialization: Consider simplifying the serialization process by using Django's built-in serializers where appropriate, which can reduce boilerplate code.

  6. Functionality: Verify that all functions, particularly those related to data insertion, handle edge cases and invalid inputs gracefully.

  7. Code Comments: Add comments explaining complex logic or sections of the code, especially around multi-line comments or blocks of repeated code.

Here are some specific suggestions:

# Example refinement: Using snake_case for method names
def convert_uuid_to_str(obj):
    # Implementation...
    pass

# Adding error handling
try:
    link_file(source_file_id, document_id)
except Exception as e:
    print(f"Error linking file: {e}")

# Simplified serialization with Django SerializerUtils if applicable
from django.core.exceptions import ValidationError

def create_serializer(data):
    try:
        serializer = MyDjangoModelSerializer(data=data)
        serializer.is_valid(raise_exception=True)
        return serializer.validated_data
    except ValidationError as e:
        raise SerialzationException(str(e))

# Improved function for tag management
@transaction.atomic
def save_knowledge_tags(concerned_document_id, concerned_tags):
    old_tag_models = DocumentTag.objects.filter(document_id__in=[concerned_document_id])
    
    new_tag_list = []
    existing_tag_dict = {}

    for tag in concerned_tags:
        key_value_pair = (tag['key'], tag['value'])
        if key_value_pair not in existing_tag_dict.values():
            new_tag = DocmumentTag(document_id=concerned_document_id, **key_value_pair)
            new_tag_list.append(new_tag)
            existing_tag_dict[key_value_pair] = f"{new_tag.id}"

    if new_tag_list:
        DocumentTag.objects.bulk_create(new_tag_list)

# Refactoring batch_add_document_tag
@transaction.atomic
def batch_add_document_tag(relation_map):
    existing_relations = set(QuerySet(...).filter(...).values_list('...'))
    new_entries = []

    for doc_id in relation_map:
        for tag_id in relation_map[doc_id]:
            if (doc_id, tag_id) not in existing_relations:
                entry = {...}
                new_entries.append(entry)
   
    if new_entries:
       _queryset.objects.bulk_create(..., new_entries)

These changes make the code cleaner, more maintainable, and safer from potential runtime errors when interacting with databases or external services.

Expand Down
2 changes: 1 addition & 1 deletion apps/locales/en_US/LC_MESSAGES/django.po
Original file line number Diff line number Diff line change
Expand Up @@ -8810,7 +8810,7 @@ msgstr ""
msgid "Audio file recognition - Tongyi Qwen"
msgstr ""

msgid "Audio file recognition - Fun-ASR/Paraformer/SenseVoice"
msgid "Real-time speech recognition - Fun-ASR/Paraformer"
msgstr ""

msgid "Qwen-Omni"
Expand Down
4 changes: 2 additions & 2 deletions apps/locales/zh_CN/LC_MESSAGES/django.po
Original file line number Diff line number Diff line change
Expand Up @@ -8936,8 +8936,8 @@ msgstr "阻止不安全的重定向到内部主机"
msgid "Audio file recognition - Tongyi Qwen"
msgstr "录音文件识别-通义千问"

msgid "Audio file recognition - Fun-ASR/Paraformer/SenseVoice"
msgstr "录音文件识别-Fun-ASR/Paraformer/SenseVoice"
msgid "Real-time speech recognition - Fun-ASR/Paraformer"
msgstr "实时语音识别-Fun-ASR/Paraformer"

msgid "Qwen-Omni"
msgstr "多模态"
4 changes: 2 additions & 2 deletions apps/locales/zh_Hant/LC_MESSAGES/django.po
Original file line number Diff line number Diff line change
Expand Up @@ -8936,8 +8936,8 @@ msgstr "阻止不安全的重定向到內部主機"
msgid "Audio file recognition - Tongyi Qwen"
msgstr "錄音文件識別-通義千問"

msgid "Audio file recognition - Fun-ASR/Paraformer/SenseVoice"
msgstr "錄音文件識別-Fun-ASR/Paraformer/SenseVoice"
msgid "Real-time speech recognition - Fun-ASR/Paraformer"
msgstr "實時語音識別-Fun-ASR/Paraformer"

msgid "Qwen-Omni"
msgstr "多模態"
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@


class AliyunBaiLianDefaultSTTModelCredential(BaseForm, BaseModelCredential):
type = forms.Radio(_("Type"), required=True, text_field='label', default_value='qwen', provider='', method='',
type = forms.SingleSelect(_("API"), required=True, text_field='label', default_value='qwen', provider='', method='',
value_field='value', option_list=[
{'label': _('Audio file recognition - Tongyi Qwen'),
'value': 'qwen'},
{'label': _('Qwen-Omni'),
'value': 'omni'},
{'label': _('Audio file recognition - Fun-ASR/Paraformer/SenseVoice'),
{'label': _('Real-time speech recognition - Fun-ASR/Paraformer'),
'value': 'other'}
])
api_url = forms.TextInputField(_('API URL'), required=True, relation_show_field_dict={'type': ['qwen', 'omni']})
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code looks mostly clean, but here are some suggestions:

  1. The option_list should be indented under the type field definition.

Here's the corrected version with indentation added:

class AliyunBaiLianDefaultSTTModelCredential(BaseForm, BaseModelCredential):
    type = forms.SingleSelect(
        _("API"), 
        required=True,
        text_field='label',
        default_value='qwen', 
        provider='', method='',
        value_field='value',
        option_list=[
            {'label': _('Audio file recognition - Tongyi Qwen'),
                'value': 'qwen'},
            {'label': _('Qwen-Omni'),
                'value': 'omni'},
            {'label': _('Real-time speech recognition - Fun-ASR/Paraformer'),
                'value': 'other'}
        ]
    )
    api_url = forms.TextInputField(
        _('API URL'), 
        required=True,
        relation_show_field_dict={'type': ['qwen', 'omni']}
    )
  1. Consider renaming the api_url to something more descriptive like speech_recognition_endpoint, as it suggests that this field is related to speech recognition rather than simply an API URL.

  2. Although not strictly necessary from a technical perspective, ensuring consistent coding style (e.g., using consistent spacing around operators and closing braces) can improve readability.

These changes make the code more readable and maintainable while maintaining its functional correctness.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def speech_to_text(self, audio_file):
"format": "mp3",
},
},
{"type": "text", "text": self.params.get('CueWord')},
{"type": "text", "text": self.params.get('CueWord') or '这段音频在说什么'},
],
},
],
Expand All @@ -77,7 +77,7 @@ def speech_to_text(self, audio_file):
# stream 必须设置为 True,否则会报错
stream=True,
stream_options={"include_usage": True},
extra_body=self.params
extra_body = {'enable_thinking': False, **self.params},
)
result = []
for chunk in completion:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The provided code checks for certain irregularities that need addressing:

Irregularity/Issue

  1. Missing self Parameter: The extra_body parameter is missing the self. prefix before it. This will raise an AttributeError because Python needs to access class attributes using the instance self.

  •            extra_body=self.params
    
  •            extra_body=self.extra_body
    

### Optimization Suggestion (if applicable)

There doesn't appear to be any immediate optimization opportunities within this function, as it primarily involves processing audio and sending requests.

### Additional Considerations

- **Exception Handling**: Ensure there's proper error handling around HTTP requests if they might fail unexpectedly.
  
  ```python
  except Exception as e:
      print(f"An error occurred in speech_to_text: {e}")
  • Logging: Implement logging statements to track the flow of execution and errors more effectively.

    import logging
    
    logger = logging.getLogger(__name__)
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
    
    handler = logging.StreamHandler()
    handler.setFormatter(formatter)
    
    logger.addHandler(handler)
    
    try:
        completion_stream = openai.Completion.create(
            ...
            additional_body={'enable_thinking': False, **self.params}
        )
    except Exception as e:
        logger.error(f"Error generating text from AI: {e}")
        return None

These changes should help ensure that the code functions correctly without syntax errors and improves robustness through some basic error handling practices.

Expand Down
Loading