Skip to content

Commit e702af8

Browse files
committed
feat: enhance Document API with workspace ID support for get, put, and delete operations
1 parent 3e9069a commit e702af8

File tree

7 files changed

+263
-16
lines changed

7 files changed

+263
-16
lines changed

apps/common/handle/impl/table/csv_parse_table_handle.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
max_kb = logging.getLogger("max_kb")
99

1010

11-
class CsvSplitHandle(BaseParseTableHandle):
11+
class CsvParseTableHandle(BaseParseTableHandle):
1212
def support(self, file, get_buffer):
1313
file_name: str = file.name.lower()
1414
if file_name.endswith(".csv"):

apps/common/handle/impl/table/xls_parse_table_handle.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
max_kb = logging.getLogger("max_kb")
99

1010

11-
class XlsSplitHandle(BaseParseTableHandle):
11+
class XlsParseTableHandle(BaseParseTableHandle):
1212
def support(self, file, get_buffer):
1313
file_name: str = file.name.lower()
1414
buffer = get_buffer(file)

apps/common/handle/impl/table/xlsx_parse_table_handle.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
max_kb = logging.getLogger("max_kb")
1111

1212

13-
class XlsxSplitHandle(BaseParseTableHandle):
13+
class XlsxParseTableHandle(BaseParseTableHandle):
1414
def support(self, file, get_buffer):
1515
file_name: str = file.name.lower()
1616
if file_name.endswith('.xlsx'):

apps/knowledge/api/document.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from common.mixins.api_mixin import APIMixin
55
from common.result import DefaultResultSerializer
66
from knowledge.serializers.common import BatchSerializer
7-
from knowledge.serializers.document import DocumentInstanceSerializer
7+
from knowledge.serializers.document import DocumentInstanceSerializer, DocumentWebInstanceSerializer
88

99

1010
class DocumentSplitAPI(APIMixin):
@@ -176,3 +176,45 @@ def get_request():
176176

177177
class DocumentDeleteAPI(DocumentReadAPI):
178178
pass
179+
180+
181+
class TableDocumentCreateAPI(APIMixin):
182+
@staticmethod
183+
def get_parameters():
184+
return [
185+
OpenApiParameter(
186+
name="workspace_id",
187+
description="工作空间id",
188+
type=OpenApiTypes.STR,
189+
location='path',
190+
required=True,
191+
),
192+
OpenApiParameter(
193+
name="knowledge_id",
194+
description="知识库id",
195+
type=OpenApiTypes.STR,
196+
location='path',
197+
required=True,
198+
),
199+
OpenApiParameter(
200+
name="file",
201+
description="文件",
202+
type=OpenApiTypes.BINARY,
203+
location='query',
204+
required=False,
205+
),
206+
]
207+
208+
@staticmethod
209+
def get_response():
210+
return DefaultResultSerializer
211+
212+
213+
class QaDocumentCreateAPI(TableDocumentCreateAPI):
214+
pass
215+
216+
217+
class WebDocumentCreateAPI(APIMixin):
218+
@staticmethod
219+
def get_request():
220+
return DocumentWebInstanceSerializer

apps/knowledge/serializers/document.py

Lines changed: 140 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import logging
22
import os
3+
import re
34
import traceback
45
from functools import reduce
56
from typing import Dict, List
67

78
import uuid_utils.compat as uuid
89
from celery_once import AlreadyQueued
10+
from django.core import validators
911
from django.db import transaction, models
1012
from django.db.models import QuerySet, Model
1113
from django.db.models.functions import Substr, Reverse
@@ -16,6 +18,13 @@
1618
from common.event import ListenerManagement
1719
from common.event.common import work_thread_pool
1820
from common.exception.app_exception import AppApiException
21+
from common.handle.impl.qa.csv_parse_qa_handle import CsvParseQAHandle
22+
from common.handle.impl.qa.xls_parse_qa_handle import XlsParseQAHandle
23+
from common.handle.impl.qa.xlsx_parse_qa_handle import XlsxParseQAHandle
24+
from common.handle.impl.qa.zip_parse_qa_handle import ZipParseQAHandle
25+
from common.handle.impl.table.csv_parse_table_handle import CsvParseTableHandle
26+
from common.handle.impl.table.xls_parse_table_handle import XlsParseTableHandle
27+
from common.handle.impl.table.xlsx_parse_table_handle import XlsxParseTableHandle
1928
from common.handle.impl.text.csv_split_handle import CsvSplitHandle
2029
from common.handle.impl.text.doc_split_handle import DocSplitHandle
2130
from common.handle.impl.text.html_split_handle import HTMLSplitHandle
@@ -26,14 +35,16 @@
2635
from common.handle.impl.text.zip_split_handle import ZipSplitHandle
2736
from common.utils.common import post, get_file_content, bulk_create_in_batches
2837
from common.utils.fork import Fork
29-
from common.utils.split_model import get_split_model
38+
from common.utils.split_model import get_split_model, flat_map
3039
from knowledge.models import Knowledge, Paragraph, Problem, Document, KnowledgeType, ProblemParagraphMapping, State, \
3140
TaskType, File
32-
from knowledge.serializers.common import ProblemParagraphManage, BatchSerializer, get_embedding_model_id_by_knowledge_id
41+
from knowledge.serializers.common import ProblemParagraphManage, BatchSerializer, \
42+
get_embedding_model_id_by_knowledge_id, MetaSerializer
3343
from knowledge.serializers.paragraph import ParagraphSerializers, ParagraphInstanceSerializer, \
3444
delete_problems_and_mappings
3545
from knowledge.task.embedding import embedding_by_document, delete_embedding_by_document_list, \
3646
delete_embedding_by_document
47+
from knowledge.task.sync import sync_web_document
3748
from maxkb.const import PROJECT_DIR
3849

