Skip to content

Commit 2701650

Browse files
Notification topics refactorization using a db table (#680)
Change completely the structure of topics in Hyperion. Instead of using an Enum difficult to edit, we have a table in db with all existing topics. This allows to ensure that users can only subscribe to *real* topics. Topics can be created either in the module definition (ex: `cinema`), or using a utils from an endpoint (creation of a new topic each time we create an advertiser in `announce`). Topics can now be restricted to a specific group. By default, all users are subscribed to each topic, they can unsubscribe manually if they want. A change in the frontend is required: - using `/notification/topics` the client will get all available topics - it could be nice to group topics by `module_root` This PR will reset existing subscriptions. All users will be subscribed to all topics, except for `advert` Fix #479 Requires #768 --------- Co-authored-by: Foucauld Bellanger <[email protected]> Co-authored-by: armanddidierjean <[email protected]>
1 parent b43d58b commit 2701650

File tree

17 files changed

+817
-265
lines changed

17 files changed

+817
-265
lines changed

app/app.py

Lines changed: 56 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from collections.abc import AsyncGenerator, Awaitable, Callable
66
from contextlib import asynccontextmanager
77
from pathlib import Path
8-
from typing import TYPE_CHECKING, cast
8+
from typing import TYPE_CHECKING
99

1010
import alembic.command as alembic_command
1111
import alembic.config as alembic_config
@@ -28,6 +28,7 @@
2828
from app.core.google_api.google_api import GoogleAPI
2929
from app.core.groups import models_groups
3030
from app.core.groups.groups_type import GroupType
31+
from app.core.notification.cruds_notification import get_notification_topic
3132
from app.core.schools import models_schools
3233
from app.core.schools.schools_type import SchoolType
3334
from app.core.utils.config import Settings
@@ -36,17 +37,19 @@
3637
disconnect_state,
3738
get_app_state,
3839
get_db,
40+
get_notification_manager,
3941
get_redis_client,
4042
init_app_state,
4143
)
42-
from app.module import module_list
44+
from app.module import all_modules, module_list
4345
from app.types.exceptions import (
4446
ContentHTTPException,
4547
GoogleAPIInvalidCredentialsError,
4648
MultipleWorkersWithoutRedisInitializationError,
4749
)
4850
from app.types.sqlalchemy import Base
4951
from app.utils import initialization
52+
from app.utils.communication.notifications import NotificationManager
5053
from app.utils.redis import limiter
5154
from app.utils.state import LifespanState
5255

@@ -293,6 +296,32 @@ def initialize_module_visibility(
293296
)
294297

295298

