Skip to content

Commit 06a8b30

Browse files
Merge pull request #36 from tahzeer/1.0
Refactor data loading into external and fallback phases
2 parents 31e5bd6 + 8c46b42 commit 06a8b30

File tree

6 files changed

+123
-281
lines changed

6 files changed

+123
-281
lines changed

iam-staff-portal-api/src/iam_staff_portal_api/app.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
StaffRole,
2626
StaffRolePermission,
2727
)
28-
from .data import run_data_loader
28+
from .data import DataLoader
2929

3030

3131
class Initializer(AuthInitializer):
@@ -49,6 +49,6 @@ async def migrate():
4949
await StaffApplicationPermission.create_migrate()
5050
await StaffRolePermission.create_migrate()
5151

52-
await run_data_loader()
52+
await DataLoader.run()
5353

5454
asyncio.run(migrate())

iam-staff-portal-api/src/iam_staff_portal_api/config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,5 @@ class Settings(BaseSettings):
1212
env_nested_delimiter="__",
1313
)
1414
auth_api_get_staff_portal_applications: ApiAuthSettings = ApiAuthSettings(enabled=True)
15-
cache_expire_seconds: int = 7*24*60*60 # 7 days in seconds
15+
cache_expire_seconds: int = 7*24*60*60 # 7 days
16+
data_dir: str = "/opt/iam-staff-portal-data"
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
from .data_loader import DataLoader, run_data_loader
1+
from .data_loader import DataLoader, DataLoaderBase
22

3-
__all__ = ["DataLoader", "run_data_loader"]
3+
__all__ = ["DataLoader", "DataLoaderBase"]
Lines changed: 117 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import json
2+
import logging
3+
from abc import ABC
24
from datetime import date, datetime
35
from pathlib import Path
46
from typing import Any
@@ -9,124 +11,145 @@
911
from iam_core.models import LoginProvider
1012
from openg2p_fastapi_common.context import dbengine
1113

14+
from ..config import Settings
1215
from ..models import (
1316
StaffApplicationPermission,
1417
StaffPortalApplication,
1518
StaffRole,
1619
StaffRolePermission,
1720
)
1821

22+
_logger = logging.getLogger("iam-staff-data-loader")
1923

20-
class DataLoader:
21-
def __init__(self):
22-
self._data_dir = Path(__file__).resolve().parent
2324

24-
def load_login_providers(self) -> list[dict[str, Any]]:
25-
return self._load_dataset("login_providers.json")
25+
class DataLoaderBase(ABC):
26+
data_models = (
27+
LoginProvider,
28+
StaffPortalApplication,
29+
StaffRole,
30+
StaffApplicationPermission,
31+
StaffRolePermission,
32+
)
2633

27-
def load_staff_portal_applications(self) -> list[dict[str, Any]]:
28-
return self._load_dataset("staff_portal_applications.json")
34+
def get_mounted_data_dir(self) -> Path:
35+
return Path(Settings.get_config(strict=False).data_dir)
2936

30-
def load_staff_roles(self) -> list[dict[str, Any]]:
31-
return self._load_dataset("staff_roles.json")
37+
def get_fallback_data_dir(self) -> Path:
38+
return Path(__file__).resolve().parent
3239

33-
def load_staff_application_permissions(self) -> list[dict[str, Any]]:
34-
return self._load_dataset("staff_application_permissions.json")
40+
def get_dataset_path(self, model, data_dir: Path) -> Path:
41+
return data_dir / f"{model.__tablename__}.json"
3542

