Skip to content

Commit 44c0a7b

Browse files
committed
feat: Knowledge base generation problem
1 parent 596b137 commit 44c0a7b

File tree

11 files changed

+139
-10
lines changed

11 files changed

+139
-10
lines changed

apps/dataset/serializers/common_serializers.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,3 +222,26 @@ def get_embedding_model_id_by_dataset_id_list(dataset_id_list: List):
222222
if len(dataset_list) == 0:
223223
raise Exception(_('Knowledge base setting error, please reset the knowledge base'))
224224
return str(dataset_list[0].embedding_mode_id)
225+
226+
227+
class GenerateRelatedSerializer(ApiMixin, serializers.Serializer):
228+
model_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid(_('Model id')))
229+
prompt = serializers.CharField(required=True, error_messages=ErrMessage.uuid(_('Prompt word')))
230+
state_list = serializers.ListField(required=False, child=serializers.CharField(required=True),
231+
error_messages=ErrMessage.list("state list"))
232+
233+
@staticmethod
234+
def get_request_body_api():
235+
return openapi.Schema(
236+
type=openapi.TYPE_OBJECT,
237+
properties={
238+
'model_id': openapi.Schema(type=openapi.TYPE_STRING,
239+
title=_('Model id'),
240+
description=_('Model id')),
241+
'prompt': openapi.Schema(type=openapi.TYPE_STRING, title=_('Prompt word'),
242+
description=_("Prompt word")),
243+
'state_list': openapi.Schema(type=openapi.TYPE_ARRAY,
244+
items=openapi.Schema(type=openapi.TYPE_STRING),
245+
title=_('state list'))
246+
}
247+
)

apps/dataset/serializers/dataset_serializers.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from django.core import validators
2424
from django.db import transaction, models
2525
from django.db.models import QuerySet
26+
from django.db.models.functions import Reverse, Substr
2627
from django.http import HttpResponse
2728
from drf_yasg import openapi
2829
from rest_framework import serializers
@@ -42,9 +43,10 @@
4243
from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, ProblemParagraphMapping, TaskType, \
4344
State, File, Image
4445
from dataset.serializers.common_serializers import list_paragraph, MetaSerializer, ProblemParagraphManage, \
45-
get_embedding_model_by_dataset_id, get_embedding_model_id_by_dataset_id, write_image, zip_dir
46+
get_embedding_model_by_dataset_id, get_embedding_model_id_by_dataset_id, write_image, zip_dir, \
47+
GenerateRelatedSerializer
4648
from dataset.serializers.document_serializers import DocumentSerializers, DocumentInstanceSerializer
47-
from dataset.task import sync_web_dataset, sync_replace_web_dataset
49+
from dataset.task import sync_web_dataset, sync_replace_web_dataset, generate_related_by_dataset_id
4850
from embedding.models import SearchMode
4951
from embedding.task import embedding_by_dataset, delete_embedding_by_dataset
5052
from setting.models import AuthOperate, Model
@@ -814,6 +816,31 @@ def re_embedding(self, with_valid=True):
814816
except AlreadyQueued as e:
815817
raise AppApiException(500, _('Failed to send the vectorization task, please try again later!'))
816818

819+
def generate_related(self, instance: Dict, with_valid=True):
820+
if with_valid:
821+
self.is_valid(raise_exception=True)
822+
GenerateRelatedSerializer(data=instance).is_valid(raise_exception=True)
823+
dataset_id = self.data.get('id')
824+
model_id = instance.get("model_id")
825+
prompt = instance.get("prompt")
826+
state_list = instance.get('state_list')
827+
ListenerManagement.update_status(QuerySet(Document).filter(dataset_id=dataset_id),
828+
TaskType.GENERATE_PROBLEM,
829+
State.PENDING)
830+
ListenerManagement.update_status(QuerySet(Paragraph).annotate(
831+
reversed_status=Reverse('status'),
832+
task_type_status=Substr('reversed_status', TaskType.GENERATE_PROBLEM.value,
833+
1),
834+
).filter(task_type_status__in=state_list, dataset_id=dataset_id)
835+
.values('id'),
836+
TaskType.GENERATE_PROBLEM,
837+
State.PENDING)
838+
ListenerManagement.get_aggregation_document_status_by_dataset_id(dataset_id)()
839+
try:
840+
generate_related_by_dataset_id.delay(dataset_id, model_id, prompt, state_list)
841+
except AlreadyQueued as e:
842+
raise AppApiException(500, _('Failed to send the vectorization task, please try again later!'))
843+
817844
def list_application(self, with_valid=True):
818845
if with_valid:
819846
self.is_valid(raise_exception=True)

