Skip to content

Commit 7556567

Browse files
committed
improve MockCurrentUser
1 parent 945a702 commit 7556567

File tree

1 file changed

+90
-65
lines changed

1 file changed

+90
-65
lines changed

tests/fixtures_server.py

Lines changed: 90 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import random
1+
import time
22
from collections.abc import AsyncGenerator
33
from collections.abc import Generator
44
from dataclasses import dataclass
@@ -7,10 +7,10 @@
77

88
import pytest
99
from asgi_lifespan import LifespanManager
10+
from devtools import debug
1011
from fastapi import FastAPI
1112
from httpx import ASGITransport
1213
from httpx import AsyncClient
13-
from sqlalchemy.exc import IntegrityError
1414
from sqlalchemy.ext.asyncio import AsyncSession
1515
from sqlmodel import select
1616

@@ -206,7 +206,7 @@ async def client(
206206

207207

208208
@pytest.fixture
209-
async def registered_client(
209+
async def registered_client( # FIXME maybe remove?
210210
app: FastAPI, register_routers, db
211211
) -> AsyncGenerator[AsyncClient, Any]:
212212
@@ -273,50 +273,66 @@ async def default_user_group(db) -> UserGroup:
273273

274274

275275
@pytest.fixture
276-
async def MockCurrentUser(app, db, default_user_group):
276+
async def MockCurrentUser(app: FastAPI, db, default_user_group):
277277
from fractal_server.app.routes.auth import (
278278
current_user_act_ver_prof,
279279
current_user_act,
280280
current_user_act_ver,
281+
current_superuser_act,
281282
)
282-
from fractal_server.app.routes.auth import current_superuser_act
283283

284-
def _random_email():
285-
return f"{random.randint(0, 100000000)}@example.org"
284+
def _new_mail():
285+
return f"{time.perf_counter_ns()}@example.org"
286286

287287
@dataclass
288288
class _MockCurrentUser:
289289
"""
290290
Context managed user override
291291
"""
292292

293-
user_kwargs: dict[str, Any] | None = None
294-
email: str | None = field(default_factory=_random_email)
295-
previous_dependencies: dict = field(default_factory=dict)
293+
user_kwargs: dict[str, Any] = field(default_factory=dict)
294+
email: str | None = field(default_factory=_new_mail)
295+
previous_deps: dict = field(default_factory=dict)
296+
debug: bool = False
296297

297298
async def __aenter__(self):
298-
if self.user_kwargs is not None and "id" in self.user_kwargs:
299+
user_id = self.user_kwargs.get("id", None)
300+
if user_id is not None:
301+
# (1) Look for existing user, by ID
299302
db_user = await db.get(
300-
UserOAuth, self.user_kwargs["id"], populate_existing=True
303+
UserOAuth,
304+
user_id,
305+
populate_existing=True,
301306
)
307+
if self.debug:
308+
debug("FOUND USER", db_user)
302309
if db_user is None:
303310
raise RuntimeError(
304-
f"User with id {self.user_kwargs['id']} doesn't exist"
311+
f"[MockCurrentUser] User with {user_id=} doesn't exist"
305312
)
313+
for k, v in self.user_kwargs.items():
314+
if not getattr(db_user, k) == v:
315+
raise RuntimeError(
316+
f"[MockCurrentUser] User with {user_id=} has "
317+
f"{k}={v}."
318+
)
306319
self.user = db_user
307-
# Removing objects from test db session, so that we can operate
308-
# on them from other sessions
309-
db.expunge(self.user)
310320
else:
311-
if (
312-
self.user_kwargs is not None
313-
and "profile_id" not in self.user_kwargs.keys()
314-
):
321+
# (2) Create new user
322+
default_user_kwargs = dict(
323+
email=self.email,
324+
hashed_password="fake_hashed_password",
325+
project_dir=PROJECT_DIR_PLACEHOLDER,
326+
is_verified=True,
327+
)
328+
329+
# (2/a) Handle resource and profile
330+
if "profile_id" not in self.user_kwargs.keys():
315331
res = await db.execute(select(Profile))
316-
profile = res.scalars().one_or_none()
332+
profile = res.scalars().first()
317333
if profile is None:
318334
resource = Resource(
319-
name="local resource 1",
335+
name="Local resource",
320336
type=ResourceType.LOCAL,
321337
jobs_local_dir="/jobs_local_dir",
322338
tasks_local_dir="/tasks_local_dir",
@@ -332,8 +348,8 @@ async def __aenter__(self):
332348
await db.commit()
333349
await db.refresh(resource)
334350
db.expunge(resource)
335-
336351
profile = Profile(
352+
username="test01",
337353
resource_id=resource.id,
338354
name="local_resource_profile_objects",
339355
resource_type=ResourceType.LOCAL,
@@ -342,76 +358,85 @@ async def __aenter__(self):
342358
await db.commit()
343359
await db.refresh(profile)
344360
db.expunge(profile)
345-
profile_id = profile.id
361+
default_user_kwargs["profile_id"] = profile.id
346362

347363
# Create new user
348-
user_attributes = dict(
349-
email=self.email,
350-
hashed_password="fake_hashed_password",
351-
project_dir="/fake",
352-
profile_id=profile_id,
353-
)
354-
if self.user_kwargs is not None:
355-
user_attributes.update(self.user_kwargs)
356-
self.user = UserOAuth(**user_attributes)
357-
358-
try:
359-
db.add(self.user)
360-
await db.commit()
361-
except IntegrityError:
362-
# Safety net, in case of non-unique email addresses
363-
await db.rollback()
364-
self.user.email = _random_email()
365-
db.add(self.user)
366-
await db.commit()
364+
default_user_kwargs.update(self.user_kwargs)
365+
self.user = UserOAuth(**default_user_kwargs)
366+
367+
db.add(self.user)
368+
await db.commit()
367369
await db.refresh(self.user)
368370

371+
if self.debug:
372+
debug("CREATED USER", self.user)
373+
369374
db.add(
370375
LinkUserGroup(
371-
user_id=self.user.id, group_id=default_user_group.id
376+
user_id=self.user.id,
377+
group_id=default_user_group.id,
372378
)
373379
)
380+
await db.commit()
381+
if self.debug:
382+
debug(
383+
f"Created link between user_id={self.user.id} and "
384+
f"group_id={default_user_group.id}."
385+
)
374386

375-
# Removing objects from test db session, so that we can operate
376-
# on them from other sessions
377-
db.expunge(self.user)
387+
# Removing objects from test db session, so that we can operate
388+
# on them from other sessions
389+
db.expunge(self.user)
378390

379391
# Find out which dependencies should be overridden, and store their
380392
# pre-override value
381393
if self.user.is_active:
382-
self.previous_dependencies[
383-
current_user_act
384-
] = app.dependency_overrides.get(current_user_act, None)
385-
if self.user.is_active and self.user.is_superuser:
386-
self.previous_dependencies[
387-
current_superuser_act
388-
] = app.dependency_overrides.get(current_superuser_act, None)
394+
dep = current_user_act
395+
self.previous_deps[dep] = app.dependency_overrides.get(
396+
dep, None
397+
)
398+
if self.debug:
399+
debug(f"Override {current_user_act}.")
400+
389401
if self.user.is_active and self.user.is_verified:
390-
self.previous_dependencies[
391-
current_user_act_ver
392-
] = app.dependency_overrides.get(current_user_act_ver, None)
402+
dep = current_user_act_ver
403+
self.previous_deps[dep] = app.dependency_overrides.get(
404+
dep, None
405+
)
406+
if self.debug:
407+
debug(f"Override {current_user_act_ver}.")
408+
393409
if (
394410
self.user.is_active
395411
and self.user.is_verified
396412
and self.user.profile_id is not None
397413
):
398-
self.previous_dependencies[
399-
current_user_act_ver_prof
400-
] = app.dependency_overrides.get(
401-
current_user_act_ver_prof, None
414+
dep = current_user_act_ver_prof
415+
if self.debug:
416+
debug(f"Override {current_user_act_ver_prof}.")
417+
self.previous_deps[dep] = app.dependency_overrides.get(
418+
dep, None
419+
)
420+
421+
if self.user.is_active and self.user.is_superuser:
422+
dep = current_superuser_act
423+
if self.debug:
424+
debug(f"Override {current_superuser_act}.")
425+
self.previous_deps[dep] = app.dependency_overrides.get(
426+
dep, None
402427
)
403428

404429
# Override dependencies in the FastAPI app
405-
for dep in self.previous_dependencies.keys():
406-
app.dependency_overrides[dep] = lambda: self.user
430+
for _dep in self.previous_deps.keys():
431+
app.dependency_overrides[_dep] = lambda: self.user
407432

408433
return self.user
409434

410435
async def __aexit__(self, *args, **kwargs):
411436
# Reset overridden dependencies to the original ones
412-
for dep, previous_dep in self.previous_dependencies.items():
437+
for _dep, previous_dep in self.previous_deps.items():
413438
if previous_dep is not None:
414-
app.dependency_overrides[dep] = previous_dep
439+
app.dependency_overrides[_dep] = previous_dep
415440

416441
return _MockCurrentUser
417442

0 commit comments

Comments
 (0)