|
1 | 1 | import json |
| 2 | +import logging |
| 3 | +from abc import ABC |
2 | 4 | from datetime import date, datetime |
3 | 5 | from pathlib import Path |
4 | 6 | from typing import Any |
|
9 | 11 | from iam_core.models import LoginProvider |
10 | 12 | from openg2p_fastapi_common.context import dbengine |
11 | 13 |
|
| 14 | +from ..config import Settings |
12 | 15 | from ..models import ( |
13 | 16 | StaffApplicationPermission, |
14 | 17 | StaffPortalApplication, |
15 | 18 | StaffRole, |
16 | 19 | StaffRolePermission, |
17 | 20 | ) |
18 | 21 |
|
| 22 | +_logger = logging.getLogger("iam-staff-data-loader") |
19 | 23 |
|
20 | | -class DataLoader: |
21 | | - def __init__(self): |
22 | | - self._data_dir = Path(__file__).resolve().parent |
23 | 24 |
|
24 | | - def load_login_providers(self) -> list[dict[str, Any]]: |
25 | | - return self._load_dataset("login_providers.json") |
| 25 | +class DataLoaderBase(ABC): |
| 26 | + data_models = ( |
| 27 | + LoginProvider, |
| 28 | + StaffPortalApplication, |
| 29 | + StaffRole, |
| 30 | + StaffApplicationPermission, |
| 31 | + StaffRolePermission, |
| 32 | + ) |
26 | 33 |
|
27 | | - def load_staff_portal_applications(self) -> list[dict[str, Any]]: |
28 | | - return self._load_dataset("staff_portal_applications.json") |
| 34 | + def get_mounted_data_dir(self) -> Path: |
| 35 | + return Path(Settings.get_config(strict=False).data_dir) |
29 | 36 |
|
30 | | - def load_staff_roles(self) -> list[dict[str, Any]]: |
31 | | - return self._load_dataset("staff_roles.json") |
| 37 | + def get_fallback_data_dir(self) -> Path: |
| 38 | + return Path(__file__).resolve().parent |
32 | 39 |
|
33 | | - def load_staff_application_permissions(self) -> list[dict[str, Any]]: |
34 | | - return self._load_dataset("staff_application_permissions.json") |
| 40 | + def get_dataset_path(self, model, data_dir: Path) -> Path: |
| 41 | + return data_dir / f"{model.__tablename__}.json" |
35 | 42 |
|
36 | | - def load_staff_role_permissions(self) -> list[dict[str, Any]]: |
37 | | - return self._load_dataset("staff_role_permissions.json") |
38 | | - |
39 | | - def _load_dataset( |
| 43 | + def load_dataset( |
40 | 44 | self, |
41 | | - default_filename: str, |
| 45 | + model, |
| 46 | + data_dir: Path, |
42 | 47 | ) -> list[dict[str, Any]]: |
43 | | - raw_value = (self._data_dir / default_filename).read_text(encoding="utf-8") |
| 48 | + dataset_path = self.get_dataset_path(model, data_dir) |
| 49 | + if not dataset_path.exists(): |
| 50 | + return [] |
| 51 | + |
| 52 | + raw_value = dataset_path.read_text(encoding="utf-8") |
44 | 53 |
|
45 | 54 | try: |
46 | 55 | payload = json.loads(raw_value) |
47 | 56 | except json.JSONDecodeError as exc: |
48 | | - raise ValueError( |
49 | | - f"Invalid JSON in {default_filename}: {exc.msg}" |
50 | | - ) from exc |
| 57 | + raise ValueError(f"Invalid JSON in {dataset_path}: {exc.msg}") from exc |
51 | 58 |
|
52 | 59 | if not isinstance(payload, list): |
53 | | - raise ValueError(f"{default_filename} must be a JSON array of objects") |
| 60 | + raise ValueError(f"{dataset_path} must be a JSON array of objects") |
54 | 61 |
|
55 | 62 | if any(not isinstance(row, dict) for row in payload): |
56 | | - raise ValueError(f"{default_filename} must contain only JSON objects") |
| 63 | + raise ValueError(f"{dataset_path} must contain only JSON objects") |
57 | 64 |
|
58 | 65 | return payload |
59 | 66 |
|
| 67 | + async def seed_models_from_dir( |
| 68 | + self, |
| 69 | + session: AsyncSession, |
| 70 | + data_dir: Path, |
| 71 | + ) -> None: |
| 72 | + if not data_dir.exists() or not data_dir.is_dir(): |
| 73 | + _logger.info("Skipping missing data directory: %s", data_dir) |
| 74 | + return |
| 75 | + |
| 76 | + _logger.info("Loading data from %s", data_dir) |
| 77 | + |
| 78 | + for model in self.data_models: |
| 79 | + rows = self.load_dataset(model, data_dir) |
| 80 | + await self.seed_if_empty(session, model, rows) |
| 81 | + |
| 82 | + async def seed_if_empty( |
| 83 | + self, |
| 84 | + session: AsyncSession, |
| 85 | + model, |
| 86 | + rows: list[dict[str, Any]], |
| 87 | + ) -> None: |
| 88 | + row_count = await session.scalar(select(func.count()).select_from(model)) |
| 89 | + if row_count and row_count > 0: |
| 90 | + return |
| 91 | + |
| 92 | + if not rows: |
| 93 | + return |
| 94 | + |
| 95 | + _logger.info("Seeding %s with %s rows", model.__tablename__, len(rows)) |
| 96 | + await session.execute(insert(model), self.coerce_rows_for_model(model, rows)) |
| 97 | + |
| 98 | + def coerce_rows_for_model( |
| 99 | + self, |
| 100 | + model, |
| 101 | + rows: list[dict[str, Any]], |
| 102 | + ) -> list[dict[str, Any]]: |
| 103 | + datetime_columns: set[str] = set() |
| 104 | + date_columns: set[str] = set() |
| 105 | + |
| 106 | + for column in model.__table__.columns: |
| 107 | + if isinstance(column.type, DateTime): |
| 108 | + datetime_columns.add(column.name) |
| 109 | + elif isinstance(column.type, Date): |
| 110 | + date_columns.add(column.name) |
| 111 | + |
| 112 | + coerced_rows: list[dict[str, Any]] = [] |
| 113 | + for row in rows: |
| 114 | + coerced = dict(row) |
| 115 | + |
| 116 | + for column_name in datetime_columns: |
| 117 | + if column_name in {"created_at", "updated_at"}: |
| 118 | + coerced.pop(column_name, None) |
| 119 | + continue |
| 120 | + |
| 121 | + value = coerced.get(column_name) |
| 122 | + if isinstance(value, str): |
| 123 | + coerced[column_name] = datetime.fromisoformat(value) |
| 124 | + |
| 125 | + for column_name in date_columns: |
| 126 | + value = coerced.get(column_name) |
| 127 | + if isinstance(value, str): |
| 128 | + coerced[column_name] = date.fromisoformat(value) |
| 129 | + |
| 130 | + coerced_rows.append(coerced) |
| 131 | + |
| 132 | + return coerced_rows |
| 133 | + |
| 134 | + def create_session_factory(self) -> async_sessionmaker[AsyncSession]: |
| 135 | + return async_sessionmaker(dbengine.get(), expire_on_commit=False) |
| 136 | + |
| 137 | + |
| 138 | +class DataLoader(DataLoaderBase): |
| 139 | + @classmethod |
| 140 | + async def run(cls) -> None: |
| 141 | + loader = cls() |
| 142 | + session_factory = loader.create_session_factory() |
| 143 | + |
| 144 | + _logger.info("Starting IAM staff data loader") |
| 145 | + async with session_factory() as session: |
| 146 | + await loader.load_data(session) |
| 147 | + await loader.load_fallback_data(session) |
| 148 | + await session.commit() |
| 149 | + _logger.info("Completed IAM staff data loader") |
| 150 | + |
| 151 | + async def load_data(self, session: AsyncSession) -> None: |
| 152 | + await self.seed_models_from_dir(session, self.get_mounted_data_dir()) |
60 | 153 |
|
61 | | -async def run_data_loader() -> None: |
62 | | - loader = DataLoader() |
63 | | - |
64 | | - async_session = async_sessionmaker(dbengine.get(), expire_on_commit=False) |
65 | | - async with async_session() as session: |
66 | | - await _seed_if_empty(session, LoginProvider, loader.load_login_providers()) |
67 | | - await _seed_if_empty( |
68 | | - session, |
69 | | - StaffPortalApplication, |
70 | | - loader.load_staff_portal_applications(), |
71 | | - ) |
72 | | - await _seed_if_empty(session, StaffRole, loader.load_staff_roles()) |
73 | | - await _seed_if_empty( |
74 | | - session, |
75 | | - StaffApplicationPermission, |
76 | | - loader.load_staff_application_permissions(), |
77 | | - ) |
78 | | - await _seed_if_empty( |
79 | | - session, |
80 | | - StaffRolePermission, |
81 | | - loader.load_staff_role_permissions(), |
82 | | - ) |
83 | | - await session.commit() |
84 | | - |
85 | | - |
86 | | -async def _seed_if_empty( |
87 | | - session: AsyncSession, |
88 | | - model, |
89 | | - rows: list[dict], |
90 | | -) -> None: |
91 | | - row_count = await session.scalar(select(func.count()).select_from(model)) |
92 | | - if row_count and row_count > 0: |
93 | | - return |
94 | | - |
95 | | - if not rows: |
96 | | - return |
97 | | - |
98 | | - await session.execute(insert(model), _coerce_rows_for_model(model, rows)) |
99 | | - |
100 | | - |
101 | | -def _coerce_rows_for_model(model, rows: list[dict]) -> list[dict]: |
102 | | - datetime_columns: set[str] = set() |
103 | | - date_columns: set[str] = set() |
104 | | - |
105 | | - for column in model.__table__.columns: |
106 | | - if isinstance(column.type, DateTime): |
107 | | - datetime_columns.add(column.name) |
108 | | - elif isinstance(column.type, Date): |
109 | | - date_columns.add(column.name) |
110 | | - |
111 | | - coerced_rows: list[dict] = [] |
112 | | - for row in rows: |
113 | | - coerced = dict(row) |
114 | | - |
115 | | - for column_name in datetime_columns: |
116 | | - # Let ORM/DB defaults populate timestamp audit fields. |
117 | | - if column_name in {"created_at", "updated_at"}: |
118 | | - coerced.pop(column_name, None) |
119 | | - continue |
120 | | - |
121 | | - value = coerced.get(column_name) |
122 | | - if isinstance(value, str): |
123 | | - coerced[column_name] = datetime.fromisoformat(value) |
124 | | - |
125 | | - for column_name in date_columns: |
126 | | - value = coerced.get(column_name) |
127 | | - if isinstance(value, str): |
128 | | - coerced[column_name] = date.fromisoformat(value) |
129 | | - |
130 | | - coerced_rows.append(coerced) |
131 | | - |
132 | | - return coerced_rows |
| 154 | + async def load_fallback_data(self, session: AsyncSession) -> None: |
| 155 | + await self.seed_models_from_dir(session, self.get_fallback_data_dir()) |
0 commit comments