|
44 | 44 | delete_problems_and_mappings |
45 | 45 | from knowledge.task.embedding import embedding_by_document, delete_embedding_by_document_list, \ |
46 | 46 | delete_embedding_by_document |
| 47 | +from knowledge.task.generate import generate_related_by_document_id |
47 | 48 | from knowledge.task.sync import sync_web_document |
48 | 49 | from maxkb.const import PROJECT_DIR |
49 | 50 |
|
@@ -109,17 +110,17 @@ class DocumentEditInstanceSerializer(serializers.Serializer): |
109 | 110 |
|
110 | 111 | @staticmethod |
111 | 112 | def get_meta_valid_map(): |
112 | | - dataset_meta_valid_map = { |
| 113 | + knowledge_meta_valid_map = { |
113 | 114 | KnowledgeType.BASE: MetaSerializer.BaseMeta, |
114 | 115 | KnowledgeType.WEB: MetaSerializer.WebMeta |
115 | 116 | } |
116 | | - return dataset_meta_valid_map |
| 117 | + return knowledge_meta_valid_map |
117 | 118 |
|
118 | 119 | def is_valid(self, *, document: Document = None): |
119 | 120 | super().is_valid(raise_exception=True) |
120 | 121 | if 'meta' in self.data and self.data.get('meta') is not None: |
121 | | - dataset_meta_valid_map = self.get_meta_valid_map() |
122 | | - valid_class = dataset_meta_valid_map.get(document.type) |
| 122 | + knowledge_meta_valid_map = self.get_meta_valid_map() |
| 123 | + valid_class = knowledge_meta_valid_map.get(document.type) |
123 | 124 | valid_class(data=self.data.get('meta')).is_valid(raise_exception=True) |
124 | 125 |
|
125 | 126 |
|
@@ -154,6 +155,18 @@ class DocumentRefreshSerializer(serializers.Serializer): |
154 | 155 | state_list = serializers.ListField(required=True, label=_('state list')) |
155 | 156 |
|
156 | 157 |
|
| 158 | +class DocumentBatchRefreshSerializer(serializers.Serializer): |
| 159 | + id_list = serializers.ListField(required=True, label=_('id list')) |
| 160 | + state_list = serializers.ListField(required=True, label=_('state list')) |
| 161 | + |
| 162 | + |
| 163 | +class DocumentBatchGenerateRelatedSerializer(serializers.Serializer): |
| 164 | + document_id_list = serializers.ListField(required=True, label=_('document id list')) |
| 165 | + model_id = serializers.UUIDField(required=True, label=_('model id')) |
| 166 | + prompt = serializers.CharField(required=True, label=_('prompt')) |
| 167 | + state_list = serializers.ListField(required=True, label=_('state list')) |
| 168 | + |
| 169 | + |
157 | 170 | class BatchEditHitHandlingSerializer(serializers.Serializer): |
158 | 171 | id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True), label=_('id list')) |
159 | 172 | hit_handling_method = serializers.CharField(required=True, label=_('hit handling method')) |
@@ -521,26 +534,28 @@ def save_web(self, instance: Dict, with_valid=True): |
521 | 534 | if with_valid: |
522 | 535 | DocumentWebInstanceSerializer(data=instance).is_valid(raise_exception=True) |
523 | 536 | self.is_valid(raise_exception=True) |
524 | | - dataset_id = self.data.get('dataset_id') |
| 537 | + knowledge_id = self.data.get('knowledge_id') |
525 | 538 | source_url_list = instance.get('source_url_list') |
526 | 539 | selector = instance.get('selector') |
527 | | - sync_web_document.delay(dataset_id, source_url_list, selector) |
| 540 | + sync_web_document.delay(knowledge_id, source_url_list, selector) |
528 | 541 |
|
529 | 542 | def save_qa(self, instance: Dict, with_valid=True): |
530 | 543 | if with_valid: |
531 | 544 | DocumentInstanceQASerializer(data=instance).is_valid(raise_exception=True) |
532 | 545 | self.is_valid(raise_exception=True) |
533 | 546 | file_list = instance.get('file_list') |
534 | 547 | document_list = flat_map([self.parse_qa_file(file) for file in file_list]) |
535 | | - return DocumentSerializers.Batch(data={'dataset_id': self.data.get('dataset_id')}).batch_save(document_list) |
| 548 | + return DocumentSerializers.Batch(data={'knowledge_id': self.data.get('knowledge_id')}).batch_save( |
| 549 | + document_list) |
536 | 550 |
|
537 | 551 | def save_table(self, instance: Dict, with_valid=True): |
538 | 552 | if with_valid: |
539 | 553 | DocumentInstanceTableSerializer(data=instance).is_valid(raise_exception=True) |
540 | 554 | self.is_valid(raise_exception=True) |
541 | 555 | file_list = instance.get('file_list') |
542 | 556 | document_list = flat_map([self.parse_table_file(file) for file in file_list]) |
543 | | - return DocumentSerializers.Batch(data={'dataset_id': self.data.get('dataset_id')}).batch_save(document_list) |
| 557 | + return DocumentSerializers.Batch(data={'knowledge_id': self.data.get('knowledge_id')}).batch_save( |
| 558 | + document_list) |
544 | 559 |
|
545 | 560 | def parse_qa_file(self, file): |
546 | 561 | get_buffer = FileBufferHandle().get_buffer |
@@ -788,6 +803,42 @@ def batch_refresh(self, instance: Dict, with_valid=True): |
788 | 803 | except AlreadyQueued as e: |
789 | 804 | pass |
790 | 805 |
|
| 806 | + class BatchGenerateRelated(serializers.Serializer): |
| 807 | + workspace_id = serializers.CharField(required=True, label=_('workspace id')) |
| 808 | + knowledge_id = serializers.UUIDField(required=True, label=_('knowledge id')) |
| 809 | + |
| 810 | + def batch_generate_related(self, instance: Dict, with_valid=True): |
| 811 | + if with_valid: |
| 812 | + self.is_valid(raise_exception=True) |
| 813 | + document_id_list = instance.get("document_id_list") |
| 814 | + model_id = instance.get("model_id") |
| 815 | + prompt = instance.get("prompt") |
| 816 | + state_list = instance.get('state_list') |
| 817 | + ListenerManagement.update_status( |
| 818 | + QuerySet(Document).filter(id__in=document_id_list), |
| 819 | + TaskType.GENERATE_PROBLEM, |
| 820 | + State.PENDING |
| 821 | + ) |
| 822 | + ListenerManagement.update_status( |
| 823 | + QuerySet(Paragraph).annotate( |
| 824 | + reversed_status=Reverse('status'), |
| 825 | + task_type_status=Substr('reversed_status', TaskType.GENERATE_PROBLEM.value, |
| 826 | + 1), |
| 827 | + ).filter( |
| 828 | + task_type_status__in=state_list, document_id__in=document_id_list |
| 829 | + ) |
| 830 | + .values('id'), |
| 831 | + TaskType.GENERATE_PROBLEM, |
| 832 | + State.PENDING |
| 833 | + ) |
| 834 | + ListenerManagement.get_aggregation_document_status_by_query_set( |
| 835 | + QuerySet(Document).filter(id__in=document_id_list))() |
| 836 | + try: |
| 837 | + for document_id in document_id_list: |
| 838 | + generate_related_by_document_id.delay(document_id, model_id, prompt, state_list) |
| 839 | + except AlreadyQueued as e: |
| 840 | + pass |
| 841 | + |
791 | 842 |
|
792 | 843 | class FileBufferHandle: |
793 | 844 | buffer = None |
|
0 commit comments