Skip to content

Commit 54836d5

Browse files
committed
feat: add Embedding API for re-vectorization of knowledge objects
1 parent b6f5a83 commit 54836d5

File tree

4 files changed

+57
-1
lines changed

4 files changed

+57
-1
lines changed

apps/knowledge/api/knowledge.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,3 +244,7 @@ class HitTestAPI(SyncWebAPI):
244244
@staticmethod
245245
def get_request():
246246
return HitTestSerializer
247+
248+
249+
class EmbeddingAPI(SyncWebAPI):
250+
pass

apps/knowledge/serializers/knowledge.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from knowledge.task.generate import generate_related_by_knowledge_id
3232
from knowledge.task.sync import sync_web_knowledge, sync_replace_web_knowledge
3333
from maxkb.conf import PROJECT_DIR
34+
from models_provider.models import Model
3435

3536

3637
class KnowledgeModelSerializer(serializers.ModelSerializer):
@@ -80,6 +81,7 @@ def is_valid(self, *, knowledge: Knowledge = None):
8081
valid_class = knowledge_meta_valid_map.get(knowledge.type)
8182
valid_class(data=self.data.get('meta')).is_valid(raise_exception=True)
8283

84+
8385
class HitTestSerializer(serializers.Serializer):
8486
query_text = serializers.CharField(required=True, label=_('query text'))
8587
top_number = serializers.IntegerField(required=True, max_value=10000, min_value=1, label=_("top number"))
@@ -89,6 +91,7 @@ class HitTestSerializer(serializers.Serializer):
8991
message=_('The type only supports embedding|keywords|blend'), code=500)
9092
])
9193

94+
9295
class KnowledgeSerializer(serializers.Serializer):
9396
class Query(serializers.Serializer):
9497
workspace_id = serializers.CharField(required=True)
@@ -157,6 +160,36 @@ class Operate(serializers.Serializer):
157160
workspace_id = serializers.CharField(required=True, label=_('workspace id'))
158161
knowledge_id = serializers.UUIDField(required=True, label=_('knowledge id'))
159162

163+
@transaction.atomic
164+
def embedding(self, with_valid=True):
165+
if with_valid:
166+
self.is_valid(raise_exception=True)
167+
knowledge_id = self.data.get('knowledge_id')
168+
knowledge = QuerySet(Knowledge).filter(id=knowledge_id).first()
169+
embedding_model_id = knowledge.embedding_mode_id
170+
knowledge_user_id = knowledge.user_id
171+
embedding_model = QuerySet(Model).filter(id=embedding_model_id).first()
172+
if embedding_model is None:
173+
raise AppApiException(500, _('Model does not exist'))
174+
if embedding_model.permission_type == 'PRIVATE' and knowledge_user_id != embedding_model.user_id:
175+
raise AppApiException(500, _('No permission to use this model') + f"{embedding_model.name}")
176+
ListenerManagement.update_status(
177+
QuerySet(Document).filter(knowledge_id=self.data.get('knowledge_id')),
178+
TaskType.EMBEDDING,
179+
State.PENDING
180+
)
181+
ListenerManagement.update_status(
182+
QuerySet(Paragraph).filter(knowledge_id=self.data.get('knowledge_id')),
183+
TaskType.EMBEDDING,
184+
State.PENDING
185+
)
186+
ListenerManagement.get_aggregation_document_status_by_knowledge_id(self.data.get('knowledge_id'))()
187+
embedding_model_id = get_embedding_model_id_by_knowledge_id(self.data.get('knowledge_id'))
188+
try:
189+
embedding_by_knowledge.delay(knowledge_id, embedding_model_id)
190+
except AlreadyQueued as e:
191+
raise AppApiException(500, _('Failed to send the vectorization task, please try again later!'))
192+
160193
def generate_related(self, instance: Dict, with_valid=True):
161194
if with_valid:
162195
self.is_valid(raise_exception=True)

apps/knowledge/urls.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>', views.KnowledgeView.Operate.as_view()),
1111
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/sync', views.KnowledgeView.SyncWeb.as_view()),
1212
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/generate_related', views.KnowledgeView.GenerateRelated.as_view()),
13+
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/embedding', views.KnowledgeView.Embedding.as_view()),
1314
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/hit_test', views.KnowledgeView.HitTest.as_view()),
1415
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document', views.DocumentView.as_view()),
1516
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/split', views.DocumentView.Split.as_view()),

apps/knowledge/views/knowledge.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from common.constants.permission_constants import PermissionConstants
99
from common.result import result
1010
from knowledge.api.knowledge import KnowledgeBaseCreateAPI, KnowledgeWebCreateAPI, KnowledgeTreeReadAPI, \
11-
KnowledgeEditAPI, KnowledgeReadAPI, KnowledgePageAPI, SyncWebAPI, GenerateRelatedAPI, HitTestAPI
11+
KnowledgeEditAPI, KnowledgeReadAPI, KnowledgePageAPI, SyncWebAPI, GenerateRelatedAPI, HitTestAPI, EmbeddingAPI
1212
from knowledge.serializers.knowledge import KnowledgeSerializer
1313

1414

@@ -160,6 +160,24 @@ def put(self, request: Request, workspace_id: str, knowledge_id: str):
160160
}
161161
).hit_test())
162162

163+
class Embedding(APIView):
164+
authentication_classes = [TokenAuth]
165+
166+
@extend_schema(
167+
methods=['PUT'],
168+
summary=_('Re-vectorize'),
169+
description=_('Re-vectorize'),
170+
operation_id=_('Re-vectorize'),
171+
parameters=EmbeddingAPI.get_parameters(),
172+
responses=EmbeddingAPI.get_response(),
173+
tags=[_('Knowledge Base')]
174+
)
175+
@has_permissions(PermissionConstants.KNOWLEDGE_EDIT.get_workspace_permission())
176+
def put(self, request: Request, workspace_id: str, knowledge_id: str):
177+
return result.success(KnowledgeSerializer.Operate(
178+
data={'knowledge_id': knowledge_id, 'workspace_id': workspace_id, 'user_id': request.user.id}
179+
).embedding())
180+
163181
class GenerateRelated(APIView):
164182
authentication_classes = [TokenAuth]
165183

0 commit comments

Comments
 (0)