Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions aperag/context/context_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import time
from datetime import datetime

from langchain import PromptTemplate
from langchain_core.prompts import PromptTemplate
from tabulate import tabulate

from aperag.context.context import ContextManager
Expand Down Expand Up @@ -47,7 +47,7 @@
table_format = """
<html>
<head>
<style>
<style>
table {{ table-layout: fixed; border: 1px solid black; border-collapse: collapse; }}
th {{ border: 1px solid black; border-collapse: collapse; white-space:break-spaces; overflow:hidden; position:sticky; top: 0; padding: 5px;}}
td {{ border: 1px solid black; border-collapse: collapse; white-space:break-spaces; overflow:hidden; padding: 5px; word-wrap: break-all}}
Expand Down
4 changes: 2 additions & 2 deletions aperag/pipeline/base_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@
# limitations under the License.

import json
import uuid
import re
import uuid
from abc import ABC, abstractmethod
from typing import Dict, List, Optional

from langchain import PromptTemplate
from langchain.schema import AIMessage, HumanMessage
from langchain_core.prompts import PromptTemplate
from pydantic import BaseModel

from aperag.chat.history.base import BaseChatMessageHistory
Expand Down
7 changes: 3 additions & 4 deletions aperag/pipeline/common_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@
import asyncio
import logging
import random
import re

from langchain import PromptTemplate
from langchain_core.prompts import PromptTemplate

from aperag.llm.prompts import COMMON_FILE_TEMPLATE
from aperag.pipeline.base_pipeline import RELATED_QUESTIONS, Message, Pipeline
Expand Down Expand Up @@ -65,12 +64,12 @@ async def run(self, message, gen_references=False, message_id="", file=None):
related_questions.update(self.welcome_question)
if len(self.welcome_question) >= 3:
need_related_question = False


# TODO: divide file_content into several parts and call API separately.
context = file if file else ""
context += self.bot_context

if len(context) > self.context_window - 500:
context = context[:len(self.context_window) - 500]

Expand Down
2 changes: 1 addition & 1 deletion aperag/pipeline/keyword_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from typing import Any, Dict

from elasticsearch import AsyncElasticsearch
from langchain import PromptTemplate
from langchain_core.prompts import PromptTemplate

from aperag.llm.base import Predictor, PredictorType
from aperag.llm.prompts import KEYWORD_PROMPT_TEMPLATE
Expand Down
27 changes: 13 additions & 14 deletions aperag/pipeline/knowledge_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,9 @@
import json
import logging
import random
import re

from langchain import PromptTemplate
from langchain_core.prompts import PromptTemplate

from config import settings
from aperag.context.context import ContextManager
from aperag.context.full_text import search_document
from aperag.llm.prompts import (
Expand All @@ -39,6 +37,7 @@
generate_vector_db_collection_name,
now_unix_milliseconds,
)
from config import settings

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -120,11 +119,11 @@ async def run(self, message, gen_references=False, message_id=""):
history = []
tot_history_querys = ''
messages = await self.history.messages
history_querys = [json.loads(message.content)["query"] for message in messages if message.additional_kwargs["role"] == "human"]
history_querys = [json.loads(message.content)["query"] for message in messages if message.additional_kwargs["role"] == "human"]

if self.memory:
tot_history_querys = '\n'.join(history_querys[-self.memory_limit_count:])+'\n'

references = []
related_questions = set()
response = ""
Expand All @@ -135,7 +134,7 @@ async def run(self, message, gen_references=False, message_id=""):
vector = self.embedding_model.embed_query(tot_history_querys+message)
logger.info("[%s] embedding query end", log_prefix)
# hyde_task = asyncio.create_task(self.generate_hyde_message(message))

results = await async_run(self.qa_context_manager.query, tot_history_querys + message, score_threshold=0.5, topk=6, vector=vector)
logger.info("[%s] find relevant qa pairs in vector db end", log_prefix)
for result in results:
Expand All @@ -144,17 +143,17 @@ async def run(self, message, gen_references=False, message_id=""):
response = result_text["answer"]
if result.score < 0.8:
related_questions.add(result_text["question"])

# if len(related_questions) >= 3:
# need_related_question = False

if response != "":
yield response

if self.use_related_question and need_related_question:
related_question_prompt = self.related_question_prompt.format(query=message, context=response)
related_question_task = asyncio.create_task(self.generate_related_question(related_question_prompt))

else:
results = await async_run(self.context_manager.query, tot_history_querys + message,
score_threshold=self.score_threshold, topk=self.topk * 6, vector=vector)
Expand All @@ -165,14 +164,14 @@ async def run(self, message, gen_references=False, message_id=""):
# score_threshold=self.score_threshold, topk=self.topk * 6, vector=new_vector)
# results_set = set([result.text for result in results])
# results.extend(result for result in results2 if result.text not in results_set)

if self.bot_context != "":
bot_context_result = DocumentWithScore(
text=self.bot_context, # type: ignore
score=0,
)
results.append(bot_context_result)

if len(results) > 1:
results = await rerank(message, results)
logger.info("[%s] rerank candidates end", log_prefix)
Expand Down Expand Up @@ -200,7 +199,7 @@ async def run(self, message, gen_references=False, message_id=""):
related_questions.update(self.welcome_question)
if len(related_questions) >= 3:
need_related_question = False

if self.use_related_question and need_related_question:
related_question_prompt = self.related_question_prompt.format(query=message, context=context)
related_question_task = asyncio.create_task(self.generate_related_question(related_question_prompt))
Expand Down
14 changes: 7 additions & 7 deletions aperag/readers/base_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,24 @@
import aiohttp
import requests
from langchain.embeddings.base import Embeddings
from langchain.embeddings.huggingface import HuggingFaceBgeEmbeddings, HuggingFaceEmbeddings
from langchain_huggingface import HuggingFaceEmbeddings
from transformers import AutoModelForSequenceClassification, AutoTokenizer, MT5EncoderModel

from aperag.query.query import DocumentWithScore
from aperag.vectorstore.connector import VectorStoreConnectorAdaptor
from config.settings import (
EMBEDDING_BACKEND,
EMBEDDING_DEVICE,
EMBEDDING_DIMENSIONS,
EMBEDDING_MODEL,
EMBEDDING_SERVICE_MODEL,
EMBEDDING_SERVICE_TOKEN,
EMBEDDING_SERVICE_MODEL_UID,
EMBEDDING_SERVICE_TOKEN,
EMBEDDING_SERVICE_URL,
RERANK_BACKEND,
RERANK_SERVICE_MODEL_UID,
RERANK_SERVICE_URL,
EMBEDDING_DIMENSIONS,
)
from aperag.query.query import DocumentWithScore
from aperag.vectorstore.connector import VectorStoreConnectorAdaptor


class EmbeddingService(Embeddings):
Expand Down Expand Up @@ -309,8 +309,8 @@ async def rank(self, query, results: List[DocumentWithScore]):
model_kwargs={'device': EMBEDDING_DEVICE},
),
"text2vec": lambda: Text2VecEmbedding(device=EMBEDDING_DEVICE),
"bge": lambda: HuggingFaceBgeEmbeddings(
model_name="BAAI/bge-large-zh",
"bge": lambda: HuggingFaceEmbeddings(
model_name="BAAI/bge-large-zh-v1.5",
model_kwargs={'device': EMBEDDING_DEVICE},
encode_kwargs={'normalize_embeddings': True, 'batch_size': 16}
)
Expand Down
2 changes: 1 addition & 1 deletion aperag/readers/qa_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import re
from abc import ABC, abstractmethod

from langchain import PromptTemplate
from langchain_core.prompts import PromptTemplate

from aperag.llm.base import Predictor, PredictorType
from aperag.llm.prompts import CHINESE_QA_EXTRACTION_PROMPT_TEMPLATE
Expand Down
4 changes: 2 additions & 2 deletions aperag/readers/question_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import re
from abc import ABC

from langchain import PromptTemplate
from langchain_core.prompts import PromptTemplate

from aperag.llm.base import Predictor, PredictorType
from aperag.llm.prompts import QUESTION_EXTRACTION_PROMPT_TEMPLATE_V2
Expand All @@ -15,7 +15,7 @@ class QuestionGenerator(ABC):
def __init__(self, **kwargs):
self.prompt_template = PromptTemplate(template=QUESTION_EXTRACTION_PROMPT_TEMPLATE_V2, input_variables=["context"])
self.predictor = Predictor.from_model(model_name="gpt-4-1106-preview", predictor_type=PredictorType.CUSTOM_LLM, **kwargs)

def gen_questions(self, text):
prompt = self.prompt_template.format(context=text)
response = ""
Expand Down
36 changes: 18 additions & 18 deletions aperag/readers/sensitive_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from abc import ABC
from typing import Dict, Tuple

from langchain import PromptTemplate
from langchain_core.prompts import PromptTemplate

from config import settings
from aperag.db.models import ProtectAction
Expand All @@ -16,26 +16,26 @@


class SensitiveFilter(ABC):

def __init__(self, **kwargs):
self.prompt_template = PromptTemplate(template=SENSITIVE_INFORMATION_TEMPLATE, input_variables=["context", "types"])
sensitive_filter_model = settings.SENSITIVE_FILTER_MODEL
self.sensitive_protect_llm = False
if sensitive_filter_model != '':
self.sensitive_protect_llm = True
self.predictor = Predictor.from_model(self.sensitive_filter_model, PredictorType.CUSTOM_LLM, **kwargs)
self.predictor = Predictor.from_model(self.sensitive_filter_model, PredictorType.CUSTOM_LLM, **kwargs)

def sensitive_filter_llm(self, context, types=["密码", "API-KEY", "special token"]):
prompt = self.prompt_template.format(context=context, types=types)
response = ""
for tokens in self.predictor.generate_stream([], prompt):
response += tokens

try:
start = response.find('[')
end = response.rfind(']')
raw_results = json.loads(response[start:end + 1])

# check raw results
results = []
for result in raw_results:
Expand All @@ -52,14 +52,14 @@ def sensitive_filter_llm(self, context, types=["密码", "API-KEY", "special tok
results.append({"text": text, "span": span, "type": text_type})
except Exception:
return context, []

return context, results

def sensitive_filter(self, text: str, sensitive_protect_method: str) -> Tuple[str, Dict]:
output_sensitive_info = {}
output_text = text
output_text = text
try:
result = subprocess.run(['dlptool', text], capture_output=True, text=True)
result = subprocess.run(['dlptool', text], capture_output=True, text=True)
output = result.stdout.split('\n')
dlp_num = int(output[0])
dlp_outputs = []
Expand All @@ -70,7 +70,7 @@ def sensitive_filter(self, text: str, sensitive_protect_method: str) -> Tuple[st
output_sensitive_info = {"chunk": text, "masked_chunk": dlp_masktext, "sensitive_info": dlp_outputs}
if sensitive_protect_method == ProtectAction.REPLACE_WORDS:
output_text = dlp_masktext

# llm check
if self.sensitive_protect_llm:
llm_masktext, llm_outputs = self.sensitive_filter_llm(text)
Expand All @@ -80,19 +80,19 @@ def sensitive_filter(self, text: str, sensitive_protect_method: str) -> Tuple[st
output_text = llm_masktext
except Exception as e:
logger.error(f"sensitive filter failed:{e}")

return output_text, output_sensitive_info


class SensitiveFilterClassify(ABC):

def __init__(self, **kwargs):
self.prompt_template = PromptTemplate(template=CLASSIFY_SENSITIVE_INFORMATION_TEMPLATE, input_variables=["context", "types"])
sensitive_filter_model = settings.SENSITIVE_FILTER_MODEL
self.sensitive_protect_llm = False
if sensitive_filter_model != '':
self.sensitive_protect_llm = True
self.predictor = Predictor.from_model(self.sensitive_filter_model, PredictorType.CUSTOM_LLM, **kwargs)
self.predictor = Predictor.from_model(self.sensitive_filter_model, PredictorType.CUSTOM_LLM, **kwargs)
def sensitive_filter_llm(self, context, types=["密码", "API-KEY", "special token"]):
prompt = self.prompt_template.format(context=context, types=types)
response = ""
Expand All @@ -103,20 +103,20 @@ def sensitive_filter_llm(self, context, types=["密码", "API-KEY", "special tok
is_sensitive = True

return is_sensitive

def sensitive_filter(self, text: str, sensitive_protect_method: str) -> Tuple[str, Dict]:
output_sensitive_info = {}
output_text = text
output_text = text
try:
result = subprocess.run(['dlptool', text], capture_output=True, text=True)
result = subprocess.run(['dlptool', text], capture_output=True, text=True)
output = result.stdout.split('\n')
dlp_num = int(output[0])
dlp_outputs = []
for line in output[1:dlp_num + 1]:
dlp_outputs.append(json.loads(line))
dlp_masktext = '\n'.join(output[dlp_num + 2:])
is_sensitive = True

if dlp_num > 0:
# llm check
if self.sensitive_protect_llm:
Expand All @@ -127,5 +127,5 @@ def sensitive_filter(self, text: str, sensitive_protect_method: str) -> Tuple[st
output_text = dlp_masktext
except Exception as e:
logger.error(f"sensitive filter failed:{e}")

return output_text, output_sensitive_info
2 changes: 1 addition & 1 deletion aperag/readers/test_image_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import time

import requests
from langchain import PromptTemplate
from langchain_core.prompts import PromptTemplate

from aperag.query.query import QueryWithEmbedding, get_packed_answer
from aperag.readers.base_embedding import get_embedding_model
Expand Down
2 changes: 1 addition & 1 deletion aperag/readers/test_local_path_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import time

import requests
from langchain import PromptTemplate
from langchain_core.prompts import PromptTemplate

from aperag.query.query import QueryWithEmbedding, get_packed_answer
from aperag.readers.base_embedding import get_embedding_model
Expand Down
2 changes: 1 addition & 1 deletion aperag/views/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from typing import Dict

from django.http import HttpRequest, HttpResponse
from langchain import PromptTemplate
from langchain_core.prompts import PromptTemplate
from ninja.main import Exc
from pydantic import ValidationError

Expand Down
Loading
Loading