99from typing import Any , Final
1010
1111import sqlalchemy as sa
12- from aiopg .sa .connection import SAConnection
13- from aiopg .sa .result import RowProxy
12+ from common_library .async_tools import maybe_await
1413from sqlalchemy import Column
1514
15+ from ._protocols import DBConnection
1616from .aiopg_errors import UniqueViolation
1717from .models .users import UserRole , UserStatus , users
1818from .models .users_details import users_pre_registration_details
@@ -55,12 +55,12 @@ def generate_alternative_username(username: str) -> str:
5555class UsersRepo :
5656 @staticmethod
5757 async def new_user (
58- conn : SAConnection ,
58+ conn : DBConnection ,
5959 email : str ,
6060 password_hash : str ,
6161 status : UserStatus ,
6262 expires_at : datetime | None ,
63- ) -> RowProxy :
63+ ) -> Any :
6464 data : dict [str , Any ] = {
6565 "name" : _generate_username_from_email (email ),
6666 "email" : email ,
@@ -88,13 +88,15 @@ async def new_user(
8888 users .c .status ,
8989 ).where (users .c .id == user_id )
9090 )
91- row = await result .first ()
91+ row = await maybe_await (result .first ())
92+ from aiopg .sa .result import RowProxy
93+
9294 assert isinstance (row , RowProxy ) # nosec
9395 return row
9496
9597 @staticmethod
9698 async def join_and_update_from_pre_registration_details (
97- conn : SAConnection , new_user_id : int , new_user_email : str
99+ conn : DBConnection , new_user_id : int , new_user_email : str
98100 ) -> None :
99101 """After a user is created, it can be associated with information provided during invitation
100102
@@ -111,6 +113,10 @@ async def join_and_update_from_pre_registration_details(
111113 .values (user_id = new_user_id )
112114 )
113115
116+ from aiopg .sa .result import ResultProxy
117+
118+ assert isinstance (result , ResultProxy ) # nosec
119+
114120 if result .rowcount :
115121 pre_columns = (
116122 users_pre_registration_details .c .pre_first_name ,
@@ -135,7 +141,7 @@ async def join_and_update_from_pre_registration_details(
135141 users_pre_registration_details .c .pre_email == new_user_email
136142 )
137143 )
138- if details := await result .fetchone ():
144+ if details := await maybe_await ( result .fetchone () ):
139145 await conn .execute (
140146 users .update ()
141147 .where (users .c .id == new_user_id )
@@ -169,15 +175,14 @@ def get_billing_details_query(user_id: int):
169175 )
170176
171177 @staticmethod
172- async def get_billing_details (conn : SAConnection , user_id : int ) -> RowProxy | None :
178+ async def get_billing_details (conn : DBConnection , user_id : int ) -> Any | None :
173179 result = await conn .execute (
174180 UsersRepo .get_billing_details_query (user_id = user_id )
175181 )
176- value : RowProxy | None = await result .fetchone ()
177- return value
182+ return await maybe_await (result .fetchone ())
178183
179184 @staticmethod
180- async def get_role (conn : SAConnection , user_id : int ) -> UserRole :
185+ async def get_role (conn : DBConnection , user_id : int ) -> UserRole :
181186 value : UserRole | None = await conn .scalar (
182187 sa .select (users .c .role ).where (users .c .id == user_id )
183188 )
@@ -188,7 +193,7 @@ async def get_role(conn: SAConnection, user_id: int) -> UserRole:
188193 raise UserNotFoundInRepoError
189194
190195 @staticmethod
191- async def get_email (conn : SAConnection , user_id : int ) -> str :
196+ async def get_email (conn : DBConnection , user_id : int ) -> str :
192197 value : str | None = await conn .scalar (
193198 sa .select (users .c .email ).where (users .c .id == user_id )
194199 )
@@ -199,7 +204,7 @@ async def get_email(conn: SAConnection, user_id: int) -> str:
199204 raise UserNotFoundInRepoError
200205
201206 @staticmethod
202- async def get_active_user_email (conn : SAConnection , user_id : int ) -> str :
207+ async def get_active_user_email (conn : DBConnection , user_id : int ) -> str :
203208 value : str | None = await conn .scalar (
204209 sa .select (users .c .email ).where (
205210 (users .c .status == UserStatus .ACTIVE ) & (users .c .id == user_id )
@@ -212,7 +217,7 @@ async def get_active_user_email(conn: SAConnection, user_id: int) -> str:
212217 raise UserNotFoundInRepoError
213218
214219 @staticmethod
215- async def is_email_used (conn : SAConnection , email : str ) -> bool :
220+ async def is_email_used (conn : DBConnection , email : str ) -> bool :
216221 email = email .lower ()
217222
218223 registered = await conn .scalar (
0 commit comments