1- import random
1+ import time
22from collections .abc import AsyncGenerator
33from collections .abc import Generator
44from dataclasses import dataclass
77
88import pytest
99from asgi_lifespan import LifespanManager
10+ from devtools import debug
1011from fastapi import FastAPI
1112from httpx import ASGITransport
1213from httpx import AsyncClient
13- from sqlalchemy .exc import IntegrityError
1414from sqlalchemy .ext .asyncio import AsyncSession
1515from sqlmodel import select
1616
@@ -206,7 +206,7 @@ async def client(
206206
207207
208208@pytest .fixture
209- async def registered_client (
209+ async def registered_client ( # FIXME maybe remove?
210210 app : FastAPI , register_routers , db
211211) -> AsyncGenerator [AsyncClient , Any ]:
212212@@ -273,50 +273,66 @@ async def default_user_group(db) -> UserGroup:
273273
274274
275275@pytest .fixture
276- async def MockCurrentUser (app , db , default_user_group ):
276+ async def MockCurrentUser (app : FastAPI , db , default_user_group ):
277277 from fractal_server .app .routes .auth import (
278278 current_user_act_ver_prof ,
279279 current_user_act ,
280280 current_user_act_ver ,
281+ current_superuser_act ,
281282 )
282- from fractal_server .app .routes .auth import current_superuser_act
283283
284- def _random_email ():
285- return f"{ random . randint ( 0 , 100000000 )} @example.org"
284+ def _new_mail ():
285+ return f"{ time . perf_counter_ns ( )} @example.org"
286286
287287 @dataclass
288288 class _MockCurrentUser :
289289 """
290290 Context managed user override
291291 """
292292
293- user_kwargs : dict [str , Any ] | None = None
294- email : str | None = field (default_factory = _random_email )
295- previous_dependencies : dict = field (default_factory = dict )
293+ user_kwargs : dict [str , Any ] = field (default_factory = dict )
294+ email : str | None = field (default_factory = _new_mail )
295+ previous_deps : dict = field (default_factory = dict )
296+ debug : bool = False
296297
297298 async def __aenter__ (self ):
298- if self .user_kwargs is not None and "id" in self .user_kwargs :
299+ user_id = self .user_kwargs .get ("id" , None )
300+ if user_id is not None :
301+ # (1) Look for existing user, by ID
299302 db_user = await db .get (
300- UserOAuth , self .user_kwargs ["id" ], populate_existing = True
303+ UserOAuth ,
304+ user_id ,
305+ populate_existing = True ,
301306 )
307+ if self .debug :
308+ debug ("FOUND USER" , db_user )
302309 if db_user is None :
303310 raise RuntimeError (
304- f"User with id { self . user_kwargs [ 'id' ] } doesn't exist"
311+ f"[MockCurrentUser] User with { user_id = } doesn't exist"
305312 )
313+ for k , v in self .user_kwargs .items ():
314+ if not getattr (db_user , k ) == v :
315+ raise RuntimeError (
316+ f"[MockCurrentUser] User with { user_id = } has "
317+ f"{ k } ={ v } ."
318+ )
306319 self .user = db_user
307- # Removing objects from test db session, so that we can operate
308- # on them from other sessions
309- db .expunge (self .user )
310320 else :
311- if (
312- self .user_kwargs is not None
313- and "profile_id" not in self .user_kwargs .keys ()
314- ):
321+ # (2) Create new user
322+ default_user_kwargs = dict (
323+ email = self .email ,
324+ hashed_password = "fake_hashed_password" ,
325+ project_dir = PROJECT_DIR_PLACEHOLDER ,
326+ is_verified = True ,
327+ )
328+
329+ # (2/a) Handle resource and profile
330+ if "profile_id" not in self .user_kwargs .keys ():
315331 res = await db .execute (select (Profile ))
316- profile = res .scalars ().one_or_none ()
332+ profile = res .scalars ().first ()
317333 if profile is None :
318334 resource = Resource (
319- name = "local resource 1 " ,
335+ name = "Local resource" ,
320336 type = ResourceType .LOCAL ,
321337 jobs_local_dir = "/jobs_local_dir" ,
322338 tasks_local_dir = "/tasks_local_dir" ,
@@ -332,8 +348,8 @@ async def __aenter__(self):
332348 await db .commit ()
333349 await db .refresh (resource )
334350 db .expunge (resource )
335-
336351 profile = Profile (
352+ username = "test01" ,
337353 resource_id = resource .id ,
338354 name = "local_resource_profile_objects" ,
339355 resource_type = ResourceType .LOCAL ,
@@ -342,76 +358,85 @@ async def __aenter__(self):
342358 await db .commit ()
343359 await db .refresh (profile )
344360 db .expunge (profile )
345- profile_id = profile .id
361+ default_user_kwargs [ " profile_id" ] = profile .id
346362
347363 # Create new user
348- user_attributes = dict (
349- email = self .email ,
350- hashed_password = "fake_hashed_password" ,
351- project_dir = "/fake" ,
352- profile_id = profile_id ,
353- )
354- if self .user_kwargs is not None :
355- user_attributes .update (self .user_kwargs )
356- self .user = UserOAuth (** user_attributes )
357-
358- try :
359- db .add (self .user )
360- await db .commit ()
361- except IntegrityError :
362- # Safety net, in case of non-unique email addresses
363- await db .rollback ()
364- self .user .email = _random_email ()
365- db .add (self .user )
366- await db .commit ()
364+ default_user_kwargs .update (self .user_kwargs )
365+ self .user = UserOAuth (** default_user_kwargs )
366+
367+ db .add (self .user )
368+ await db .commit ()
367369 await db .refresh (self .user )
368370
371+ if self .debug :
372+ debug ("CREATED USER" , self .user )
373+
369374 db .add (
370375 LinkUserGroup (
371- user_id = self .user .id , group_id = default_user_group .id
376+ user_id = self .user .id ,
377+ group_id = default_user_group .id ,
372378 )
373379 )
380+ await db .commit ()
381+ if self .debug :
382+ debug (
383+ f"Created link between user_id={ self .user .id } and "
384+ f"group_id={ default_user_group .id } ."
385+ )
374386
375- # Removing objects from test db session, so that we can operate
376- # on them from other sessions
377- db .expunge (self .user )
387+ # Removing objects from test db session, so that we can operate
388+ # on them from other sessions
389+ db .expunge (self .user )
378390
379391 # Find out which dependencies should be overridden, and store their
380392 # pre-override value
381393 if self .user .is_active :
382- self . previous_dependencies [
383- current_user_act
384- ] = app . dependency_overrides . get ( current_user_act , None )
385- if self . user . is_active and self . user . is_superuser :
386- self .previous_dependencies [
387- current_superuser_act
388- ] = app . dependency_overrides . get ( current_superuser_act , None )
394+ dep = current_user_act
395+ self . previous_deps [ dep ] = app . dependency_overrides . get (
396+ dep , None
397+ )
398+ if self .debug :
399+ debug ( f"Override { current_user_act } ." )
400+
389401 if self .user .is_active and self .user .is_verified :
390- self .previous_dependencies [
391- current_user_act_ver
392- ] = app .dependency_overrides .get (current_user_act_ver , None )
402+ dep = current_user_act_ver
403+ self .previous_deps [dep ] = app .dependency_overrides .get (
404+ dep , None
405+ )
406+ if self .debug :
407+ debug (f"Override { current_user_act_ver } ." )
408+
393409 if (
394410 self .user .is_active
395411 and self .user .is_verified
396412 and self .user .profile_id is not None
397413 ):
398- self .previous_dependencies [
399- current_user_act_ver_prof
400- ] = app .dependency_overrides .get (
401- current_user_act_ver_prof , None
414+ dep = current_user_act_ver_prof
415+ if self .debug :
416+ debug (f"Override { current_user_act_ver_prof } ." )
417+ self .previous_deps [dep ] = app .dependency_overrides .get (
418+ dep , None
419+ )
420+
421+ if self .user .is_active and self .user .is_superuser :
422+ dep = current_superuser_act
423+ if self .debug :
424+ debug (f"Override { current_superuser_act } ." )
425+ self .previous_deps [dep ] = app .dependency_overrides .get (
426+ dep , None
402427 )
403428
404429 # Override dependencies in the FastAPI app
405- for dep in self .previous_dependencies .keys ():
406- app .dependency_overrides [dep ] = lambda : self .user
430+ for _dep in self .previous_deps .keys ():
431+ app .dependency_overrides [_dep ] = lambda : self .user
407432
408433 return self .user
409434
410435 async def __aexit__ (self , * args , ** kwargs ):
411436 # Reset overridden dependencies to the original ones
412- for dep , previous_dep in self .previous_dependencies .items ():
437+ for _dep , previous_dep in self .previous_deps .items ():
413438 if previous_dep is not None :
414- app .dependency_overrides [dep ] = previous_dep
439+ app .dependency_overrides [_dep ] = previous_dep
415440
416441 return _MockCurrentUser
417442
0 commit comments