|
18 | 18 |
|
19 | 19 | from common.config.embedding_config import ModelManage |
20 | 20 | from common.db.search import native_search |
21 | | -from common.db.sql_execute import update_execute |
| 21 | +from common.db.sql_execute import update_execute, sql_execute |
22 | 22 | from common.exception.app_exception import AppApiException |
23 | 23 | from common.mixins.api_mixin import ApiMixin |
24 | 24 | from common.util.field_message import ErrMessage |
25 | 25 | from common.util.file_util import get_file_content |
26 | 26 | from common.util.fork import Fork |
27 | | -from dataset.models import Paragraph, Problem, ProblemParagraphMapping, DataSet, File, Image |
| 27 | +from dataset.models import Paragraph, Problem, ProblemParagraphMapping, DataSet, File, Image, Document |
28 | 28 | from setting.models_provider import get_model |
29 | 29 | from smartdoc.conf import PROJECT_DIR |
30 | 30 | from django.utils.translation import gettext_lazy as _ |
@@ -224,6 +224,46 @@ def get_embedding_model_id_by_dataset_id_list(dataset_id_list: List): |
224 | 224 | return str(dataset_list[0].embedding_mode_id) |
225 | 225 |
|
226 | 226 |
|
| 227 | + |
| 228 | +def create_dataset_index(dataset_id=None, document_id=None): |
| 229 | + if dataset_id is None and document_id is None: |
| 230 | + raise AppApiException(500, _('Dataset ID or Document ID must be provided')) |
| 231 | + |
| 232 | + if dataset_id is not None: |
| 233 | + k_id = dataset_id |
| 234 | + else: |
| 235 | + document = QuerySet(Document).filter(id=document_id).first() |
| 236 | + k_id = document.dataset_id |
| 237 | + |
| 238 | + sql = f"SELECT indexname, indexdef FROM pg_indexes WHERE tablename = 'embedding' AND indexname = 'embedding_hnsw_idx_{k_id}'" |
| 239 | + index = sql_execute(sql, []) |
| 240 | + if not index: |
| 241 | + sql = f"SELECT vector_dims(embedding) AS dims FROM embedding WHERE dataset_id = '{k_id}' LIMIT 1" |
| 242 | + result = sql_execute(sql, []) |
| 243 | + if len(result) == 0: |
| 244 | + return |
| 245 | + dims = result[0]['dims'] |
| 246 | + sql = f"""CREATE INDEX "embedding_hnsw_idx_{k_id}" ON embedding USING hnsw ((embedding::vector({dims})) vector_cosine_ops) WHERE dataset_id = '{k_id}'""" |
| 247 | + update_execute(sql, []) |
| 248 | + |
| 249 | + |
| 250 | +def drop_dataset_index(dataset_id=None, document_id=None): |
| 251 | + if dataset_id is None and document_id is None: |
| 252 | + raise AppApiException(500, _('Dataset ID or Document ID must be provided')) |
| 253 | + |
| 254 | + if dataset_id is not None: |
| 255 | + k_id = dataset_id |
| 256 | + else: |
| 257 | + document = QuerySet(Document).filter(id=document_id).first() |
| 258 | + k_id = document.dataset_id |
| 259 | + |
| 260 | + sql = f"SELECT indexname, indexdef FROM pg_indexes WHERE tablename = 'embedding' AND indexname = 'embedding_hnsw_idx_{k_id}'" |
| 261 | + index = sql_execute(sql, []) |
| 262 | + if index: |
| 263 | + sql = f'DROP INDEX "embedding_hnsw_idx_{k_id}"' |
| 264 | + update_execute(sql, []) |
| 265 | + |
| 266 | + |
227 | 267 | class GenerateRelatedSerializer(ApiMixin, serializers.Serializer): |
228 | 268 | model_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid(_('Model id'))) |
229 | 269 | prompt = serializers.CharField(required=True, error_messages=ErrMessage.uuid(_('Prompt word'))) |
|
0 commit comments