11import os
22import asyncio
3- import aioredis
3+ from redis import asyncio as aioredis
44import uvloop
55import socket
66import uuid
77import contextvars
8+ from contextlib import asynccontextmanager
89from fastapi import FastAPI , Depends , Request
910from starlette .staticfiles import StaticFiles
1011from starlette .templating import Jinja2Templates
1112from starlette .middleware .base import BaseHTTPMiddleware
1213from starlette .websockets import WebSocket , WebSocketDisconnect
1314
1415from websockets .exceptions import ConnectionClosedError , ConnectionClosedOK
15- from aioredis . errors import ConnectionClosedError as ServerConnectionClosedError
16+ from redis . exceptions import ConnectionError as ServerConnectionClosedError
1617
1718REDIS_HOST = 'localhost'
1819REDIS_PORT = 6379
@@ -43,7 +44,32 @@ async def dispatch(self, request, call_next):
4344
4445
4546asyncio .set_event_loop_policy (uvloop .EventLoopPolicy ())
46- app = FastAPI ()
47+
48+ @asynccontextmanager
49+ async def lifespan (app : FastAPI ):
50+ # Startup
51+ try :
52+ redis_url = f"redis://{ REDIS_HOST } :{ REDIS_PORT } "
53+ pool = await aioredis .from_url (
54+ redis_url ,
55+ encoding = 'utf-8' ,
56+ decode_responses = True ,
57+ max_connections = 20
58+ )
59+ cvar_redis .set (pool )
60+ print ("Connected to Redis on " , REDIS_HOST , REDIS_PORT )
61+ except ConnectionRefusedError as e :
62+ print ('cannot connect to redis on:' , REDIS_HOST , REDIS_PORT )
63+
64+ yield
65+
66+ # Shutdown
67+ redis = cvar_redis .get ()
68+ if redis :
69+ await redis .aclose ()
70+ print ("closed connection Redis on " , REDIS_HOST , REDIS_PORT )
71+
72+ app = FastAPI (lifespan = lifespan )
4773app .add_middleware (CustomHeaderMiddleware )
4874templates = Jinja2Templates (directory = "templates" )
4975
@@ -73,8 +99,12 @@ def get_local_ip():
7399
74100async def get_redis_pool ():
75101 try :
76- pool = await aioredis .create_redis_pool (
77- (REDIS_HOST , REDIS_PORT ), encoding = 'utf-8' )
102+ redis_url = f"redis://{ REDIS_HOST } :{ REDIS_PORT } "
103+ pool = await aioredis .from_url (
104+ redis_url ,
105+ encoding = 'utf-8' ,
106+ decode_responses = True
107+ )
78108 return pool
79109 except ConnectionRefusedError as e :
80110 print ('cannot connect to redis on:' , REDIS_HOST , REDIS_PORT )
@@ -97,21 +127,20 @@ async def ws_send_moderator(websocket: WebSocket, chat_info: dict):
97127 """
98128 pool = await get_redis_pool ()
99129 streams = chat_info ['room' ].split (',' )
100- latest_ids = [ '$' for i in streams ]
130+ latest_ids = { stream : '$' for stream in streams }
101131 ws_connected = True
102132 print (streams , latest_ids )
103133 while pool and ws_connected :
104134 try :
105135 events = await pool .xread (
106- streams = streams ,
136+ streams = latest_ids ,
107137 count = XREAD_COUNT ,
108- timeout = XREAD_TIMEOUT ,
109- latest_ids = latest_ids
138+ block = XREAD_TIMEOUT if XREAD_TIMEOUT > 0 else None
110139 )
111- for _ , e_id , e in events :
112- e [ ' e_id' ] = e_id
113- await websocket . send_json ( e )
114- #latest_ids = [e_id]
140+ for stream , messages in events :
141+ for e_id , e in messages :
142+ e [ 'e_id' ] = e_id
143+ await websocket . send_json ( e )
115144 except ConnectionClosedError :
116145 ws_connected = False
117146
@@ -130,18 +159,19 @@ async def ws_send(websocket: WebSocket, chat_info: dict):
130159 :type chat_info:
131160 """
132161 pool = await get_redis_pool ()
133- latest_ids = ['$' ]
162+ stream_key = cvar_tenant .get () + ":stream"
163+ latest_ids = {stream_key : '$' }
134164 ws_connected = True
135165 first_run = True
136166 while pool and ws_connected :
137167 try :
138168 if first_run :
139169 # fetch some previous chat history
140170 events = await pool .xrevrange (
141- stream = cvar_tenant . get () + ":stream" ,
171+ name = stream_key ,
142172 count = NUM_PREVIOUS ,
143- start = '+ ' ,
144- stop = '- '
173+ min = '- ' ,
174+ max = '+ '
145175 )
146176 first_run = False
147177 events .reverse ()
@@ -150,15 +180,15 @@ async def ws_send(websocket: WebSocket, chat_info: dict):
150180 await websocket .send_json (e )
151181 else :
152182 events = await pool .xread (
153- streams = [ cvar_tenant . get () + ":stream" ] ,
183+ streams = latest_ids ,
154184 count = XREAD_COUNT ,
155- timeout = XREAD_TIMEOUT ,
156- latest_ids = latest_ids
185+ block = XREAD_TIMEOUT if XREAD_TIMEOUT > 0 else None
157186 )
158- for _ , e_id , e in events :
159- e ['e_id' ] = e_id
160- await websocket .send_json (e )
161- latest_ids = [e_id ]
187+ for stream , messages in events :
188+ for e_id , e in messages :
189+ e ['e_id' ] = e_id
190+ await websocket .send_json (e )
191+ latest_ids = {stream_key : e_id }
162192 #print('################contextvar ', cvar_tenant.get())
163193 except ConnectionClosedError :
164194 ws_connected = False
@@ -169,7 +199,7 @@ async def ws_send(websocket: WebSocket, chat_info: dict):
169199 except ServerConnectionClosedError :
170200 print ('redis server connection closed' )
171201 return
172- pool .close ()
202+ await pool .aclose ()
173203
174204
175205async def ws_recieve (websocket : WebSocket , chat_info : dict ):
@@ -205,10 +235,12 @@ async def ws_recieve(websocket: WebSocket, chat_info: dict):
205235 'type' : 'comment' ,
206236 'room' : chat_info ['room' ]
207237 }
208- await pool .xadd (stream = cvar_tenant .get () + ":stream" ,
209- fields = fields ,
210- message_id = b'*' ,
211- max_len = STREAM_MAX_LEN )
238+ await pool .xadd (
239+ name = cvar_tenant .get () + ":stream" ,
240+ fields = fields ,
241+ id = '*' ,
242+ maxlen = STREAM_MAX_LEN
243+ )
212244 #print('################contextvar ', cvar_tenant.get())
213245 except WebSocketDisconnect :
214246 await remove_room_user (chat_info , pool )
@@ -223,7 +255,7 @@ async def ws_recieve(websocket: WebSocket, chat_info: dict):
223255 print ('redis server connection closed' )
224256 return
225257
226- pool .close ()
258+ await pool .aclose ()
227259
228260
229261async def add_room_user (chat_info : dict , pool ):
@@ -259,10 +291,12 @@ async def announce(pool, chat_info: dict, action: str):
259291 }
260292 #print(fields)
261293
262- await pool .xadd (stream = cvar_tenant .get () + ":stream" ,
263- fields = fields ,
264- message_id = b'*' ,
265- max_len = STREAM_MAX_LEN )
294+ await pool .xadd (
295+ name = cvar_tenant .get () + ":stream" ,
296+ fields = fields ,
297+ id = '*' ,
298+ maxlen = STREAM_MAX_LEN
299+ )
266300
267301
268302async def chat_info_vars (username : str = None , room : str = None ):
@@ -355,30 +389,10 @@ async def verify_user_for_room(chat_info):
355389 # whitelist rooms
356390 if not chat_info ['room' ] in ALLOWED_ROOMS :
357391 verified = False
358- pool .close ()
392+ await pool .aclose ()
359393 return verified
360394
361395
362- @app .on_event ("startup" )
363- async def handle_startup ():
364- try :
365- pool = await aioredis .create_redis_pool (
366- (REDIS_HOST , REDIS_PORT ), encoding = 'utf-8' , maxsize = 20 )
367- cvar_redis .set (pool )
368- print ("Connected to Redis on " , REDIS_HOST , REDIS_PORT )
369- except ConnectionRefusedError as e :
370- print ('cannot connect to redis on:' , REDIS_HOST , REDIS_PORT )
371- return
372-
373-
374- @app .on_event ("shutdown" )
375- async def handle_shutdown ():
376- redis = cvar_redis .get ()
377- redis .close ()
378- await redis .wait_closed ()
379- print ("closed connection Redis on " , REDIS_HOST , REDIS_PORT )
380-
381-
382396if __name__ == "__main__" :
383397 import uvicorn
384398 print (dir (app ))
0 commit comments