77
88import pytest
99import sqlalchemy as sa
10- from aiopg .sa .connection import SAConnection
11- from aiopg .sa .result import ResultProxy , RowProxy
1210from faker import Faker
1311from pytest_simcore .helpers .faker_factories import random_user
14- from simcore_postgres_database .aiopg_errors import (
15- InvalidTextRepresentation ,
16- UniqueViolation ,
17- )
1812from simcore_postgres_database .models .users import UserRole , UserStatus , users
13+ from simcore_postgres_database .utils_repos import (
14+ transaction_context ,
15+ )
1916from simcore_postgres_database .utils_users import (
2017 UsersRepo ,
2118 _generate_username_from_email ,
2219 generate_alternative_username ,
2320)
21+ from sqlalchemy .exc import DataError , IntegrityError
22+ from sqlalchemy .ext .asyncio import AsyncEngine
2423from sqlalchemy .sql import func
2524
2625
2726@pytest .fixture
28- async def clean_users_db_table (connection : SAConnection ):
27+ async def clean_users_db_table (asyncpg_engine : AsyncEngine ):
2928 yield
30- await connection .execute (users .delete ())
29+ async with transaction_context (asyncpg_engine ) as connection :
30+ await connection .execute (users .delete ())
3131
3232
3333async def test_user_status_as_pending (
34- connection : SAConnection , faker : Faker , clean_users_db_table : None
34+ asyncpg_engine : AsyncEngine , faker : Faker , clean_users_db_table : None
3535):
3636 """Checks a bug where the expression
3737
@@ -51,10 +51,13 @@ async def test_user_status_as_pending(
5151 # tests that the database never stores the word "PENDING"
5252 data = random_user (faker , status = "PENDING" )
5353 assert data ["status" ] == "PENDING"
54- with pytest .raises (InvalidTextRepresentation ) as err_info :
55- await connection .execute (users .insert ().values (data ))
54+ async with transaction_context (asyncpg_engine ) as connection :
55+ with pytest .raises (DataError ) as err_info :
56+ await connection .execute (users .insert ().values (data ))
5657
57- assert 'invalid input value for enum userstatus: "PENDING"' in f"{ err_info .value } "
58+ assert (
59+ 'invalid input value for enum userstatus: "PENDING"' in f"{ err_info .value } "
60+ )
5861
5962
6063@pytest .mark .parametrize (
@@ -66,27 +69,30 @@ async def test_user_status_as_pending(
6669)
6770async def test_user_status_inserted_as_enum_or_int (
6871 status_value : UserStatus | str ,
69- connection : SAConnection ,
72+ asyncpg_engine : AsyncEngine ,
7073 faker : Faker ,
7174 clean_users_db_table : None ,
7275):
7376 # insert as `status_value`
7477 data = random_user (faker , status = status_value )
7578 assert data ["status" ] == status_value
76- user_id = await connection .scalar (users .insert ().values (data ).returning (users .c .id ))
7779
78- # get as UserStatus.CONFIRMATION_PENDING
79- user = await (
80- await connection .execute (users .select ().where (users .c .id == user_id ))
81- ).first ()
82- assert user
80+ async with transaction_context (asyncpg_engine ) as connection :
81+ user_id = await connection .scalar (
82+ users .insert ().values (data ).returning (users .c .id )
83+ )
84+
85+ # get as UserStatus.CONFIRMATION_PENDING
86+ result = await connection .execute (users .select ().where (users .c .id == user_id ))
87+ user = result .one_or_none ()
88+ assert user
8389
84- assert UserStatus (user .status ) == UserStatus .CONFIRMATION_PENDING
85- assert user .status == UserStatus .CONFIRMATION_PENDING
90+ assert UserStatus (user .status ) == UserStatus .CONFIRMATION_PENDING
91+ assert user .status == UserStatus .CONFIRMATION_PENDING
8692
8793
8894async def test_unique_username (
89- connection : SAConnection , faker : Faker , clean_users_db_table : None
95+ asyncpg_engine : AsyncEngine , faker : Faker , clean_users_db_table : None
9096):
9197 data = random_user (
9298 faker ,
@@ -96,95 +102,101 @@ async def test_unique_username(
96102 first_name = "Pedro" ,
97103 last_name = "Crespo Valero" ,
98104 )
99- user_id = await connection .scalar (users .insert ().values (data ).returning (users .c .id ))
100- user = await (
101- await connection .execute (users .select ().where (users .c .id == user_id ))
102- ).first ()
103- assert user
104-
105- assert user .id == user_id
106- assert user .name == "pcrespov"
107-
108- # same name fails
109- data ["email" ] = faker .email ()
110- with pytest .raises (UniqueViolation ):
111- await connection .scalar (users .insert ().values (data ).returning (users .c .id ))
105+ async with transaction_context (asyncpg_engine ) as connection :
106+ user_id = await connection .scalar (
107+ users .insert ().values (data ).returning (users .c .id )
108+ )
109+ result = await connection .execute (users .select ().where (users .c .id == user_id ))
110+ user = result .one_or_none ()
111+ assert user
112+
113+ assert user .id == user_id
114+ assert user .name == "pcrespov"
112115
113- # generate new name
114- data ["name " ] = _generate_username_from_email ( user .email )
115- data [ "email" ] = faker . email ()
116- await connection .scalar (users .insert ().values (data ).returning (users .c .id ))
116+ # same name fails
117+ data ["email " ] = faker .email ( )
118+ with pytest . raises ( IntegrityError ):
119+ await connection .scalar (users .insert ().values (data ).returning (users .c .id ))
117120
118- # and another one
119- data ["name" ] = generate_alternative_username (data ["name" ])
120- data ["email" ] = faker .email ()
121- await connection .scalar (users .insert ().values (data ).returning (users .c .id ))
121+ # generate new name
122+ data ["name" ] = _generate_username_from_email (user .email )
123+ data ["email" ] = faker .email ()
124+ await connection .scalar (users .insert ().values (data ).returning (users .c .id ))
125+
126+ # and another one
127+ data ["name" ] = generate_alternative_username (data ["name" ])
128+ data ["email" ] = faker .email ()
129+ await connection .scalar (users .insert ().values (data ).returning (users .c .id ))
122130
123131
124132async def test_new_user (
125- connection : SAConnection , faker : Faker , clean_users_db_table : None
133+ asyncpg_engine : AsyncEngine , faker : Faker , clean_users_db_table : None
126134):
127135 data = {
128136 "email" : faker .email (),
129137 "password_hash" : "foo" ,
130138 "status" : UserStatus .ACTIVE ,
131139 "expires_at" : datetime .utcnow (),
132140 }
133- new_user = await UsersRepo .new_user (connection , ** data )
134-
135- assert new_user .email == data ["email" ]
136- assert new_user .status == data ["status" ]
137- assert new_user .role == UserRole .USER
138-
139- other_email = f"{ new_user .name } @other-domain.com"
140- assert _generate_username_from_email (other_email ) == new_user .name
141- other_data = {** data , "email" : other_email }
142-
143- other_user = await UsersRepo .new_user (connection , ** other_data )
144- assert other_user .email != new_user .email
145- assert other_user .name != new_user .name
146-
147- assert await UsersRepo .get_email (connection , other_user .id ) == other_user .email
148- assert await UsersRepo .get_role (connection , other_user .id ) == other_user .role
149- assert (
150- await UsersRepo .get_active_user_email (connection , other_user .id )
151- == other_user .email
152- )
141+ async with transaction_context (asyncpg_engine ) as connection :
142+ new_user = await UsersRepo .new_user (connection , ** data )
143+
144+ assert new_user .email == data ["email" ]
145+ assert new_user .status == data ["status" ]
146+ assert new_user .role == UserRole .USER
147+
148+ other_email = f"{ new_user .name } @other-domain.com"
149+ assert _generate_username_from_email (other_email ) == new_user .name
150+ other_data = {** data , "email" : other_email }
151+
152+ other_user = await UsersRepo .new_user (connection , ** other_data )
153+ assert other_user .email != new_user .email
154+ assert other_user .name != new_user .name
155+
156+ assert await UsersRepo .get_email (connection , other_user .id ) == other_user .email
157+ assert await UsersRepo .get_role (connection , other_user .id ) == other_user .role
158+ assert (
159+ await UsersRepo .get_active_user_email (connection , other_user .id )
160+ == other_user .email
161+ )
153162
154163
155- async def test_trial_accounts (connection : SAConnection , clean_users_db_table : None ):
164+ async def test_trial_accounts (asyncpg_engine : AsyncEngine , clean_users_db_table : None ):
156165 EXPIRATION_INTERVAL = timedelta (minutes = 5 )
157166
158167 # creates trial user
159168 client_now = datetime .utcnow ()
160- user_id : int | None = await connection .scalar (
161- users .insert ()
162- .values (
163- ** random_user (
164- status = UserStatus .ACTIVE ,
165- # Using some magic from sqlachemy ...
166- expires_at = func .now () + EXPIRATION_INTERVAL ,
169+ async with transaction_context (asyncpg_engine ) as connection :
170+ user_id : int | None = await connection .scalar (
171+ users .insert ()
172+ .values (
173+ ** random_user (
174+ status = UserStatus .ACTIVE ,
175+ # Using some magic from sqlachemy ...
176+ expires_at = func .now () + EXPIRATION_INTERVAL ,
177+ )
167178 )
179+ .returning (users .c .id )
168180 )
169- .returning (users .c .id )
170- )
171- assert user_id
181+ assert user_id
172182
173- # check expiration date
174- result : ResultProxy = await connection .execute (
175- sa .select (users .c .status , users .c .created_at , users .c .expires_at ).where (
176- users .c .id == user_id
183+ # check expiration date
184+ result = await connection .execute (
185+ sa .select (users .c .status , users .c .created_at , users .c .expires_at ).where (
186+ users .c .id == user_id
187+ )
188+ )
189+ row = result .one_or_none ()
190+ assert row
191+ assert row .created_at - client_now < timedelta (
192+ minutes = 1
193+ ), "Difference between server and client now should not differ much"
194+ assert row .expires_at - row .created_at == EXPIRATION_INTERVAL
195+ assert row .status == UserStatus .ACTIVE
196+
197+ # sets user as expired
198+ await connection .execute (
199+ users .update ()
200+ .values (status = UserStatus .EXPIRED )
201+ .where (users .c .id == user_id )
177202 )
178- )
179- row : RowProxy | None = await result .first ()
180- assert row
181- assert row .created_at - client_now < timedelta (
182- minutes = 1
183- ), "Difference between server and client now should not differ much"
184- assert row .expires_at - row .created_at == EXPIRATION_INTERVAL
185- assert row .status == UserStatus .ACTIVE
186-
187- # sets user as expired
188- await connection .execute (
189- users .update ().values (status = UserStatus .EXPIRED ).where (users .c .id == user_id )
190- )
0 commit comments