Skip to content

Commit 90b8887

Browse files
feat(serve): migrate server implementation to use asyncio and grpc.aio for asynchronous handling (#6)
* feat(serve): migrate server implementation to use asyncio and grpc.aio for asynchronous handling * fix(summarize): await an async function
1 parent d5990b9 commit 90b8887

File tree

5 files changed

+43
-49
lines changed

5 files changed

+43
-49
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ The following environment variables are required (`export` them or place them in
2020
- `QDRANT_COLLECTION`: The Qdrant collection name.
2121

2222
```shell
23-
uv run serve --config configs/config.toml
23+
python3 scripts/serve.py --config configs/config.toml
2424
```
2525

2626
## Features

llm_backend/search/service.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,16 +36,16 @@ def __init__(
3636
self.query_template = config.query.prompt_template
3737
self.similarity_top_k = config.query.similarity_top_k
3838

39-
def Search(
39+
async def Search(
4040
self,
4141
request: search_pb2.SearchRequest,
42-
context: grpc.ServicerContext,
42+
context: grpc.aio.ServicerContext,
4343
):
4444
prompt = self.query_template.format(keywords=", ".join(request.keywords))
4545
similarity_top_k = request.similarity_top_k or self.similarity_top_k
4646

4747
retriever = self.index.as_retriever(similarity_top_k=similarity_top_k)
48-
results: list[NodeWithScore] = retriever.retrieve(prompt)
48+
results: list[NodeWithScore] = await retriever.aretrieve(prompt)
4949

5050
return search_pb2.SearchResponse(
5151
results=[

llm_backend/summarize/service.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,11 @@ def __init__(
5555
self.query_str = config.query.query_str
5656
self.content_formatter = CONTENT_FORMATTERS[config.query.content_format]
5757

58-
def Summarize(
58+
async def Summarize(
5959
self,
6060
request: summarize_pb2.SummarizeRequest,
61-
context: grpc.ServicerContext,
61+
context: grpc.aio.ServicerContext,
6262
):
6363
texts = self.content_formatter(request.contents)
64-
summary = str(self.summarizer.get_response(self.query_str, texts))
64+
summary = str(await self.summarizer.aget_response(self.query_str, texts))
6565
return summarize_pb2.SummarizeResponse(summary=summary)

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ dependencies = [
1818
]
1919

2020
[project.scripts]
21-
serve = "scripts.serve:main"
2221
gen-protos = "scripts.gen_protos:generate"
2322

2423
[dependency-groups]

scripts/serve.py

Lines changed: 36 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import argparse
2+
import asyncio
23
import logging
34
import os
4-
import sys
55
import tomllib
66
from concurrent import futures
77

@@ -10,33 +10,6 @@
1010
from llm_backend import Config, setup_search_service, setup_summarize_service
1111

1212

13-
def start_server(server: grpc.Server, config: Config):
14-
try:
15-
server_config = config.server
16-
address = f"{server_config.host}:{server_config.port}"
17-
server.add_insecure_port(address=address)
18-
server.start()
19-
logger.info("Server started on %s", address)
20-
server.wait_for_termination()
21-
except Exception as e:
22-
logger.error("Error occurred while starting server: %s", e)
23-
raise
24-
25-
26-
def serve(config: Config):
27-
server = grpc.server(
28-
futures.ThreadPoolExecutor(max_workers=config.server.max_workers)
29-
)
30-
31-
setup_search_service(config, server)
32-
logger.info("Added SearchService to server")
33-
34-
setup_summarize_service(config, server)
35-
logger.info("Added SummarizeService to server")
36-
37-
start_server(server, config)
38-
39-
4013
def parse_args():
4114
parser = argparse.ArgumentParser()
4215
parser.add_argument(
@@ -58,22 +31,44 @@ def load_config(config_path):
5831
return Config.model_validate(config)
5932

6033

61-
def main():
62-
logging.basicConfig(
63-
format="%(asctime)s\t%(levelname)s: %(message)s",
64-
handlers=[
65-
logging.StreamHandler(sys.stdout),
66-
logging.FileHandler("server.log", "w"),
67-
],
34+
async def serve(config: Config, logger: logging.Logger):
35+
server = grpc.aio.server(
36+
futures.ThreadPoolExecutor(max_workers=config.server.max_workers)
6837
)
69-
logger.setLevel(logging.INFO)
38+
setup_search_service(config, server)
39+
logger.info("Added SearchService to server")
7040

71-
args = parse_args()
72-
config = load_config(args.config)
73-
serve(config)
41+
setup_summarize_service(config, server)
42+
logger.info("Added SummarizeService to server")
43+
44+
server_config = config.server
45+
address = f"{server_config.host}:{server_config.port}"
46+
server.add_insecure_port(address=address)
47+
logger.info("Server started on %s", address)
48+
49+
await server.start()
7450

51+
async def server_graceful_shutdown():
52+
logging.info("Starting graceful shutdown...")
53+
await server.stop(3)
54+
55+
_cleanup_coroutines.append(server_graceful_shutdown())
56+
57+
await server.wait_for_termination()
7558

76-
logger = logging.getLogger("server")
7759

7860
if __name__ == "__main__":
79-
main()
61+
logging.basicConfig(format="%(asctime)s\t%(levelname)s: %(message)s")
62+
logger = logging.getLogger("server")
63+
logger.setLevel(logging.INFO)
64+
65+
args = parse_args()
66+
config = load_config(args.config)
67+
68+
loop = asyncio.new_event_loop()
69+
_cleanup_coroutines = []
70+
try:
71+
loop.run_until_complete(serve(config, logger))
72+
finally:
73+
loop.run_until_complete(*_cleanup_coroutines)
74+
loop.close()

0 commit comments

Comments
 (0)