|
| 1 | +# coding=utf-8 |
| 2 | +""" |
| 3 | + @project: MaxKB |
| 4 | + @Author:虎 |
| 5 | + @file: model_apply_serializers.py |
| 6 | + @date:2024/8/20 20:39 |
| 7 | + @desc: |
| 8 | +""" |
| 9 | +from django.db import connection |
| 10 | +from django.db.models import QuerySet |
| 11 | +from langchain_core.documents import Document |
| 12 | +from rest_framework import serializers |
| 13 | + |
| 14 | +from common.config.embedding_config import ModelManage |
| 15 | +from django.utils.translation import gettext_lazy as _ |
| 16 | + |
| 17 | +from models_provider.models import Model |
| 18 | +from models_provider.tools import get_model |
| 19 | + |
| 20 | + |
| 21 | +def get_embedding_model(model_id): |
| 22 | + model = QuerySet(Model).filter(id=model_id).first() |
| 23 | + # 手动关闭数据库连接 |
| 24 | + connection.close() |
| 25 | + embedding_model = ModelManage.get_model(model_id, |
| 26 | + lambda _id: get_model(model, use_local=True)) |
| 27 | + return embedding_model |
| 28 | + |
| 29 | + |
| 30 | +class EmbedDocuments(serializers.Serializer): |
| 31 | + texts = serializers.ListField(required=True, child=serializers.CharField(required=True, |
| 32 | + label=_('vector text')), |
| 33 | + label=_('vector text list')), |
| 34 | + |
| 35 | + |
| 36 | +class EmbedQuery(serializers.Serializer): |
| 37 | + text = serializers.CharField(required=True, label=_('vector text')) |
| 38 | + |
| 39 | + |
| 40 | +class CompressDocument(serializers.Serializer): |
| 41 | + page_content = serializers.CharField(required=True, label=_('text')) |
| 42 | + metadata = serializers.DictField(required=False, label=_('metadata')) |
| 43 | + |
| 44 | + |
| 45 | +class CompressDocuments(serializers.Serializer): |
| 46 | + documents = CompressDocument(required=True, many=True) |
| 47 | + query = serializers.CharField(required=True, label=_('query')) |
| 48 | + |
| 49 | + |
| 50 | +class ModelApplySerializers(serializers.Serializer): |
| 51 | + model_id = serializers.UUIDField(required=True, label=_('model id')) |
| 52 | + |
| 53 | + def embed_documents(self, instance, with_valid=True): |
| 54 | + if with_valid: |
| 55 | + self.is_valid(raise_exception=True) |
| 56 | + EmbedDocuments(data=instance).is_valid(raise_exception=True) |
| 57 | + |
| 58 | + model = get_embedding_model(self.data.get('model_id')) |
| 59 | + return model.embed_documents(instance.getlist('texts')) |
| 60 | + |
| 61 | + def embed_query(self, instance, with_valid=True): |
| 62 | + if with_valid: |
| 63 | + self.is_valid(raise_exception=True) |
| 64 | + EmbedQuery(data=instance).is_valid(raise_exception=True) |
| 65 | + |
| 66 | + model = get_embedding_model(self.data.get('model_id')) |
| 67 | + return model.embed_query(instance.get('text')) |
| 68 | + |
| 69 | + def compress_documents(self, instance, with_valid=True): |
| 70 | + if with_valid: |
| 71 | + self.is_valid(raise_exception=True) |
| 72 | + CompressDocuments(data=instance).is_valid(raise_exception=True) |
| 73 | + model = get_embedding_model(self.data.get('model_id')) |
| 74 | + return [{'page_content': d.page_content, 'metadata': d.metadata} for d in model.compress_documents( |
| 75 | + [Document(page_content=document.get('page_content'), metadata=document.get('metadata')) for document in |
| 76 | + instance.get('documents')], instance.get('query'))] |
0 commit comments