11import argparse
2+ import asyncio
23import logging
34import os
4- import sys
55import tomllib
66from concurrent import futures
77
1010from llm_backend import Config , setup_search_service , setup_summarize_service
1111
1212
13- def start_server (server : grpc .Server , config : Config ):
14- try :
15- server_config = config .server
16- address = f"{ server_config .host } :{ server_config .port } "
17- server .add_insecure_port (address = address )
18- server .start ()
19- logger .info ("Server started on %s" , address )
20- server .wait_for_termination ()
21- except Exception as e :
22- logger .error ("Error occurred while starting server: %s" , e )
23- raise
24-
25-
26- def serve (config : Config ):
27- server = grpc .server (
28- futures .ThreadPoolExecutor (max_workers = config .server .max_workers )
29- )
30-
31- setup_search_service (config , server )
32- logger .info ("Added SearchService to server" )
33-
34- setup_summarize_service (config , server )
35- logger .info ("Added SummarizeService to server" )
36-
37- start_server (server , config )
38-
39-
4013def parse_args ():
4114 parser = argparse .ArgumentParser ()
4215 parser .add_argument (
@@ -58,22 +31,44 @@ def load_config(config_path):
5831 return Config .model_validate (config )
5932
6033
61- def main ():
62- logging .basicConfig (
63- format = "%(asctime)s\t %(levelname)s: %(message)s" ,
64- handlers = [
65- logging .StreamHandler (sys .stdout ),
66- logging .FileHandler ("server.log" , "w" ),
67- ],
34+ async def serve (config : Config , logger : logging .Logger ):
35+ server = grpc .aio .server (
36+ futures .ThreadPoolExecutor (max_workers = config .server .max_workers )
6837 )
69- logger .setLevel (logging .INFO )
38+ setup_search_service (config , server )
39+ logger .info ("Added SearchService to server" )
7040
71- args = parse_args ()
72- config = load_config (args .config )
73- serve (config )
41+ setup_summarize_service (config , server )
42+ logger .info ("Added SummarizeService to server" )
43+
44+ server_config = config .server
45+ address = f"{ server_config .host } :{ server_config .port } "
46+ server .add_insecure_port (address = address )
47+ logger .info ("Server started on %s" , address )
48+
49+ await server .start ()
7450
51+ async def server_graceful_shutdown ():
52+ logging .info ("Starting graceful shutdown..." )
53+ await server .stop (3 )
54+
55+ _cleanup_coroutines .append (server_graceful_shutdown ())
56+
57+ await server .wait_for_termination ()
7558
76- logger = logging .getLogger ("server" )
7759
7860if __name__ == "__main__" :
79- main ()
61+ logging .basicConfig (format = "%(asctime)s\t %(levelname)s: %(message)s" )
62+ logger = logging .getLogger ("server" )
63+ logger .setLevel (logging .INFO )
64+
65+ args = parse_args ()
66+ config = load_config (args .config )
67+
68+ loop = asyncio .new_event_loop ()
69+ _cleanup_coroutines = []
70+ try :
71+ loop .run_until_complete (serve (config , logger ))
72+ finally :
73+ loop .run_until_complete (* _cleanup_coroutines )
74+ loop .close ()
0 commit comments