3950
default_split_handle = TextSplitHandle()
@@ -48,6 +59,9 @@
4859
default_split_handle
4960
]
5061

62+
parse_qa_handle_list = [XlsParseQAHandle(), CsvParseQAHandle(), XlsxParseQAHandle(), ZipParseQAHandle()]
63+
parse_table_handle_list = [CsvParseTableHandle(), XlsParseTableHandle(), XlsxParseTableHandle()]
64+
5165

5266
class BatchCancelInstanceSerializer(serializers.Serializer):
5367
id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True), label=('id list'))
@@ -67,6 +81,36 @@ class DocumentInstanceSerializer(serializers.Serializer):
6781
paragraphs = ParagraphInstanceSerializer(required=False, many=True, allow_null=True)
6882

6983

84+
class DocumentEditInstanceSerializer(serializers.Serializer):
85+
meta = serializers.DictField(required=False)
86+
name = serializers.CharField(required=False, max_length=128, min_length=1, label=_('document name'))
87+
hit_handling_method = serializers.CharField(required=False, validators=[
88+
validators.RegexValidator(regex=re.compile("^optimization|directly_return$"),
89+
message=_('The type only supports optimization|directly_return'),
90+
code=500)
91+
], label=_('hit handling method'))
92+
93+
directly_return_similarity = serializers.FloatField(required=False, max_value=2, min_value=0,
94+
label=_('directly return similarity'))
95+
96+
is_active = serializers.BooleanField(required=False, label=_('document is active'))
97+
98+
@staticmethod
99+
def get_meta_valid_map():
100+
dataset_meta_valid_map = {
101+
KnowledgeType.BASE: MetaSerializer.BaseMeta,
102+
KnowledgeType.WEB: MetaSerializer.WebMeta
103+
}
104+
return dataset_meta_valid_map
105+
106+
def is_valid(self, *, document: Document = None):
107+
super().is_valid(raise_exception=True)
108+
if 'meta' in self.data and self.data.get('meta') is not None:
109+
dataset_meta_valid_map = self.get_meta_valid_map()
110+
valid_class = dataset_meta_valid_map.get(document.type)
111+
valid_class(data=self.data.get('meta')).is_valid(raise_exception=True)
112+
113+
70114
class DocumentSplitRequest(serializers.Serializer):
71115
file = serializers.ListField(required=True, label=_('file list'))
72116
limit = serializers.IntegerField(required=False, label=_('limit'))
@@ -78,6 +122,22 @@ class DocumentSplitRequest(serializers.Serializer):
78122
with_filter = serializers.BooleanField(required=False, label=_('Auto Clean'))
79123

80124

125+
class DocumentWebInstanceSerializer(serializers.Serializer):
126+
source_url_list = serializers.ListField(required=True, label=_('document url list'),
127+
child=serializers.CharField(required=True, label=_('document url list')))
128+
selector = serializers.CharField(required=False, allow_null=True, allow_blank=True, label=_('selector'))
129+
130+
131+
class DocumentInstanceQASerializer(serializers.Serializer):
132+
file_list = serializers.ListSerializer(required=True, label=_('file list'),
133+
child=serializers.FileField(required=True, label=_('file')))
134+
135+
136+
class DocumentInstanceTableSerializer(serializers.Serializer):
137+
file_list = serializers.ListSerializer(required=True, label=_('file list'),
138+
child=serializers.FileField(required=True, label=_('file')))
139+
140+
81141
class DocumentSerializers(serializers.Serializer):
82142
class Query(serializers.Serializer):
83143
# 知识库id
@@ -226,6 +286,7 @@ def sync(self, with_valid=True, with_embedding=True):
226286
return True
227287

228288
class Operate(serializers.Serializer):
289+
workspace_id = serializers.CharField(required=True, label=_('workspace id'))
229290
document_id = serializers.UUIDField(required=True, label=_('document id'))
230291
knowledge_id = serializers.UUIDField(required=True, label=_('knowledge id'))
231292

@@ -246,6 +307,31 @@ def one(self, with_valid=False):
246307
}, select_string=get_file_content(
247308
os.path.join(PROJECT_DIR, "apps", "knowledge", 'sql', 'list_document.sql')), with_search_one=True)
248309

