Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
.DS_Store
configs/*
!configs/example.toml
llm_backend/protos/

# Byte-compiled / optimized / DLL files
Expand Down
10 changes: 8 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ uv run gen-protos

## Usage

Please configure the `configs/config.toml` file (refer to `configs/example.toml` for the options).
Please configure the `configs/config.toml` file.
The following environment variables are required (`export` them or place them in a `.env` file):

- `OPENAI_API_KEY`: Your ChatGPT API key.
Expand All @@ -20,7 +20,13 @@ The following environment variables are required (`export` them or place them in
- `QDRANT_COLLECTION`: The Qdrant collection name.

```shell
uv run serve --config configs/config.toml
python3 scripts/serve.py --config configs/config.toml
```

You can refer to `scripts/client.py` for an example implementation of a client:

```shell
python3 scripts/client.py
```

## Features
Expand Down
44 changes: 44 additions & 0 deletions configs/config.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
[server]
host = 'localhost'
port = 50051
max_workers = 10

[service.retrieve]
# Name of embedding model. All available models can be found [here](https://huggingface.co/models?language=zh)
embedding_model = 'intfloat/multilingual-e5-large'

# The template must contain the `{keywords}` placeholder.
prompt_template = 'Please search for the content related to the following keywords: {keywords}.'
similarity_top_k = 5

[service.summarize]
system_template = """
# Project Mission: My project mission is to extract 5 articles of the same type from the internet each time and provide them to Chat GPT in the same format to generate summaries and digests, making it convenient for the general public to read.
# Input Format: The format during input is as follows: 1.xxx 2.xxx 3.xxx 4.xxx 5.xxx Each news article is numbered with a digit title. There is a blank line between different news articles, but within the same article, there are no line breaks.
# Detailed Project Execution: The detailed execution of the project involves refraining from adding personal opinions. I only generate summaries based on the provided news and refrain from providing responses beyond the scope of the news.
# Audience for My Content: The audience comprises professionals from various fields, as well as students and homemakers. They span a wide range of age groups and have a strong desire for knowledge. However, due to limited time, they cannot personally read a large amount of news information. Therefore, my content typically needs to be transformed into something understandable by the general public, with logical analysis involving specific questions and solutions.

# Assuming you are now a reader, think step by step about what you think the key points of the news would be, and provide the first version of the summary. Then, based on this summary, pose sub-questions and further modify to provide the final summary.
# Answer in Traditional Chinese, and refrain from providing thoughts and content beyond what you've provided. Endeavor to comprehensively describe the key points of the news.
# Responses should strive to be rigorous and formal, with real evidence when answering questions.
# Answers can be as complete and comprehensive as possible, expanding on details and actual content.
# The "Output Format" is: provide an overarching title that summarizes the news content above, then summarizes the content.
"""

# The template must contain the `{context_str}` and `{query_str}` placeholders.
user_template = """
{query_str}
---------------------
{context_str}"""

# The content of `{query_str}` placeholder in the user template.
query_str = '假設你是一個摘要抓取者,請將以下---內的文字做一篇文章摘要,用文章敘述的方式呈現,不要用列點的,至少要有500字,要有標題。'

# The transform function from the request strings to the query strings.
# Must be one of:
# - 'plain': The query string is the same as the request string.
# - 'numbered': Add a number (1., 2., ...) to the beginning of each request string.
content_format = 'plain'

[service.summarize.llm]
model = 'gpt-4o-mini'
43 changes: 0 additions & 43 deletions configs/example.toml

This file was deleted.

31 changes: 6 additions & 25 deletions llm_backend/__init__.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,9 @@
import os

from grpc._server import _Server
import grpc
from pydantic import BaseModel

from .search import (
SearchService,
add_SearchServiceServicer_to_server,
)
from .search.config import SearchConfig
from .summarize import (
SummarizeService,
add_SummarizeServiceServicer_to_server,
)
from .summarize.config import SummarizeConfig
from .rag import RagConfig, RagService, add_RagServiceServicer_to_server


class ServerConfig(BaseModel):
Expand All @@ -21,21 +12,11 @@ class ServerConfig(BaseModel):
max_workers: int = (os.cpu_count() or 1) * 5


class ServiceConfig(BaseModel):
search: SearchConfig
summarize: SummarizeConfig


class Config(BaseModel):
server: ServerConfig
service: ServiceConfig


def setup_search_service(config: Config, server: _Server):
search_service = SearchService(config.service.search)
add_SearchServiceServicer_to_server(search_service, server)
service: RagConfig


def setup_summarize_service(config: Config, server: _Server):
summarize_service = SummarizeService(config.service.summarize)
add_SummarizeServiceServicer_to_server(summarize_service, server)
def setup_rag_service(config: Config, server: grpc.aio.Server):
rag_service = RagService(config.service)
add_RagServiceServicer_to_server(rag_service, server)
5 changes: 5 additions & 0 deletions llm_backend/rag/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from ..protos.rag_pb2_grpc import (
add_RagServiceServicer_to_server as add_RagServiceServicer_to_server,
)
from .config import RagConfig as RagConfig
from .service import RagService as RagService
61 changes: 42 additions & 19 deletions llm_backend/summarize/config.py → llm_backend/rag/config.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
from enum import StrEnum
from typing import Annotated

from llama_index.llms.openai.utils import ALL_AVAILABLE_MODELS
from pydantic import AfterValidator, BaseModel, Field
from pydantic_settings import BaseSettings, SettingsConfigDict
from pydantic_settings import BaseSettings

from ..utils import contains_placeholder
from .content_formatters import ContentFormat

DEFAULT_EMBEDDING_MODEL = "intfloat/multilingual-e5-large"
DEFAULT_OPENAI_MODEL = "gpt-4o-mini"
DEFAULT_SIMILARITY_TOP_K = 10
DEFAULT_QUERY_PROMPT_TEMPLATE = (
"Please search for the content related to the following keywords: {keywords}."
)
DEFAULT_SYSTEM_TEMPLATE = (
"You are an expert Q&A system that is trusted around the world.\n"
"Always answer the query using the provided context information,"
Expand All @@ -30,9 +35,33 @@
DEFAULT_QUERY_STR = "請用繁體中文總結這幾篇新聞。"


class ContentFormat(StrEnum):
PLAIN = "plain"
NUMBERED = "numbered"
def contains_placeholder(*placeholders: str):
def validate_template(template: str):
for placeholder in placeholders:
if f"{{{placeholder}}}" not in template:
raise ValueError(f"Template must contain '{{{placeholder}}}'")
return template

return validate_template


class QDrantConfig(BaseSettings):
host: str = Field("test", validation_alias="QDRANT_HOST")
port: int = Field(6333, gt=0, validation_alias="QDRANT_PORT")
collection: str = Field("news", validation_alias="QDRANT_COLLECTION")


class RetrieveConfig(BaseModel):
vector_database: QDrantConfig = QDrantConfig() # type: ignore
embedding_model: str = Field(
DEFAULT_EMBEDDING_MODEL,
description="Name of embedding model."
"All available models can be found [here](https://huggingface.co/models?library=sentence-transformers&language=zh).",
)
prompt_template: Annotated[
str, AfterValidator(contains_placeholder("keywords"))
] = DEFAULT_QUERY_PROMPT_TEMPLATE
similarity_top_k: int = Field(DEFAULT_SIMILARITY_TOP_K, gt=1)


def is_available_model(model_name: str):
Expand All @@ -43,23 +72,17 @@ def is_available_model(model_name: str):
return model_name


class ChatgptConfig(BaseSettings):
model_config = SettingsConfigDict(
env_file=(".env", ".env.prod"),
env_file_encoding="utf-8",
case_sensitive=True,
extra="ignore",
)

class ChatGptConfig(BaseSettings):
api_key: str = Field(validation_alias="OPENAI_API_KEY")
model: Annotated[
str,
Field("gpt-3.5-turbo"),
Field(DEFAULT_OPENAI_MODEL),
AfterValidator(is_available_model),
]


class SummarizeQueryConfig(BaseModel):
class SummarizeConfig(BaseModel):
llm: ChatGptConfig = ChatGptConfig() # type: ignore
system_template: str = DEFAULT_SYSTEM_TEMPLATE
user_template: Annotated[
str, AfterValidator(contains_placeholder("context_str", "query_str"))
Expand All @@ -71,6 +94,6 @@ class SummarizeQueryConfig(BaseModel):
content_format: ContentFormat = ContentFormat.PLAIN


class SummarizeConfig(BaseModel):
chatgpt: ChatgptConfig
query: SummarizeQueryConfig
class RagConfig(BaseModel):
retrieve: RetrieveConfig
summarize: SummarizeConfig
17 changes: 17 additions & 0 deletions llm_backend/rag/content_formatters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from collections.abc import Callable, Sequence
from enum import StrEnum


class ContentFormat(StrEnum):
PLAIN = "plain"
NUMBERED = "numbered"


ContentFormatter = Callable[[Sequence[str]], Sequence[str]]

CONTENT_FORMATTERS: dict[ContentFormat, ContentFormatter] = {
ContentFormat.PLAIN: lambda x: x,
ContentFormat.NUMBERED: lambda x: [
f"{i}. {line}" for i, line in enumerate(x, start=1)
],
}
25 changes: 25 additions & 0 deletions llm_backend/rag/service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import grpc

from llm_backend.protos import rag_pb2, rag_pb2_grpc
from llm_backend.rag.config import RagConfig
from llm_backend.rag.workflow import RagWorkflow


class RagService(rag_pb2_grpc.RagServiceServicer):
def __init__(self, config: RagConfig):
self.workflow = RagWorkflow(config=config)

async def Rag(
self,
request: rag_pb2.RagRequest,
context: grpc.aio.ServicerContext,
):
result = await self.workflow.run(
keywords=request.keywords,
similarity_top_k=request.similarity_top_k,
)

return rag_pb2.RagResponse(
retrieved_ids=result["retrieved_ids"],
summary=result["summary"],
)
Loading