Skip to content

Commit 60a5fec

Browse files
Rotheemwarix8cotanoine
authored
Feat : Add base to allow factory setup (#375)
Co-authored-by: Warix <[email protected]> Co-authored-by: cotanoine <[email protected]>
1 parent 6f2caa4 commit 60a5fec

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

57 files changed

+1429
-35
lines changed

app/api.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,5 @@
77
api_router = APIRouter()
88

99

10-
for core_module in all_modules:
11-
api_router.include_router(core_module.router)
10+
for module in all_modules:
11+
api_router.include_router(module.router)

app/app.py

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import alembic.command as alembic_command
1111
import alembic.config as alembic_config
1212
import alembic.migration as alembic_migration
13+
import redis
1314
from calypsso import get_calypsso_app
1415
from fastapi import FastAPI, HTTPException, Request, Response, status
1516
from fastapi.encoders import jsonable_encoder
@@ -56,6 +57,8 @@
5657
if TYPE_CHECKING:
5758
import redis
5859

60+
from app.types.factory import Factory
61+
5962

6063
# NOTE: We can not get loggers at the top of this file like we do in other files
6164
# as the loggers are not yet initialized
@@ -226,6 +229,66 @@ def initialize_schools(
226229
)
227230

228231

232+
async def run_factories(
233+
db: AsyncSession,
234+
settings: Settings,
235+
hyperion_error_logger: logging.Logger,
236+
) -> None:
237+
"""Run the factories to create default data in the database"""
238+
if not settings.USE_FACTORIES:
239+
return
240+
241+
hyperion_error_logger.info("Startup: Factories enabled")
242+
# Importing the core_factory at the beginning of the factories.
243+
factories_list: list[Factory] = []
244+
for module in all_modules:
245+
if module.factory:
246+
factories_list.append(module.factory)
247+
hyperion_error_logger.info(
248+
f"Module {module.root} declares a factory {module.factory.__class__.__name__} with dependencies {module.factory.depends_on}",
249+
)
250+
else:
251+
hyperion_error_logger.warning(
252+
f"Module {module.root} does not declare a factory. It won't provide any base data.",
253+
)
254+
255+
# We have to run the factories in a specific order to make sure the dependencies are met
256+
# For that reason, we will run the first factory that has no dependencies, after that we remove it from the list of the dependencies from the other factories
257+
# And we loop until there are no more factories to run and we use a boolean to avoid infinite loops with circular dependencies
258+
no_factory_run_during_last_loop = False
259+
ran_factories: list[type[Factory]] = []
260+
while len(factories_list) > 0 and not no_factory_run_during_last_loop:
261+
no_factory_run_during_last_loop = True
262+
for factory in factories_list:
263+
if all(depend in ran_factories for depend in factory.depends_on):
264+
no_factory_run_during_last_loop = False
265+
# Check if the factory should be run
266+
if await factory.should_run(db):
267+
hyperion_error_logger.info(
268+
f"Startup: Running factory {factory.__class__.__name__}",
269+
)
270+
try:
271+
await factory.run(db, settings)
272+
except Exception as error:
273+
hyperion_error_logger.fatal(
274+
f"Startup: Could not run factories: {error}",
275+
)
276+
raise
277+
else:
278+
hyperion_error_logger.info(
279+
f"Startup: Factory {factory.__class__.__name__} is not necessary, skipping it",
280+
)
281+
ran_factories.append(factory.__class__)
282+
factories_list.remove(factory)
283+
break
284+
if no_factory_run_during_last_loop:
285+
hyperion_error_logger.error(
286+
"Factories are not correctly configured, some factories are not running.",
287+
)
288+
break
289+
hyperion_error_logger.info("Startup: Factories have been run")
290+
291+
229292
def initialize_module_visibility(
230293
sync_engine: Engine,
231294
hyperion_error_logger: logging.Logger,
@@ -487,7 +550,18 @@ async def init_lifespan(
487550
get_db,
488551
get_db,
489552
)
490-
553+
# We need to run the factories only once across all the workers
554+
async for db in get_db_dependency(state):
555+
await initialization.use_lock_for_workers(
556+
run_factories,
557+
"run_factories",
558+
redis_client,
559+
number_of_workers,
560+
hyperion_error_logger,
561+
db=db,
562+
settings=settings,
563+
hyperion_error_logger=hyperion_error_logger,
564+
)
491565
async for db in get_db_dependency(state):
492566
await initialization.use_lock_for_workers(
493567
init_google_API,

app/core/auth/endpoints_auth.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
root="auth",
5151
tag="Auth",
5252
router=router,
53+
factory=None,
5354
)
5455

5556
templates = Jinja2Templates(directory="assets/templates")

app/core/core_endpoints/endpoints_core.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import logging
21
from os import path
32
from pathlib import Path
43

@@ -23,13 +22,12 @@
2322
router = APIRouter(tags=["Core"])
2423

2524
core_module = CoreModule(
26-
root="",
25+
root="core",
2726
tag="Core",
2827
router=router,
28+
factory=None,
2929
)
3030

31-
hyperion_error_logger = logging.getLogger("hyperion.error")
32-
3331

3432
@router.get(
3533
"/information",

app/core/google_api/endpoints_google_api.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
root="google-api",
1818
tag="GoogleAPI",
1919
router=router,
20+
factory=None,
2021
)
2122

2223
hyperion_error_logger = logging.getLogger("hyperion.error")

app/core/groups/endpoints_groups.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from sqlalchemy.ext.asyncio import AsyncSession
1212

1313
from app.core.groups import cruds_groups, models_groups, schemas_groups
14+
from app.core.groups.factory_groups import CoreGroupsFactory
1415
from app.core.groups.groups_type import GroupType
1516
from app.core.notification.utils_notification import get_topics_restricted_to_group_id
1617
from app.core.users import cruds_users
@@ -30,6 +31,7 @@
3031
root="groups",
3132
tag="Groups",
3233
router=router,
34+
factory=CoreGroupsFactory(),
3335
)
3436

3537
hyperion_security_logger = logging.getLogger("hyperion.security")

app/core/groups/factory_groups.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import random
2+
import uuid
3+
4+
from sqlalchemy.ext.asyncio import AsyncSession
5+
6+
from app.core.groups import cruds_groups
7+
from app.core.groups.groups_type import GroupType
8+
from app.core.groups.models_groups import CoreGroup, CoreMembership
9+
from app.core.users.factory_users import CoreUsersFactory
10+
from app.core.utils.config import Settings
11+
from app.types.factory import Factory
12+
13+
14+
class CoreGroupsFactory(Factory):
15+
groups_ids = [
16+
str(uuid.uuid4()),
17+
str(uuid.uuid4()),
18+
]
19+
20+
depends_on = [CoreUsersFactory]
21+
22+
@classmethod
23+
async def create_core_groups(cls, db: AsyncSession):
24+
groups = ["Oui", "Pixels"]
25+
descriptions = ["Groupe de test", "Groupe de test 2"]
26+
for i in range(len(groups)):
27+
await cruds_groups.create_group(
28+
db=db,
29+
group=CoreGroup(
30+
id=cls.groups_ids[i],
31+
name=groups[i],
32+
description=descriptions[i],
33+
),
34+
)
35+
36+
@classmethod
37+
async def create_core_memberships(cls, db: AsyncSession):
38+
for i in range(len(cls.groups_ids)):
39+
users = random.sample(CoreUsersFactory.other_users_id, 10)
40+
41+
for user_id in users:
42+
await cruds_groups.create_membership(
43+
db=db,
44+
membership=CoreMembership(
45+
group_id=cls.groups_ids[i],
46+
user_id=user_id,
47+
description=None,
48+
),
49+
)
50+
51+
@classmethod
52+
async def run(cls, db: AsyncSession, settings: Settings) -> None:
53+
await cls.create_core_groups(db=db)
54+
await cls.create_core_memberships(db=db)
55+
56+
@classmethod
57+
async def should_run(cls, db: AsyncSession):
58+
return len(await cruds_groups.get_groups(db=db)) == len(GroupType)

app/core/memberships/cruds_memberships.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ async def get_association_membership_by_id(
7979
)
8080

8181

82-
def create_association_membership(
82+
async def create_association_membership(
8383
db: AsyncSession,
8484
membership: schemas_memberships.MembershipSimple,
8585
):
@@ -347,7 +347,7 @@ async def get_user_membership_by_id(
347347
)
348348

349349

350-
def create_user_membership(
350+
async def create_user_membership(
351351
db: AsyncSession,
352352
user_membership: schemas_memberships.UserMembershipSimple,
353353
):

app/core/memberships/endpoints_memberships.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,11 @@
77

88
from app.core.groups import cruds_groups
99
from app.core.groups.groups_type import GroupType
10-
from app.core.memberships import cruds_memberships, schemas_memberships
10+
from app.core.memberships import (
11+
cruds_memberships,
12+
schemas_memberships,
13+
)
14+
from app.core.memberships.factory_memberships import CoreMembershipsFactory
1115
from app.core.memberships.utils_memberships import validate_user_new_membership
1216
from app.core.users import cruds_users, models_users, schemas_users
1317
from app.dependencies import (
@@ -25,6 +29,7 @@
2529
root="memberships",
2630
tag="Memberships",
2731
router=router,
32+
factory=CoreMembershipsFactory(),
2833
)
2934

3035

@@ -128,7 +133,7 @@ async def create_association_membership(
128133
id=uuid.uuid4(),
129134
)
130135

131-
cruds_memberships.create_association_membership(
136+
await cruds_memberships.create_association_membership(
132137
db=db,
133138
membership=db_association_membership,
134139
)
@@ -305,7 +310,10 @@ async def create_user_membership(
305310
)
306311
await validate_user_new_membership(db_user_membership, db)
307312

308-
cruds_memberships.create_user_membership(db=db, user_membership=db_user_membership)
313+
await cruds_memberships.create_user_membership(
314+
db=db,
315+
user_membership=db_user_membership,
316+
)
309317

310318
await db.flush()
311319

@@ -369,7 +377,7 @@ async def add_batch_membership(
369377
end_date=detail.end_date,
370378
)
371379
if len(stored_memberships) == 0:
372-
cruds_memberships.create_user_membership(
380+
await cruds_memberships.create_user_membership(
373381
db=db,
374382
user_membership=schemas_memberships.UserMembershipSimple(
375383
id=uuid.uuid4(),
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
import datetime
2+
import random
3+
from uuid import uuid4
4+
5+
from sqlalchemy.ext.asyncio import AsyncSession
6+
7+
from app.core.groups.groups_type import GroupType
8+
from app.core.memberships import cruds_memberships
9+
from app.core.memberships.schemas_memberships import (
10+
MembershipSimple,
11+
UserMembershipSimple,
12+
)
13+
from app.core.users.factory_users import CoreUsersFactory
14+
from app.core.utils.config import Settings
15+
from app.types.factory import Factory
16+
17+
18+
class CoreMembershipsFactory(Factory):
19+
memberships_ids = [
20+
uuid4(),
21+
uuid4(),
22+
]
23+
memberships_names = [
24+
"AEECL",
25+
"USEECL",
26+
]
27+
memberships_manager_group_id = [
28+
GroupType.BDE.value,
29+
GroupType.BDS.value,
30+
]
31+
32+
depends_on = [CoreUsersFactory]
33+
34+
@classmethod
35+
async def run(cls, db: AsyncSession, settings: Settings) -> None:
36+
for i in range(len(cls.memberships_ids)):
37+
await cruds_memberships.create_association_membership(
38+
db,
39+
MembershipSimple(
40+
id=cls.memberships_ids[i],
41+
name=cls.memberships_names[i],
42+
manager_group_id=cls.memberships_manager_group_id[i],
43+
),
44+
)
45+
46+
members = random.sample(
47+
CoreUsersFactory.other_users_id,
48+
20,
49+
)
50+
for user_id in members:
51+
await cruds_memberships.create_user_membership(
52+
db=db,
53+
user_membership=UserMembershipSimple(
54+
id=uuid4(),
55+
user_id=user_id,
56+
association_membership_id=cls.memberships_ids[i],
57+
start_date=datetime.datetime(
58+
random.randint(2020, 2023), # noqa: S311
59+
random.randint(1, 12), # noqa: S311
60+
random.randint(1, 28), # noqa: S311
61+
tzinfo=datetime.UTC,
62+
),
63+
end_date=datetime.datetime(
64+
random.randint(2025, 2027), # noqa: S311
65+
random.randint(1, 12), # noqa: S311
66+
random.randint(1, 28), # noqa: S311
67+
tzinfo=datetime.UTC,
68+
),
69+
),
70+
)
71+
await db.commit()
72+
73+
@classmethod
74+
async def should_run(cls, db: AsyncSession):
75+
result = (
76+
len(
77+
await cruds_memberships.get_association_memberships(
78+
db=db,
79+
),
80+
)
81+
== 0
82+
)
83+
if not result:
84+
registered_memberships = (
85+
await cruds_memberships.get_association_memberships(
86+
db=db,
87+
)
88+
)
89+
cls.memberships_ids = [
90+
membership.id for membership in registered_memberships
91+
]

0 commit comments

Comments
 (0)