310+
def edit(self, instance: Dict, with_valid=False):
311+
if with_valid:
312+
self.is_valid(raise_exception=True)
313+
_document = QuerySet(Document).get(id=self.data.get("document_id"))
314+
if with_valid:
315+
DocumentEditInstanceSerializer(data=instance).is_valid(document=_document)
316+
update_keys = ['name', 'is_active', 'hit_handling_method', 'directly_return_similarity', 'meta']
317+
for update_key in update_keys:
318+
if update_key in instance and instance.get(update_key) is not None:
319+
_document.__setattr__(update_key, instance.get(update_key))
320+
_document.save()
321+
return self.one()
322+
323+
@transaction.atomic
324+
def delete(self):
325+
document_id = self.data.get("document_id")
326+
QuerySet(model=Document).filter(id=document_id).delete()
327+
# 删除段落
328+
QuerySet(model=Paragraph).filter(document_id=document_id).delete()
329+
# 删除问题
330+
delete_problems_and_mappings([document_id])
331+
# 删除向量库
332+
delete_embedding_by_document(document_id)
333+
return True
334+
249335
def refresh(self, state_list=None, with_valid=True):
250336
if state_list is None:
251337
state_list = [State.PENDING.value, State.STARTED.value, State.SUCCESS.value, State.FAILURE.value,
@@ -369,6 +455,58 @@ def get_document_paragraph_model(knowledge_id, instance: Dict):
369455
instance.get('paragraphs') if 'paragraphs' in instance else []
370456
)
371457

458+
def save_web(self, instance: Dict, with_valid=True):
459+
if with_valid:
460+
DocumentWebInstanceSerializer(data=instance).is_valid(raise_exception=True)
461+
self.is_valid(raise_exception=True)
462+
dataset_id = self.data.get('dataset_id')
463+
source_url_list = instance.get('source_url_list')
464+
selector = instance.get('selector')
465+
sync_web_document.delay(dataset_id, source_url_list, selector)
466+
467+
def save_qa(self, instance: Dict, with_valid=True):
468+
if with_valid:
469+
DocumentInstanceQASerializer(data=instance).is_valid(raise_exception=True)
470+
self.is_valid(raise_exception=True)
471+
file_list = instance.get('file_list')
472+
document_list = flat_map([self.parse_qa_file(file) for file in file_list])
473+
return DocumentSerializers.Batch(data={'dataset_id': self.data.get('dataset_id')}).batch_save(document_list)
474+
475+
def save_table(self, instance: Dict, with_valid=True):
476+
if with_valid:
477+
DocumentInstanceTableSerializer(data=instance).is_valid(raise_exception=True)
478+
self.is_valid(raise_exception=True)
479+
file_list = instance.get('file_list')
480+
document_list = flat_map([self.parse_table_file(file) for file in file_list])
481+
return DocumentSerializers.Batch(data={'dataset_id': self.data.get('dataset_id')}).batch_save(document_list)
482+
483+
def parse_qa_file(self, file):
484+
get_buffer = FileBufferHandle().get_buffer
485+
for parse_qa_handle in parse_qa_handle_list:
486+
if parse_qa_handle.support(file, get_buffer):
487+
return parse_qa_handle.handle(file, get_buffer, self.save_image)
488+
raise AppApiException(500, _('Unsupported file format'))
489+
490+
def parse_table_file(self, file):
491+
get_buffer = FileBufferHandle().get_buffer
492+
for parse_table_handle in parse_table_handle_list:
493+
if parse_table_handle.support(file, get_buffer):
494+
return parse_table_handle.handle(file, get_buffer, self.save_image)
495+
raise AppApiException(500, _('Unsupported file format'))
496+
497+
def save_image(self, image_list):
498+
if image_list is not None and len(image_list) > 0:
499+
exist_image_list = [str(i.get('id')) for i in
500+
QuerySet(File).filter(id__in=[i.id for i in image_list]).values('id')]
501+
save_image_list = [image for image in image_list if not exist_image_list.__contains__(str(image.id))]
502+
save_image_list = list({img.id: img for img in save_image_list}.values())
503+
# save image
504+
for file in save_image_list:
505+
file_bytes = file.meta.pop('content')
506+
file.workspace_id = self.data.get('workspace_id')
507+
file.meta['knowledge_id'] = self.data.get('knowledge_id')
508+
file.save(file_bytes)
509+
372510
class Split(serializers.Serializer):
373511
workspace_id = serializers.CharField(required=True, label=_('workspace id'))
374512
knowledge_id = serializers.UUIDField(required=True, label=_('knowledge id'))

apps/knowledge/urls.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document', views.DocumentView.as_view()),
1212
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/split', views.DocumentView.Split.as_view()),
1313
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/batch', views.DocumentView.Batch.as_view()),
14+
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/web', views.WebDocumentView.as_view()),
15+
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/qa', views.QaDocumentView.as_view()),
16+
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/table', views.TableDocumentView.as_view()),
1417
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/<str:document_id>', views.DocumentView.Operate.as_view()),
1518
path('workspace/<str:workspace_id>/knowledge/<int:current_page>/<int:page_size>', views.KnowledgeView.Page.as_view()),
1619
]

0 commit comments

Comments
 (0)