36-
def load_staff_role_permissions(self) -> list[dict[str, Any]]:
37-
return self._load_dataset("staff_role_permissions.json")
38-
39-
def _load_dataset(
43+
def load_dataset(
4044
self,
41-
default_filename: str,
45+
model,
46+
data_dir: Path,
4247
) -> list[dict[str, Any]]:
43-
raw_value = (self._data_dir / default_filename).read_text(encoding="utf-8")
48+
dataset_path = self.get_dataset_path(model, data_dir)
49+
if not dataset_path.exists():
50+
return []
51+
52+
raw_value = dataset_path.read_text(encoding="utf-8")
4453

4554
try:
4655
payload = json.loads(raw_value)
4756
except json.JSONDecodeError as exc:
48-
raise ValueError(
49-
f"Invalid JSON in {default_filename}: {exc.msg}"
50-
) from exc
57+
raise ValueError(f"Invalid JSON in {dataset_path}: {exc.msg}") from exc
5158

5259
if not isinstance(payload, list):
53-
raise ValueError(f"{default_filename} must be a JSON array of objects")
60+
raise ValueError(f"{dataset_path} must be a JSON array of objects")
5461

5562
if any(not isinstance(row, dict) for row in payload):
56-
raise ValueError(f"{default_filename} must contain only JSON objects")
63+
raise ValueError(f"{dataset_path} must contain only JSON objects")
5764

5865
return payload
5966

67+
async def seed_models_from_dir(
68+
self,
69+
session: AsyncSession,
70+
data_dir: Path,
71+
) -> None:
72+
if not data_dir.exists() or not data_dir.is_dir():
73+
_logger.info("Skipping missing data directory: %s", data_dir)
74+
return
75+
76+
_logger.info("Loading data from %s", data_dir)
77+
78+
for model in self.data_models:
79+
rows = self.load_dataset(model, data_dir)
80+
await self.seed_if_empty(session, model, rows)
81+
82+
async def seed_if_empty(
83+
self,
84+
session: AsyncSession,
85+
model,
86+
rows: list[dict[str, Any]],
87+
) -> None:
88+
row_count = await session.scalar(select(func.count()).select_from(model))
89+
if row_count and row_count > 0:
90+
return
91+
92+
if not rows:
93+
return
94+
95+
_logger.info("Seeding %s with %s rows", model.__tablename__, len(rows))
96+
await session.execute(insert(model), self.coerce_rows_for_model(model, rows))
97+
98+
def coerce_rows_for_model(
99+
self,
100+
model,
101+
rows: list[dict[str, Any]],
102+
) -> list[dict[str, Any]]:
103+
datetime_columns: set[str] = set()
104+
date_columns: set[str] = set()
105+
106+
for column in model.__table__.columns:
107+
if isinstance(column.type, DateTime):
108+
datetime_columns.add(column.name)
109+
elif isinstance(column.type, Date):
110+
date_columns.add(column.name)
111+
112+
coerced_rows: list[dict[str, Any]] = []
113+
for row in rows:
114+
coerced = dict(row)
115+
116+
for column_name in datetime_columns:
117+
if column_name in {"created_at", "updated_at"}:
118+
coerced.pop(column_name, None)
119+
continue
120+
121+
value = coerced.get(column_name)
122+
if isinstance(value, str):
123+
coerced[column_name] = datetime.fromisoformat(value)
124+
125+
for column_name in date_columns:
126+
value = coerced.get(column_name)
127+
if isinstance(value, str):
128+
coerced[column_name] = date.fromisoformat(value)
129+
130+
coerced_rows.append(coerced)
131+
132+
return coerced_rows
133+
134+
def create_session_factory(self) -> async_sessionmaker[AsyncSession]:
135+
return async_sessionmaker(dbengine.get(), expire_on_commit=False)
136+
137+
138+
class DataLoader(DataLoaderBase):
139+
@classmethod
140+
async def run(cls) -> None:
141+
loader = cls()
142+
session_factory = loader.create_session_factory()
143+
144+
_logger.info("Starting IAM staff data loader")
145+
async with session_factory() as session:
146+
await loader.load_data(session)
147+
await loader.load_fallback_data(session)
148+
await session.commit()
149+
_logger.info("Completed IAM staff data loader")
150+
151+
async def load_data(self, session: AsyncSession) -> None:
152+
await self.seed_models_from_dir(session, self.get_mounted_data_dir())
60153

61-
async def run_data_loader() -> None:
62-
loader = DataLoader()
63-
64-
async_session = async_sessionmaker(dbengine.get(), expire_on_commit=False)
65-
async with async_session() as session:
66-
await _seed_if_empty(session, LoginProvider, loader.load_login_providers())
67-
await _seed_if_empty(
68-
session,
69-
StaffPortalApplication,
70-
loader.load_staff_portal_applications(),
71-
)
72-
await _seed_if_empty(session, StaffRole, loader.load_staff_roles())
73-
await _seed_if_empty(
74-
session,
75-
StaffApplicationPermission,
76-
loader.load_staff_application_permissions(),
77-
)
78-
await _seed_if_empty(
79-
session,
80-
StaffRolePermission,
81-
loader.load_staff_role_permissions(),
82-
)
83-
await session.commit()
84-
85-
86-
async def _seed_if_empty(
87-
session: AsyncSession,
88-
model,
89-
rows: list[dict],
90-
) -> None:
91-
row_count = await session.scalar(select(func.count()).select_from(model))
92-
if row_count and row_count > 0:
93-
return
94-
95-
if not rows:
96-
return
97-
98-
await session.execute(insert(model), _coerce_rows_for_model(model, rows))
99-
100-
101-
def _coerce_rows_for_model(model, rows: list[dict]) -> list[dict]:
102-
datetime_columns: set[str] = set()
103-
date_columns: set[str] = set()
104-
105-
for column in model.__table__.columns:
106-
if isinstance(column.type, DateTime):
107-
datetime_columns.add(column.name)
108-
elif isinstance(column.type, Date):
109-
date_columns.add(column.name)
110-
111-
coerced_rows: list[dict] = []
112-
for row in rows:
113-
coerced = dict(row)
114-
115-
for column_name in datetime_columns:
116-
# Let ORM/DB defaults populate timestamp audit fields.
117-
if column_name in {"created_at", "updated_at"}:
118-
coerced.pop(column_name, None)
119-
continue
120-
121-
value = coerced.get(column_name)
122-
if isinstance(value, str):
123-
coerced[column_name] = datetime.fromisoformat(value)
124-
125-
for column_name in date_columns:
126-
value = coerced.get(column_name)
127-
if isinstance(value, str):
128-
coerced[column_name] = date.fromisoformat(value)
129-
130-
coerced_rows.append(coerced)
131-
132-
return coerced_rows
154+
async def load_fallback_data(self, session: AsyncSession) -> None:
155+
await self.seed_models_from_dir(session, self.get_fallback_data_dir())

iam-staff-portal-api/src/iam_staff_portal_api/data/login_providers.json

Lines changed: 0 additions & 60 deletions
This file was deleted.

0 commit comments

Comments
 (0)