From 4fd424c43811f515bf6a07b2d580a672bae9bc09 Mon Sep 17 00:00:00 2001 From: David Chiu Date: Mon, 17 Feb 2025 16:49:42 +0800 Subject: [PATCH 1/4] feat: update proto files --- protos | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/protos b/protos index a6ed1e5..6244aa1 160000 --- a/protos +++ b/protos @@ -1 +1 @@ -Subproject commit a6ed1e50003357b55e503ed639e77a15f296da6e +Subproject commit 6244aa1a7c26262f3aa8ecd1a598f1aedb48fe92 From ba9ed841d4d0b7e9a3dc97bd6a6407903de06e12 Mon Sep 17 00:00:00 2001 From: David Chiu Date: Mon, 10 Mar 2025 04:09:54 +0800 Subject: [PATCH 2/4] refactor(rag): combine services to a rag workflow --- llm_backend/__init__.py | 31 ++----- llm_backend/rag/__init__.py | 5 ++ llm_backend/{summarize => rag}/config.py | 61 +++++++++---- llm_backend/rag/content_formatters.py | 17 ++++ llm_backend/rag/service.py | 25 ++++++ llm_backend/rag/workflow.py | 105 +++++++++++++++++++++++ llm_backend/search/__init__.py | 4 - llm_backend/search/config.py | 51 ----------- llm_backend/search/service.py | 59 ------------- llm_backend/summarize/__init__.py | 4 - llm_backend/summarize/service.py | 65 -------------- llm_backend/utils.py | 8 -- pyproject.toml | 2 +- scripts/serve.py | 80 ++++++++--------- uv.lock | 3 + 15 files changed, 241 insertions(+), 279 deletions(-) create mode 100644 llm_backend/rag/__init__.py rename llm_backend/{summarize => rag}/config.py (52%) create mode 100644 llm_backend/rag/content_formatters.py create mode 100644 llm_backend/rag/service.py create mode 100644 llm_backend/rag/workflow.py delete mode 100644 llm_backend/search/__init__.py delete mode 100644 llm_backend/search/config.py delete mode 100644 llm_backend/search/service.py delete mode 100644 llm_backend/summarize/__init__.py delete mode 100644 llm_backend/summarize/service.py delete mode 100644 llm_backend/utils.py diff --git a/llm_backend/__init__.py b/llm_backend/__init__.py index f43ff5c..64a753b 100644 --- a/llm_backend/__init__.py +++ b/llm_backend/__init__.py @@ -1,18 +1,9 @@ import os -from grpc._server import _Server +import grpc from pydantic import BaseModel -from .search import ( - SearchService, - add_SearchServiceServicer_to_server, -) -from .search.config import SearchConfig -from .summarize import ( - SummarizeService, - add_SummarizeServiceServicer_to_server, -) -from .summarize.config import SummarizeConfig +from .rag import RagConfig, RagService, add_RagServiceServicer_to_server class ServerConfig(BaseModel): @@ -21,21 +12,11 @@ class ServerConfig(BaseModel): max_workers: int = (os.cpu_count() or 1) * 5 -class ServiceConfig(BaseModel): - search: SearchConfig - summarize: SummarizeConfig - - class Config(BaseModel): server: ServerConfig - service: ServiceConfig - - -def setup_search_service(config: Config, server: _Server): - search_service = SearchService(config.service.search) - add_SearchServiceServicer_to_server(search_service, server) + service: RagConfig -def setup_summarize_service(config: Config, server: _Server): - summarize_service = SummarizeService(config.service.summarize) - add_SummarizeServiceServicer_to_server(summarize_service, server) +def setup_rag_service(config: Config, server: grpc.aio.Server): + rag_service = RagService(config.service) + add_RagServiceServicer_to_server(rag_service, server) diff --git a/llm_backend/rag/__init__.py b/llm_backend/rag/__init__.py new file mode 100644 index 0000000..752d4ff --- /dev/null +++ b/llm_backend/rag/__init__.py @@ -0,0 +1,5 @@ +from ..protos.rag_pb2_grpc import ( + add_RagServiceServicer_to_server as add_RagServiceServicer_to_server, +) +from .config import RagConfig as RagConfig +from .service import RagService as RagService diff --git a/llm_backend/summarize/config.py b/llm_backend/rag/config.py similarity index 52% rename from llm_backend/summarize/config.py rename to llm_backend/rag/config.py index bd9b6ad..668b319 100644 --- a/llm_backend/summarize/config.py +++ b/llm_backend/rag/config.py @@ -1,12 +1,17 @@ -from enum import StrEnum from typing import Annotated from llama_index.llms.openai.utils import ALL_AVAILABLE_MODELS from pydantic import AfterValidator, BaseModel, Field -from pydantic_settings import BaseSettings, SettingsConfigDict +from pydantic_settings import BaseSettings -from ..utils import contains_placeholder +from .content_formatters import ContentFormat +DEFAULT_EMBEDDING_MODEL = "intfloat/multilingual-e5-large" +DEFAULT_OPENAI_MODEL = "gpt-4o-mini" +DEFAULT_SIMILARITY_TOP_K = 10 +DEFAULT_QUERY_PROMPT_TEMPLATE = ( + "Please search for the content related to the following keywords: {keywords}." +) DEFAULT_SYSTEM_TEMPLATE = ( "You are an expert Q&A system that is trusted around the world.\n" "Always answer the query using the provided context information," @@ -30,9 +35,33 @@ DEFAULT_QUERY_STR = "請用繁體中文總結這幾篇新聞。" -class ContentFormat(StrEnum): - PLAIN = "plain" - NUMBERED = "numbered" +def contains_placeholder(*placeholders: str): + def validate_template(template: str): + for placeholder in placeholders: + if f"{{{placeholder}}}" not in template: + raise ValueError(f"Template must contain '{{{placeholder}}}'") + return template + + return validate_template + + +class QDrantConfig(BaseSettings): + host: str = Field("test", validation_alias="QDRANT_HOST") + port: int = Field(6333, gt=0, validation_alias="QDRANT_PORT") + collection: str = Field("news", validation_alias="QDRANT_COLLECTION") + + +class RetrieveConfig(BaseModel): + vector_database: QDrantConfig = QDrantConfig() # type: ignore + embedding_model: str = Field( + DEFAULT_EMBEDDING_MODEL, + description="Name of embedding model." + "All available models can be found [here](https://huggingface.co/models?library=sentence-transformers&language=zh).", + ) + prompt_template: Annotated[ + str, AfterValidator(contains_placeholder("keywords")) + ] = DEFAULT_QUERY_PROMPT_TEMPLATE + similarity_top_k: int = Field(DEFAULT_SIMILARITY_TOP_K, gt=1) def is_available_model(model_name: str): @@ -43,23 +72,17 @@ def is_available_model(model_name: str): return model_name -class ChatgptConfig(BaseSettings): - model_config = SettingsConfigDict( - env_file=(".env", ".env.prod"), - env_file_encoding="utf-8", - case_sensitive=True, - extra="ignore", - ) - +class ChatGptConfig(BaseSettings): api_key: str = Field(validation_alias="OPENAI_API_KEY") model: Annotated[ str, - Field("gpt-3.5-turbo"), + Field(DEFAULT_OPENAI_MODEL), AfterValidator(is_available_model), ] -class SummarizeQueryConfig(BaseModel): +class SummarizeConfig(BaseModel): + llm: ChatGptConfig = ChatGptConfig() # type: ignore system_template: str = DEFAULT_SYSTEM_TEMPLATE user_template: Annotated[ str, AfterValidator(contains_placeholder("context_str", "query_str")) @@ -71,6 +94,6 @@ class SummarizeQueryConfig(BaseModel): content_format: ContentFormat = ContentFormat.PLAIN -class SummarizeConfig(BaseModel): - chatgpt: ChatgptConfig - query: SummarizeQueryConfig +class RagConfig(BaseModel): + retrieve: RetrieveConfig + summarize: SummarizeConfig diff --git a/llm_backend/rag/content_formatters.py b/llm_backend/rag/content_formatters.py new file mode 100644 index 0000000..72efaab --- /dev/null +++ b/llm_backend/rag/content_formatters.py @@ -0,0 +1,17 @@ +from collections.abc import Callable, Sequence +from enum import StrEnum + + +class ContentFormat(StrEnum): + PLAIN = "plain" + NUMBERED = "numbered" + + +ContentFormatter = Callable[[Sequence[str]], Sequence[str]] + +CONTENT_FORMATTERS: dict[ContentFormat, ContentFormatter] = { + ContentFormat.PLAIN: lambda x: x, + ContentFormat.NUMBERED: lambda x: [ + f"{i}. {line}" for i, line in enumerate(x, start=1) + ], +} diff --git a/llm_backend/rag/service.py b/llm_backend/rag/service.py new file mode 100644 index 0000000..07281d2 --- /dev/null +++ b/llm_backend/rag/service.py @@ -0,0 +1,25 @@ +import grpc + +from llm_backend.protos import rag_pb2, rag_pb2_grpc +from llm_backend.rag.config import RagConfig +from llm_backend.rag.workflow import RagWorkflow + + +class RagService(rag_pb2_grpc.RagServiceServicer): + def __init__(self, config: RagConfig): + self.workflow = RagWorkflow(config=config) + + async def Rag( + self, + request: rag_pb2.RagRequest, + context: grpc.aio.ServicerContext, + ): + result = await self.workflow.run( + keywords=request.keywords, + similarity_top_k=request.similarity_top_k, + ) + + return rag_pb2.RagResponse( + retrieved_ids=result["retrieved_ids"], + summary=result["summary"], + ) diff --git a/llm_backend/rag/workflow.py b/llm_backend/rag/workflow.py new file mode 100644 index 0000000..9c5029d --- /dev/null +++ b/llm_backend/rag/workflow.py @@ -0,0 +1,105 @@ +import qdrant_client +from llama_index.core import ChatPromptTemplate, VectorStoreIndex +from llama_index.core.bridge.pydantic import BaseModel, Field +from llama_index.core.llms import ChatMessage, MessageRole +from llama_index.core.prompts import PromptType +from llama_index.core.response_synthesizers import ( + ResponseMode, + get_response_synthesizer, +) +from llama_index.core.schema import NodeWithScore +from llama_index.core.workflow import Event, StartEvent, StopEvent, Workflow, step +from llama_index.embeddings.huggingface import HuggingFaceEmbedding +from llama_index.llms.openai import OpenAI +from llama_index.vector_stores.qdrant import QdrantVectorStore + +from .config import RagConfig, RetrieveConfig, SummarizeConfig +from .content_formatters import CONTENT_FORMATTERS + + +class RetrieveResult(BaseModel): + news_id: str = Field(description="The mongodb id of the retrieved news.") + score: float = Field(description="The score of the retrieved news.") + content: str = Field(description="The content of the retrieved news.") + + +class RetrieveEvent(Event): + results: list[RetrieveResult] + + +class RagWorkflow(Workflow): + def __init__(self, config: RagConfig, *args, **kwargs): + super().__init__(*args, **kwargs) + self.__init_retrieve_service(config.retrieve) + self.__init_summarize_service(config.summarize) + + def __init_retrieve_service(self, config: RetrieveConfig): + if (host := config.vector_database.host) == "test": + client = qdrant_client.AsyncQdrantClient(location=":memory:") + else: + client = qdrant_client.AsyncQdrantClient( + host=host, + port=config.vector_database.port, + ) + + self.index = VectorStoreIndex.from_vector_store( + vector_store=QdrantVectorStore( + aclient=client, + collection_name=config.vector_database.collection, + ), + embed_model=HuggingFaceEmbedding( + model_name=config.embedding_model, + ), + ) + + self.query_template = config.prompt_template + self.similarity_top_k = config.similarity_top_k + + def __init_summarize_service(self, config: SummarizeConfig): + self.summarizer = get_response_synthesizer( + llm=OpenAI(model=config.llm.model, api_key=config.llm.api_key), + response_mode=ResponseMode.SIMPLE_SUMMARIZE, + use_async=True, + ) + summarizer_prompt = ChatPromptTemplate( + message_templates=[ + ChatMessage(role=MessageRole.SYSTEM, content=config.system_template), + ChatMessage(role=MessageRole.USER, content=config.user_template), + ], + prompt_type=PromptType.SUMMARY, + ) + self.summarizer.update_prompts({"text_qa_template": summarizer_prompt}) + + self.query_str = config.query_str + self.content_formatter = CONTENT_FORMATTERS[config.content_format] + + @step + async def retrieve(self, ev: StartEvent) -> RetrieveEvent: + prompt = self.query_template.format(keywords=", ".join(ev.keywords)) + similarity_top_k = ev.similarity_top_k or self.similarity_top_k + + retriever = self.index.as_retriever(similarity_top_k=similarity_top_k) + retrieved_results: list[NodeWithScore] = await retriever.aretrieve(prompt) + + return RetrieveEvent( + results=[ + RetrieveResult( + news_id=result.metadata["news_id"], + score=result.get_score(), + content="".join(result.text.split()), + ) + for result in retrieved_results + ] + ) + + @step + async def summarize(self, ev: RetrieveEvent) -> StopEvent: + contents = [result.content for result in ev.results] + texts = self.content_formatter(contents) + summary = str(await self.summarizer.aget_response(self.query_str, texts)) + return StopEvent( + result={ + "retrieved_ids": [result.news_id for result in ev.results], + "summary": summary, + } + ) diff --git a/llm_backend/search/__init__.py b/llm_backend/search/__init__.py deleted file mode 100644 index 2ff8af5..0000000 --- a/llm_backend/search/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from ..protos.search_pb2_grpc import ( - add_SearchServiceServicer_to_server as add_SearchServiceServicer_to_server, -) -from .service import SearchService as SearchService diff --git a/llm_backend/search/config.py b/llm_backend/search/config.py deleted file mode 100644 index 3289d0d..0000000 --- a/llm_backend/search/config.py +++ /dev/null @@ -1,51 +0,0 @@ -from typing import Annotated - -from pydantic import AfterValidator, BaseModel, Field -from pydantic_settings import BaseSettings, SettingsConfigDict - -from ..utils import contains_placeholder - -DEFAULT_EMBEDDING_MODEL = "moka-ai/m3e-base" -DEFAULT_SIMILARITY_TOP_K = 3 -DEFAULT_QUERY_PROMPT_TEMPLATE = ( - "Please search for the content related to the following keywords: {keywords}." -) - - -class QDrantConfig(BaseSettings): - model_config = SettingsConfigDict( - env_file=(".env", ".env.prod"), - env_file_encoding="utf-8", - env_prefix="QDRANT_", - case_sensitive=True, - extra="ignore", - ) - - host: str = Field("test", validation_alias="HOST") - port: int = Field( - 6333, # 6333 is the default port of Qdrant - validation_alias="PORT", - gt=0, - ) - collection: str = Field("news", validation_alias="COLLECTION") - - -class EmbeddingsConfig(BaseModel): - model: str = Field( - DEFAULT_EMBEDDING_MODEL, - description="Name of embedding model. " - "All available models can be found [here](https://huggingface.co/models?library=sentence-transformers&language=zh).", - ) - - -class QueryConfig(BaseModel): - prompt_template: Annotated[ - str, AfterValidator(contains_placeholder("keywords")) - ] = DEFAULT_QUERY_PROMPT_TEMPLATE - similarity_top_k: int = Field(DEFAULT_SIMILARITY_TOP_K, gt=0) - - -class SearchConfig(BaseModel): - qdrant: QDrantConfig = Field(default_factory=QDrantConfig) # type: ignore - embeddings: EmbeddingsConfig - query: QueryConfig diff --git a/llm_backend/search/service.py b/llm_backend/search/service.py deleted file mode 100644 index d907974..0000000 --- a/llm_backend/search/service.py +++ /dev/null @@ -1,59 +0,0 @@ -import grpc -import qdrant_client -from llama_index.core import VectorStoreIndex -from llama_index.core.schema import NodeWithScore -from llama_index.embeddings.huggingface import HuggingFaceEmbedding -from llama_index.vector_stores.qdrant import QdrantVectorStore - -from llm_backend.protos import search_pb2, search_pb2_grpc - -from .config import SearchConfig - - -class SearchService(search_pb2_grpc.SearchServiceServicer): - def __init__( - self, - config: SearchConfig, - ): - if (host := config.qdrant.host) == "test": - client = qdrant_client.AsyncQdrantClient(location=":memory:") - else: - client = qdrant_client.AsyncQdrantClient( - host=host, - port=config.qdrant.port, - ) - - vector_store = QdrantVectorStore( - aclient=client, - collection_name=config.qdrant.collection, - ) - embed_model = HuggingFaceEmbedding(model_name=config.embeddings.model) - - self.index = VectorStoreIndex.from_vector_store( - vector_store=vector_store, embed_model=embed_model - ) - - self.query_template = config.query.prompt_template - self.similarity_top_k = config.query.similarity_top_k - - def Search( - self, - request: search_pb2.SearchRequest, - context: grpc.ServicerContext, - ): - prompt = self.query_template.format(keywords=", ".join(request.keywords)) - similarity_top_k = request.similarity_top_k or self.similarity_top_k - - retriever = self.index.as_retriever(similarity_top_k=similarity_top_k) - results: list[NodeWithScore] = retriever.retrieve(prompt) - - return search_pb2.SearchResponse( - results=[ - search_pb2.RetrieveResult( - id=result.node_id, - score=result.score, - content="".join(result.text.split()), - ) - for result in results - ] - ) diff --git a/llm_backend/summarize/__init__.py b/llm_backend/summarize/__init__.py deleted file mode 100644 index 44c8b7b..0000000 --- a/llm_backend/summarize/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from ..protos.summarize_pb2_grpc import ( - add_SummarizeServiceServicer_to_server as add_SummarizeServiceServicer_to_server, -) -from .service import SummarizeService as SummarizeService diff --git a/llm_backend/summarize/service.py b/llm_backend/summarize/service.py deleted file mode 100644 index 2c3ca64..0000000 --- a/llm_backend/summarize/service.py +++ /dev/null @@ -1,65 +0,0 @@ -from collections.abc import Callable, Sequence - -import grpc -from llama_index.core import ChatPromptTemplate -from llama_index.core.llms import ChatMessage, MessageRole -from llama_index.core.prompts import PromptType -from llama_index.core.response_synthesizers import ( - ResponseMode, - get_response_synthesizer, -) -from llama_index.llms.openai import OpenAI - -from llm_backend.protos import summarize_pb2, summarize_pb2_grpc -from llm_backend.protos.summarize_pb2_grpc import add_SummarizeServiceServicer_to_server - -from .config import ContentFormat, SummarizeConfig - -__all__ = [ - "SummarizeService", - "add_SummarizeServiceServicer_to_server", -] - - -type ContentFormatter = Callable[[Sequence[str]], Sequence[str]] - -CONTENT_FORMATTERS: dict[ContentFormat, ContentFormatter] = { - ContentFormat.PLAIN: lambda x: x, - ContentFormat.NUMBERED: lambda x: [ - f"{i}. {line}" for i, line in enumerate(x, start=1) - ], -} - - -class SummarizeService(summarize_pb2_grpc.SummarizeServiceServicer): - def __init__( - self, - config: SummarizeConfig, - ): - llm = OpenAI(model=config.chatgpt.model, api_key=config.chatgpt.api_key) - - self.summarizer = get_response_synthesizer( - llm, response_mode=ResponseMode.TREE_SUMMARIZE, use_async=True - ) - summarizer_prompt = ChatPromptTemplate( - message_templates=[ - ChatMessage( - role=MessageRole.SYSTEM, content=config.query.system_template - ), - ChatMessage(role=MessageRole.USER, content=config.query.user_template), - ], - prompt_type=PromptType.SUMMARY, - ) - self.summarizer.update_prompts({"summary_template": summarizer_prompt}) - - self.query_str = config.query.query_str - self.content_formatter = CONTENT_FORMATTERS[config.query.content_format] - - def Summarize( - self, - request: summarize_pb2.SummarizeRequest, - context: grpc.ServicerContext, - ): - texts = self.content_formatter(request.contents) - summary = str(self.summarizer.get_response(self.query_str, texts)) - return summarize_pb2.SummarizeResponse(summary=summary) diff --git a/llm_backend/utils.py b/llm_backend/utils.py deleted file mode 100644 index 5c7a832..0000000 --- a/llm_backend/utils.py +++ /dev/null @@ -1,8 +0,0 @@ -def contains_placeholder(*placeholders: str): - def validate_template(template: str): - for placeholder in placeholders: - if f"{{{placeholder}}}" not in template: - raise ValueError(f"Template must contain '{{{placeholder}}}'") - return template - - return validate_template diff --git a/pyproject.toml b/pyproject.toml index a2fc243..55fe1f2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,6 +9,7 @@ dependencies = [ "grpcio>=1.62.1,<2", "grpcio-tools>=1.62.1,<2", "llama-index~=0.12.16", + "llama-index-core~=0.12.17", "llama-index-vector-stores-qdrant~=0.4.3", "llama-index-embeddings-huggingface~=0.5.1", "llama-index-llms-openai~=0.3.18", @@ -18,7 +19,6 @@ dependencies = [ ] [project.scripts] -serve = "scripts.serve:main" gen-protos = "scripts.gen_protos:generate" [dependency-groups] diff --git a/scripts/serve.py b/scripts/serve.py index 019f279..f37a090 100644 --- a/scripts/serve.py +++ b/scripts/serve.py @@ -1,4 +1,5 @@ import argparse +import asyncio import logging import os import sys @@ -7,34 +8,7 @@ import grpc -from llm_backend import Config, setup_search_service, setup_summarize_service - - -def start_server(server: grpc.Server, config: Config): - try: - server_config = config.server - address = f"{server_config.host}:{server_config.port}" - server.add_insecure_port(address=address) - server.start() - logger.info("Server started on %s", address) - server.wait_for_termination() - except Exception as e: - logger.error("Error occurred while starting server: %s", e) - raise - - -def serve(config: Config): - server = grpc.server( - futures.ThreadPoolExecutor(max_workers=config.server.max_workers) - ) - - setup_search_service(config, server) - logger.info("Added SearchService to server") - - setup_summarize_service(config, server) - logger.info("Added SummarizeService to server") - - start_server(server, config) +from llm_backend import Config, setup_rag_service def parse_args(): @@ -42,7 +16,7 @@ def parse_args(): parser.add_argument( "--config", type=str, - default=os.path.join("configs", "example.toml"), + default=os.path.join("configs", "config.toml"), help="Path to the config file.", ) return parser.parse_args() @@ -54,26 +28,46 @@ def load_config(config_path): config = tomllib.load(config_file) except FileNotFoundError as e: logger.error("Config file not found: %s", e) - raise + sys.exit(1) return Config.model_validate(config) -def main(): - logging.basicConfig( - format="%(asctime)s\t%(levelname)s: %(message)s", - handlers=[ - logging.StreamHandler(sys.stdout), - logging.FileHandler("server.log", "w"), - ], +async def serve(config: Config, logger: logging.Logger): + server = grpc.aio.server( + futures.ThreadPoolExecutor(max_workers=config.server.max_workers) ) - logger.setLevel(logging.INFO) - args = parse_args() - config = load_config(args.config) - serve(config) + setup_rag_service(config, server) + logger.info("RagService setup complete") + + server_config = config.server + address = f"{server_config.host}:{server_config.port}" + server.add_insecure_port(address=address) + logger.info("Server started on %s", address) + + await server.start() + async def server_graceful_shutdown(): + logging.info("Starting graceful shutdown...") + await server.stop(3) + + _cleanup_coroutines.append(server_graceful_shutdown()) + + await server.wait_for_termination() -logger = logging.getLogger("server") if __name__ == "__main__": - main() + logging.basicConfig(format="%(asctime)s\t%(levelname)s: %(message)s") + logger = logging.getLogger("server") + logger.setLevel(logging.INFO) + + args = parse_args() + config = load_config(args.config) + + loop = asyncio.new_event_loop() + _cleanup_coroutines = [] + try: + loop.run_until_complete(serve(config, logger)) + finally: + loop.run_until_complete(*_cleanup_coroutines) + loop.close() diff --git a/uv.lock b/uv.lock index df348cd..7a5f882 100644 --- a/uv.lock +++ b/uv.lock @@ -1,4 +1,5 @@ version = 1 +revision = 1 requires-python = "==3.12.*" [[package]] @@ -746,6 +747,7 @@ dependencies = [ { name = "grpcio" }, { name = "grpcio-tools" }, { name = "llama-index" }, + { name = "llama-index-core" }, { name = "llama-index-embeddings-huggingface" }, { name = "llama-index-llms-openai" }, { name = "llama-index-vector-stores-qdrant" }, @@ -765,6 +767,7 @@ requires-dist = [ { name = "grpcio", specifier = ">=1.62.1,<2" }, { name = "grpcio-tools", specifier = ">=1.62.1,<2" }, { name = "llama-index", specifier = "~=0.12.16" }, + { name = "llama-index-core", specifier = "~=0.12.17" }, { name = "llama-index-embeddings-huggingface", specifier = "~=0.5.1" }, { name = "llama-index-llms-openai", specifier = "~=0.3.18" }, { name = "llama-index-vector-stores-qdrant", specifier = "~=0.4.3" }, From 047913ba40ec89a984f3c0e60dfc4b98796fd3f2 Mon Sep 17 00:00:00 2001 From: David Chiu Date: Mon, 10 Mar 2025 04:13:13 +0800 Subject: [PATCH 3/4] doc(config): update execute instructions --- .gitignore | 2 -- README.md | 4 ++-- configs/config.toml | 44 ++++++++++++++++++++++++++++++++++++++++++++ configs/example.toml | 43 ------------------------------------------- 4 files changed, 46 insertions(+), 47 deletions(-) create mode 100644 configs/config.toml delete mode 100644 configs/example.toml diff --git a/.gitignore b/.gitignore index 32d45c7..f0d016c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,4 @@ .DS_Store -configs/* -!configs/example.toml llm_backend/protos/ # Byte-compiled / optimized / DLL files diff --git a/README.md b/README.md index c64d094..0435d7e 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ uv run gen-protos ## Usage -Please configure the `configs/config.toml` file (refer to `configs/example.toml` for the options). +Please configure the `configs/config.toml` file. The following environment variables are required (`export` them or place them in a `.env` file): - `OPENAI_API_KEY`: Your ChatGPT API key. @@ -20,7 +20,7 @@ The following environment variables are required (`export` them or place them in - `QDRANT_COLLECTION`: The Qdrant collection name. ```shell -uv run serve --config configs/config.toml +python3 scripts/serve.py --config configs/config.toml ``` ## Features diff --git a/configs/config.toml b/configs/config.toml new file mode 100644 index 0000000..7b2d2ae --- /dev/null +++ b/configs/config.toml @@ -0,0 +1,44 @@ +[server] +host = 'localhost' +port = 50051 +max_workers = 10 + +[service.retrieve] +# Name of embedding model. All available models can be found [here](https://huggingface.co/models?language=zh) +embedding_model = 'intfloat/multilingual-e5-large' + +# The template must contain the `{keywords}` placeholder. +prompt_template = 'Please search for the content related to the following keywords: {keywords}.' +similarity_top_k = 5 + +[service.summarize] +system_template = """ +# Project Mission: My project mission is to extract 5 articles of the same type from the internet each time and provide them to Chat GPT in the same format to generate summaries and digests, making it convenient for the general public to read. +# Input Format: The format during input is as follows: 1.xxx 2.xxx 3.xxx 4.xxx 5.xxx Each news article is numbered with a digit title. There is a blank line between different news articles, but within the same article, there are no line breaks. +# Detailed Project Execution: The detailed execution of the project involves refraining from adding personal opinions. I only generate summaries based on the provided news and refrain from providing responses beyond the scope of the news. +# Audience for My Content: The audience comprises professionals from various fields, as well as students and homemakers. They span a wide range of age groups and have a strong desire for knowledge. However, due to limited time, they cannot personally read a large amount of news information. Therefore, my content typically needs to be transformed into something understandable by the general public, with logical analysis involving specific questions and solutions. + +# Assuming you are now a reader, think step by step about what you think the key points of the news would be, and provide the first version of the summary. Then, based on this summary, pose sub-questions and further modify to provide the final summary. +# Answer in Traditional Chinese, and refrain from providing thoughts and content beyond what you've provided. Endeavor to comprehensively describe the key points of the news. +# Responses should strive to be rigorous and formal, with real evidence when answering questions. +# Answers can be as complete and comprehensive as possible, expanding on details and actual content. +# The "Output Format" is: provide an overarching title that summarizes the news content above, then summarizes the content. +""" + +# The template must contain the `{context_str}` and `{query_str}` placeholders. +user_template = """ +{query_str} +--------------------- +{context_str}""" + +# The content of `{query_str}` placeholder in the user template. +query_str = '假設你是一個摘要抓取者,請將以下---內的文字做一篇文章摘要,用文章敘述的方式呈現,不要用列點的,至少要有500字,要有標題。' + +# The transform function from the request strings to the query strings. +# Must be one of: +# - 'plain': The query string is the same as the request string. +# - 'numbered': Add a number (1., 2., ...) to the beginning of each request string. +content_format = 'plain' + +[service.summarize.llm] +model = 'gpt-4o-mini' diff --git a/configs/example.toml b/configs/example.toml deleted file mode 100644 index 26d7be1..0000000 --- a/configs/example.toml +++ /dev/null @@ -1,43 +0,0 @@ -[server] -host = 'localhost' -port = 50051 -max_workers = 10 - -[service.search.embeddings] -# Name of embedding model. All available models can be found [here](https://huggingface.co/models?language=zh) -model = 'moka-ai/m3e-base' - -[service.search.query] -# The template must contain the `{keywords}` placeholder. -prompt_template = 'Please search for the content related to the following keywords: {keywords}.' -similarity_top_k = 3 - -[service.summarize.chatgpt] -model = 'o3-mini' - -[service.summarize.query] -system_template = """ -You are an expert Q&A system that is trusted around the world. -Always answer the query using the provided context information, and not prior knowledge. -Some rules to follow: -1. Never directly reference the given context in your answer. -2. Avoid statements like 'Based on the context, ...' or The context information ...' or anything along those lines.""" - -# The template must contain the `{context_str}` and `{query_str}` placeholders. -user_template = """ -Context information from multiple sources is below. ---------------------- -{context_str} ---------------------- -Given the information from multiple sources and not prior knowledge, answer the query. -Query: {query_str} -Answer: """ - -# The content of `{query_str}` placeholder in the user template. -query_str = '請用繁體中文總結這幾篇新聞。' - -# The transform function from the request strings to the query strings. -# Must be one of: -# - 'plain': The query string is the same as the request string. -# - 'numbered': Add a number (1., 2., ...) to the beginning of each request string. -content_format = 'plain' From 339abe5872ecfb82aa87254db7527bebd54cadd6 Mon Sep 17 00:00:00 2001 From: David Chiu Date: Mon, 10 Mar 2025 04:17:58 +0800 Subject: [PATCH 4/4] feat(client): add example client implementation for searching with gRPC --- README.md | 6 ++++++ scripts/client.py | 23 +++++++++++++++++++++++ 2 files changed, 29 insertions(+) create mode 100644 scripts/client.py diff --git a/README.md b/README.md index 0435d7e..4ebd1a3 100644 --- a/README.md +++ b/README.md @@ -23,6 +23,12 @@ The following environment variables are required (`export` them or place them in python3 scripts/serve.py --config configs/config.toml ``` +You can refer to `scripts/client.py` for an example implementation of a client: + +```shell +python3 scripts/client.py +``` + ## Features Refer to the protobuf files in the `protos/` directory for the features provided by the server. diff --git a/scripts/client.py b/scripts/client.py new file mode 100644 index 0000000..7fa6c58 --- /dev/null +++ b/scripts/client.py @@ -0,0 +1,23 @@ +import grpc + +from llm_backend.protos import rag_pb2, rag_pb2_grpc + + +def run_search(): + with grpc.insecure_channel(ADDRESS) as channel: + client = rag_pb2_grpc.RagServiceStub(channel) + + request = rag_pb2.RagRequest( + keywords=["台灣", "選舉"], + similarity_top_k=5, + ) + + response = client.Rag(request) + + print(response) + + +if __name__ == "__main__": + ADDRESS = "localhost:50051" + print("Running search...") + run_search()