apps/dataset/task/generate.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,17 @@ def is_the_task_interrupted():
6464
return is_the_task_interrupted
6565

6666

67+
@celery_app.task(base=QueueOnce, once={'keys': ['dataset_id']},
68+
name='celery:generate_related_by_dataset')
69+
def generate_related_by_dataset_id(dataset_id, model_id, prompt, state_list=None):
70+
document_list = QuerySet(Document).filter(dataset_id=dataset_id)
71+
for document in document_list:
72+
try:
73+
generate_related_by_document_id.delay(document.id, model_id, prompt, state_list)
74+
except Exception as e:
75+
pass
76+
77+
6778
@celery_app.task(base=QueueOnce, once={'keys': ['document_id']},
6879
name='celery:generate_related_by_document')
6980
def generate_related_by_document_id(document_id, model_id, prompt, state_list=None):

apps/dataset/urls.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
path('dataset/<str:dataset_id>/export', views.Dataset.Export.as_view(), name="export"),
1212
path('dataset/<str:dataset_id>/export_zip', views.Dataset.ExportZip.as_view(), name="export_zip"),
1313
path('dataset/<str:dataset_id>/re_embedding', views.Dataset.Embedding.as_view(), name="dataset_key"),
14+
path('dataset/<str:dataset_id>/generate_related', views.Dataset.GenerateRelated.as_view(),
15+
name="dataset_generate_related"),
1416
path('dataset/<str:dataset_id>/application', views.Dataset.Application.as_view()),
1517
path('dataset/<int:current_page>/<int:page_size>', views.Dataset.Page.as_view(), name="dataset"),
1618
path('dataset/<str:dataset_id>/sync_web', views.Dataset.SyncWeb.as_view()),

apps/dataset/views/dataset.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from common.response import result
2222
from common.response.result import get_page_request_params, get_page_api_response, get_api_response
2323
from common.swagger_api.common_api import CommonApi
24+
from dataset.serializers.common_serializers import GenerateRelatedSerializer
2425
from dataset.serializers.dataset_serializers import DataSetSerializers
2526
from dataset.views.common import get_dataset_operation_object
2627
from setting.serializers.provider_serializers import ModelSerializer
@@ -173,6 +174,23 @@ def put(self, request: Request, dataset_id: str):
173174
return result.success(
174175
DataSetSerializers.Operate(data={'id': dataset_id, 'user_id': request.user.id}).re_embedding())
175176

177+
class GenerateRelated(APIView):
178+
authentication_classes = [TokenAuth]
179+
180+
@action(methods=['PUT'], detail=False)
181+
@swagger_auto_schema(operation_summary=_('Generate related'), operation_id=_('Generate related'),
182+
manual_parameters=DataSetSerializers.Operate.get_request_params_api(),
183+
request_body=GenerateRelatedSerializer.get_request_body_api(),
184+
tags=[_('Knowledge Base')]
185+
)
186+
@log(menu='document', operate="Generate related documents",
187+
get_operation_object=lambda r, keywords: get_dataset_operation_object(keywords.get('dataset_id'))
188+
)
189+
def put(self, request: Request, dataset_id: str):
190+
return result.success(
191+
DataSetSerializers.Operate(data={'id': dataset_id, 'user_id': request.user.id}).generate_related(
192+
request.data))
193+
176194
class Export(APIView):
177195
authentication_classes = [TokenAuth]
178196

apps/locales/en_US/LC_MESSAGES/django.po

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7487,4 +7487,7 @@ msgid "Field: {name} Type: {_type} Value: {value} Unsupported types"
74877487
msgstr ""
74887488

74897489
msgid "Field: {name} No value set"
7490+
msgstr ""
7491+
7492+
msgid "Generate related"
74907493
msgstr ""