299+
async def initialize_notification_topics(
300+
db: AsyncSession,
301+
hyperion_error_logger: logging.Logger,
302+
notification_manager: NotificationManager,
303+
) -> None:
304+
existing_topics = await get_notification_topic(db=db)
305+
existing_topics_id = [topic.id for topic in existing_topics]
306+
for module in all_modules:
307+
if module.registred_topics:
308+
for registred_topic in module.registred_topics:
309+
if registred_topic.id not in existing_topics_id:
310+
# We want to register this new topic
311+
hyperion_error_logger.info(
312+
f"Registering topic {registred_topic.name} ({registred_topic.id})",
313+
)
314+
await notification_manager.register_new_topic(
315+
topic_id=registred_topic.id,
316+
name=registred_topic.name,
317+
module_root=registred_topic.module_root,
318+
topic_identifier=registred_topic.topic_identifier,
319+
restrict_to_group_id=registred_topic.restrict_to_group_id,
320+
restrict_to_members=registred_topic.restrict_to_members,
321+
db=db,
322+
)
323+
324+
296325
def use_route_path_as_operation_ids(app: FastAPI) -> None:
297326
"""
298327
Simplify operation IDs so that generated API clients have simpler function names.
@@ -404,17 +433,16 @@ async def init_lifespan(
404433

405434
# We get `init_app_state` as a dependency, as tests
406435
# should override it to provide their own state
407-
state = await app.dependency_overrides.get(
436+
state: LifespanState = await app.dependency_overrides.get(
408437
init_app_state,
409438
init_app_state,
410439
)(
411440
app=app,
412441
settings=settings,
413442
hyperion_error_logger=hyperion_error_logger,
414443
)
415-
state = cast("LifespanState", state)
416444

417-
redis_client = app.dependency_overrides.get(
445+
redis_client: Redis | None = app.dependency_overrides.get(
418446
get_redis_client,
419447
get_redis_client,
420448
)(state=state)
@@ -452,10 +480,15 @@ async def init_lifespan(
452480
hyperion_error_logger=hyperion_error_logger,
453481
)
454482

455-
async for db in app.dependency_overrides.get(
483+
get_db_dependency: Callable[
484+
[LifespanState],
485+
AsyncGenerator[AsyncSession, None],
486+
] = app.dependency_overrides.get(
456487
get_db,
457488
get_db,
458-
)(state=state):
489+
)
490+
491+
async for db in get_db_dependency(state):
459492
await initialization.use_lock_for_workers(
460493
init_google_API,
461494
"init_google_API",
@@ -466,6 +499,22 @@ async def init_lifespan(
466499
settings=settings,
467500
)
468501

502+
async for db in get_db_dependency(state):
503+
notification_manager = app.dependency_overrides.get(
504+
get_notification_manager,
505+
get_notification_manager,
506+
)(state)
507+
await initialization.use_lock_for_workers(
508+
initialize_notification_topics,
509+
"initialize_notification_topics",
510+
redis_client,
511+
number_of_workers,
512+
hyperion_error_logger,
513+
db=db,
514+
hyperion_error_logger=hyperion_error_logger,
515+
notification_manager=notification_manager,
516+
)
517+
469518
return state
470519

471520

app/core/groups/endpoints_groups.py

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,17 @@
1212

1313
from app.core.groups import cruds_groups, models_groups, schemas_groups
1414
from app.core.groups.groups_type import GroupType
15+
from app.core.notification.utils_notification import get_topics_restricted_to_group_id
1516
from app.core.users import cruds_users
1617
from app.dependencies import (
1718
get_db,
19+
get_notification_manager,
1820
get_request_id,
1921
is_user_an_ecl_member,
2022
is_user_in,
2123
)
2224
from app.types.module import CoreModule
25+
from app.utils.communication.notifications import NotificationManager
2326

2427
router = APIRouter(tags=["Groups"])
2528

@@ -166,15 +169,12 @@ async def create_membership(
166169
f"Create_membership: Admin user {user.id} ({user.name}) added user {user_db.id} ({user_db.email}) to group {group_db.id} ({group_db.name}) ({request_id})",
167170
)
168171

169-
try:
170-
membership_db = models_groups.CoreMembership(
171-
user_id=membership.user_id,
172-
group_id=membership.group_id,
173-
description=membership.description,
174-
)
175-
return await cruds_groups.create_membership(db=db, membership=membership_db)
176-
except ValueError as error:
177-
raise HTTPException(status_code=422, detail=str(error))
172+
membership_db = models_groups.CoreMembership(
173+
user_id=membership.user_id,
174+
group_id=membership.group_id,
175+
description=membership.description,
176+
)
177+
return await cruds_groups.create_membership(db=db, membership=membership_db)
178178

179179

180180
@router.post(
@@ -232,6 +232,7 @@ async def delete_membership(
232232
db: AsyncSession = Depends(get_db),
233233
user=Depends(is_user_in(GroupType.admin)),
234234
request_id: str = Depends(get_request_id),
235+
notification_manager: NotificationManager = Depends(get_notification_manager),
235236
):
236237
"""
237238
Delete a membership using the user and group ids.
@@ -243,6 +244,19 @@ async def delete_membership(
243244
f"Create_membership: Admin user {user.id} ({user.name}) removed user {membership.user_id} from group {membership.group_id} ({request_id})",
244245
)
245246

247+
# To remove a user from a group, we should unsubscribe this user from all
248+
# topic that required to be a member of this group
249+
restricted_topics = await get_topics_restricted_to_group_id(
250+
group_id=membership.group_id,
251+
db=db,
252+
)
253+
for topic in restricted_topics:
254+
await notification_manager.unsubscribe_user_to_topic(
255+
topic_id=topic.id,
256+
user_id=membership.user_id,
257+
db=db,
258+
)
259+
246260
await cruds_groups.delete_membership_by_group_and_user_id(
247261
group_id=membership.group_id,
248262
user_id=membership.user_id,
@@ -259,6 +273,7 @@ async def delete_batch_membership(
259273
db: AsyncSession = Depends(get_db),
260274
user=Depends(is_user_in(GroupType.admin)),
261275
request_id: str = Depends(get_request_id),
276+
notification_manager: NotificationManager = Depends(get_notification_manager),
262277
):
263278
"""
264279
This endpoint removes all users from a given group.
@@ -273,6 +288,20 @@ async def delete_batch_membership(
273288
if group_db is None:
274289
raise HTTPException(status_code=400, detail="Invalid group_id")
275290

291+
# To remove a user from a group, we should unsubscribe this user from all
292+
# topic that required to be a member of this group
293+
restricted_topics = await get_topics_restricted_to_group_id(
294+
group_id=batch_membership.group_id,
295+
db=db,
296+
)
297+
for topic in restricted_topics:
298+
for user_from_group in group_db.members:
299+
await notification_manager.unsubscribe_user_to_topic(
300+
topic_id=topic.id,
301+
user_id=user_from_group.id,
302+
db=db,
303+
)
304+
276305
hyperion_security_logger.warning(
277306
f"Create_batch_membership: Admin user {user.id} ({user.name}) removed all users from group {group_db.id} ({group_db.name}) in batch ({request_id})",
278307
)

app/core/notification/cruds_notification.py

Lines changed: 70 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,66 @@
11
from collections.abc import Sequence
22
from datetime import date
3+
from uuid import UUID
34

45
from sqlalchemy import delete, select, update
56
from sqlalchemy.ext.asyncio import AsyncSession
67

78
from app.core.notification import models_notification
8-
from app.core.notification.notification_types import CustomTopic, Topic
9+
10+
11+
async def get_notification_topic(
12+
db: AsyncSession,
13+
) -> Sequence[models_notification.NotificationTopic]:
14+
result = await db.execute(select(models_notification.NotificationTopic))
15+
return result.scalars().all()
16+
17+
18+
async def get_notification_topic_by_id(
19+
topic_id: UUID,
20+
db: AsyncSession,
21+
) -> models_notification.NotificationTopic | None:
22+
result = await db.execute(
23+
select(models_notification.NotificationTopic).where(
24+
models_notification.NotificationTopic.id == topic_id,
25+
),
26+
)
27+
return result.scalars().first()
28+
29+
30+
async def get_topics_restricted_to_group_id(
31+
group_id: str,
32+
db: AsyncSession,
33+
) -> Sequence[models_notification.NotificationTopic]:
34+
result = await db.execute(
35+
select(models_notification.NotificationTopic).where(
36+
models_notification.NotificationTopic.restrict_to_group_id == group_id,
37+
),
38+
)
39+
return result.scalars().all()
40+
41+
42+
async def get_notification_topic_by_root_and_identifier(
43+
module_root: str,
44+
topic_identifier: str | None,
45+
db: AsyncSession,
46+
) -> models_notification.NotificationTopic | None:
47+
result = await db.execute(
48+
select(models_notification.NotificationTopic).where(
49+
models_notification.NotificationTopic.module_root == module_root,
50+
models_notification.NotificationTopic.topic_identifier == topic_identifier,
51+
),
52+
)
53+
return result.scalars().first()
54+
55+
56+
async def create_notification_topic(
57+
notification_topic: models_notification.NotificationTopic,
58+
db: AsyncSession,
59+
) -> None:
60+
"""Register a new topic in database and return it"""
61+
62+
db.add(notification_topic)
63+
await db.flush()
964

1065

1166
async def get_firebase_devices_by_user_id(
@@ -111,29 +166,25 @@ async def create_topic_membership(
111166

112167
async def delete_topic_membership(
113168
user_id: str,
114-
custom_topic: CustomTopic,
169+
topic_id: UUID,
115170
db: AsyncSession,
116171
):
117172
await db.execute(
118173
delete(models_notification.TopicMembership).where(
119174
models_notification.TopicMembership.user_id == user_id,
120-
models_notification.TopicMembership.topic == custom_topic.topic,
121-
models_notification.TopicMembership.topic_identifier
122-
== custom_topic.topic_identifier,
175+
models_notification.TopicMembership.topic_id == topic_id,
123176
),
124177
)
125178
await db.flush()
126179

127180

128-
async def get_topic_memberships_by_topic(
129-
custom_topic: CustomTopic,
181+
async def get_topic_memberships_by_topic_id(
182+
topic_id: str,
130183
db: AsyncSession,
131184
) -> Sequence[models_notification.TopicMembership]:
132185
result = await db.execute(
133186
select(models_notification.TopicMembership).where(
134-
models_notification.TopicMembership.topic == custom_topic.topic,
135-
models_notification.TopicMembership.topic_identifier
136-
== custom_topic.topic_identifier,
187+
models_notification.TopicMembership.topic_id == topic_id,
137188
),
138189
)
139190
return result.scalars().all()
@@ -151,46 +202,41 @@ async def get_topic_memberships_by_user_id(
151202
return result.scalars().all()
152203

153204

154-
async def get_topic_memberships_with_identifiers_by_user_id_and_topic(
205+
async def get_topic_memberships_with_identifiers_by_user_id_and_topic_id(
155206
user_id: str,
156-
topic: Topic,
207+
topic_id: str,
157208
db: AsyncSession,
158209
) -> Sequence[models_notification.TopicMembership]:
159210
result = await db.execute(
160211
select(models_notification.TopicMembership).where(
161212
models_notification.TopicMembership.user_id == user_id,
162-
models_notification.TopicMembership.topic == topic,
163-
models_notification.TopicMembership.topic_identifier != "",
213+
models_notification.TopicMembership.topic_id == topic_id,
164214
),
165215
)
166216
return result.scalars().all()
167217

168218

169-
async def get_topic_membership_by_user_id_and_custom_topic(
219+
async def get_topic_membership_by_user_id_and_topic_id(
170220
user_id: str,
171-
custom_topic: CustomTopic,
221+
topic_id: UUID,
172222
db: AsyncSession,
173223
) -> models_notification.TopicMembership | None:
174224
result = await db.execute(
175225
select(models_notification.TopicMembership).where(
176226
models_notification.TopicMembership.user_id == user_id,
177-
models_notification.TopicMembership.topic == custom_topic.topic,
178-
models_notification.TopicMembership.topic_identifier
179-
== custom_topic.topic_identifier,
227+
models_notification.TopicMembership.topic_id == topic_id,
180228
),
181229
)
182230
return result.scalars().first()
183231

184232

185-
async def get_user_ids_by_topic(
186-
custom_topic: CustomTopic,
233+
async def get_user_ids_by_topic_id(
234+
topic_id: UUID,
187235
db: AsyncSession,
188236
) -> list[str]:
189237
result = await db.execute(
190238
select(models_notification.TopicMembership.user_id).where(
191-
models_notification.TopicMembership.topic == custom_topic.topic,
192-
models_notification.TopicMembership.topic_identifier
193-
== custom_topic.topic_identifier,
239+
models_notification.TopicMembership.topic_id == topic_id,
194240
),
195241
)
196242
return list(result.scalars().all())

0 commit comments

Comments
 (0)