apps/locales/zh_CN/LC_MESSAGES/django.po

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7650,4 +7650,7 @@ msgid "Field: {name} Type: {_type} Value: {value} Unsupported types"
76507650
msgstr "字段: {name} 类型: {_type} 值: {value} 不支持的类型"
76517651

76527652
msgid "Field: {name} No value set"
7653-
msgstr "字段: {name} 未设置值"
7653+
msgstr "字段: {name} 未设置值"
7654+
7655+
msgid "Generate related"
7656+
msgstr "生成问题"

apps/locales/zh_Hant/LC_MESSAGES/django.po

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7660,4 +7660,7 @@ msgid "Field: {name} Type: {_type} Value: {value} Unsupported types"
76607660
msgstr "欄位: {name} 類型: {_type} 值: {value} 不支持的類型"
76617661

76627662
msgid "Field: {name} No value set"
7663-
msgstr "欄位: {name} 未設定值"
7663+
msgstr "欄位: {name} 未設定值"
7664+
7665+
msgid "Generate related"
7666+
msgstr "生成問題"

ui/src/api/dataset.ts

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,20 @@ const importLarkDocument: (
277277
) => Promise<Result<Array<any>>> = (dataset_id, data, loading) => {
278278
return post(`${prefix}/lark/${dataset_id}/import`, data, null, loading)
279279
}
280+
/**
281+
* 生成关联问题
282+
* @param dataset_id 知识库id
283+
* @param data
284+
* @param loading
285+
* @returns
286+
*/
287+
const generateRelated: (
288+
dataset_id: string,
289+
data: any,
290+
loading?: Ref<boolean>
291+
) => Promise<Result<Array<any>>> = (dataset_id, data, loading) => {
292+
return put(`${prefix}/${dataset_id}/generate_related`, data, null, loading)
293+
}
280294

281295
export default {
282296
getDataset,
@@ -297,5 +311,6 @@ export default {
297311
postLarkDataset,
298312
getLarkDocumentList,
299313
importLarkDocument,
300-
putLarkDataset
314+
putLarkDataset,
315+
generateRelated
301316
}

ui/src/components/generate-related-dialog/index.vue

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
/>
5252
</el-form-item>
5353
<el-form-item
54-
v-if="apiType === 'document'"
54+
v-if="['document', 'dataset'].includes(apiType)"
5555
:label="$t('components.selectParagraph.title')"
5656
prop="state"
5757
>
@@ -107,6 +107,7 @@ const stateMap = {
107107
error: ['0', '1', '3', '4', '5', 'n']
108108
}
109109
const FormRef = ref()
110+
const datasetId = ref<string>()
110111
const userId = user.userInfo?.id as string
111112
const form = ref(prompt.get(userId))
112113
const rules = reactive({
@@ -133,7 +134,8 @@ watch(dialogVisible, (bool) => {
133134
}
134135
})
135136
136-
const open = (ids: string[], type: string) => {
137+
const open = (ids: string[], type: string, _datasetId?: string) => {
138+
datasetId.value = _datasetId
137139
getModel()
138140
idList.value = ids
139141
apiType.value = type
@@ -169,6 +171,15 @@ const submitHandle = async (formEl: FormInstance) => {
169171
emit('refresh')
170172
dialogVisible.value = false
171173
})
174+
} else if (apiType.value === 'dataset') {
175+
const data = {
176+
...form.value,
177+
state_list: stateMap[state.value]
178+
}
179+
datasetApi.generateRelated(id ? id : datasetId.value, data, loading).then(() => {
180+
MsgSuccess(t('views.document.generateQuestion.successMessage'))
181+
dialogVisible.value = false
182+
})
172183
}
173184
}
174185
})
@@ -177,7 +188,7 @@ const submitHandle = async (formEl: FormInstance) => {
177188
function getModel() {
178189
loading.value = true
179190
datasetApi
180-
.getDatasetModel(id)
191+
.getDatasetModel(id ? id : datasetId.value)
181192
.then((res: any) => {
182193
modelOptions.value = groupBy(res?.data, 'provider')
183194
loading.value = false

0 commit comments

Comments
 (0)