|
2 | 2 | from datetime import datetime |
3 | 3 |
|
4 | 4 | from pydantic import BaseModel |
5 | | -from sqlalchemy import select, update, delete, func, and_ |
| 5 | +from sqlalchemy import select, update, delete, func, and_, inspect |
| 6 | +from sqlalchemy.sql import Join |
6 | 7 | from sqlalchemy.ext.asyncio import AsyncSession |
7 | 8 | from sqlalchemy.engine.row import Row |
8 | 9 |
|
9 | | -from app.core.models import TimestampModel |
10 | 10 | from .helper import ( |
11 | 11 | _extract_matching_columns_from_schema, |
12 | | - _extract_matching_columns_from_kwargs |
| 12 | + _extract_matching_columns_from_kwargs, |
| 13 | + _auto_detect_join_condition, |
| 14 | + _add_column_with_prefix |
13 | 15 | ) |
14 | 16 |
|
15 | 17 | ModelType = TypeVar("ModelType") |
@@ -191,7 +193,224 @@ async def get_multi( |
191 | 193 | total_count = await self.count(db=db, **kwargs) |
192 | 194 |
|
193 | 195 | return {"data": data, "total_count": total_count} |
194 | | - |
| 196 | + |
| 197 | + async def get_joined( |
| 198 | + self, |
| 199 | + db: AsyncSession, |
| 200 | + join_model: Type[ModelType], |
| 201 | + join_prefix: str | None = None, |
| 202 | + join_on: Union[Join, None] = None, |
| 203 | + schema_to_select: Union[Type[BaseModel], List, None] = None, |
| 204 | + join_schema_to_select: Union[Type[BaseModel], List, None] = None, |
| 205 | + join_type: str = "left", |
| 206 | + **kwargs |
| 207 | + ) -> Dict | None: |
| 208 | + """ |
| 209 | + Fetches a single record with a join on another model. If 'join_on' is not provided, the method attempts |
| 210 | + to automatically detect the join condition using foreign key relationships. |
| 211 | +
|
| 212 | + Parameters |
| 213 | + ---------- |
| 214 | + db : AsyncSession |
| 215 | + The SQLAlchemy async session. |
| 216 | + join_model : Type[ModelType] |
| 217 | + The model to join with. |
| 218 | + join_prefix : Optional[str] |
| 219 | + Optional prefix to be added to all columns of the joined model. If None, no prefix is added. |
| 220 | + join_on : Join, optional |
| 221 | + SQLAlchemy Join object for specifying the ON clause of the join. If None, the join condition is |
| 222 | + auto-detected based on foreign keys. |
| 223 | + schema_to_select : Union[Type[BaseModel], List, None], optional |
| 224 | + Pydantic schema for selecting specific columns from the primary model. |
| 225 | + join_schema_to_select : Union[Type[BaseModel], List, None], optional |
| 226 | + Pydantic schema for selecting specific columns from the joined model. |
| 227 | + join_type : str, default "left" |
| 228 | + Specifies the type of join operation to perform. Can be "left" for a left outer join or "inner" for an inner join. |
| 229 | + kwargs : dict |
| 230 | + Filters to apply to the query. |
| 231 | +
|
| 232 | + Returns |
| 233 | + ------- |
| 234 | + Dict | None |
| 235 | + The fetched database row or None if not found. |
| 236 | +
|
| 237 | + Examples |
| 238 | + -------- |
| 239 | + Simple example: Joining User and Tier models without explicitly providing join_on |
| 240 | + ```python |
| 241 | + result = await crud_user.get_joined( |
| 242 | + db=session, |
| 243 | + join_model=Tier, |
| 244 | + schema_to_select=UserSchema, |
| 245 | + join_schema_to_select=TierSchema |
| 246 | + ) |
| 247 | + ``` |
| 248 | +
|
| 249 | + Complex example: Joining with a custom join condition, additional filter parameters, and a prefix |
| 250 | + ```python |
| 251 | + from sqlalchemy import and_ |
| 252 | + result = await crud_user.get_joined( |
| 253 | + db=session, |
| 254 | + join_model=Tier, |
| 255 | + join_prefix="tier_", |
| 256 | + join_on=and_(User.tier_id == Tier.id, User.is_superuser == True), |
| 257 | + schema_to_select=UserSchema, |
| 258 | + join_schema_to_select=TierSchema, |
| 259 | + username="john_doe" |
| 260 | + ) |
| 261 | + ``` |
| 262 | +
|
| 263 | + Return example: prefix added, no schema_to_select or join_schema_to_select |
| 264 | + ```python |
| 265 | + { |
| 266 | + "id": 1, |
| 267 | + "name": "John Doe", |
| 268 | + "username": "john_doe", |
| 269 | + |
| 270 | + "hashed_password": "hashed_password_example", |
| 271 | + "profile_image_url": "https://profileimageurl.com/default.jpg", |
| 272 | + "uuid": "123e4567-e89b-12d3-a456-426614174000", |
| 273 | + "created_at": "2023-01-01T12:00:00", |
| 274 | + "updated_at": "2023-01-02T12:00:00", |
| 275 | + "deleted_at": null, |
| 276 | + "is_deleted": false, |
| 277 | + "is_superuser": false, |
| 278 | + "tier_id": 2, |
| 279 | + "tier_name": "Premium", |
| 280 | + "tier_created_at": "2022-12-01T10:00:00", |
| 281 | + "tier_updated_at": "2023-01-01T11:00:00" |
| 282 | + } |
| 283 | + ``` |
| 284 | + """ |
| 285 | + if join_on is None: |
| 286 | + join_on = _auto_detect_join_condition(self._model, join_model) |
| 287 | + |
| 288 | + primary_select = _extract_matching_columns_from_schema(model=self._model, schema=schema_to_select) |
| 289 | + join_select = [] |
| 290 | + |
| 291 | + if join_schema_to_select: |
| 292 | + columns = _extract_matching_columns_from_schema(model=join_model, schema=join_schema_to_select) |
| 293 | + else: |
| 294 | + columns = inspect(join_model).c |
| 295 | + |
| 296 | + for column in columns: |
| 297 | + labeled_column = _add_column_with_prefix(column, join_prefix) |
| 298 | + if f"{join_prefix}{column.name}" not in [col.name for col in primary_select]: |
| 299 | + join_select.append(labeled_column) |
| 300 | + |
| 301 | + if join_type == "left": |
| 302 | + stmt = select(*primary_select, *join_select).outerjoin(join_model, join_on) |
| 303 | + elif join_type == "inner": |
| 304 | + stmt = select(*primary_select, *join_select).join(join_model, join_on) |
| 305 | + else: |
| 306 | + raise ValueError(f"Invalid join type: {join_type}. Only 'left' or 'inner' are valid.") |
| 307 | + |
| 308 | + for key, value in kwargs.items(): |
| 309 | + if hasattr(self._model, key): |
| 310 | + print(self._model) |
| 311 | + stmt = stmt.where(getattr(self._model, key) == value) |
| 312 | + |
| 313 | + db_row = await db.execute(stmt) |
| 314 | + result = db_row.first() |
| 315 | + if result: |
| 316 | + result = dict(result._mapping) |
| 317 | + |
| 318 | + return result |
| 319 | + |
| 320 | + async def get_multi_joined( |
| 321 | + self, |
| 322 | + db: AsyncSession, |
| 323 | + join_model: Type[ModelType], |
| 324 | + join_prefix: str | None = None, |
| 325 | + join_on: Union[Join, None] = None, |
| 326 | + schema_to_select: Union[Type[BaseModel], List[Type[BaseModel]], None] = None, |
| 327 | + join_schema_to_select: Union[Type[BaseModel], List[Type[BaseModel]], None] = None, |
| 328 | + join_type: str = "left", |
| 329 | + offset: int = 0, |
| 330 | + limit: int = 100, |
| 331 | + **kwargs: Any |
| 332 | + ) -> Dict[str, Any]: |
| 333 | + """ |
| 334 | + Fetch multiple records with a join on another model, allowing for pagination. |
| 335 | +
|
| 336 | + Parameters |
| 337 | + ---------- |
| 338 | + db : AsyncSession |
| 339 | + The SQLAlchemy async session. |
| 340 | + join_model : Type[ModelType] |
| 341 | + The model to join with. |
| 342 | + join_prefix : Optional[str] |
| 343 | + Optional prefix to be added to all columns of the joined model. If None, no prefix is added. |
| 344 | + join_on : Join, optional |
| 345 | + SQLAlchemy Join object for specifying the ON clause of the join. If None, the join condition is |
| 346 | + auto-detected based on foreign keys. |
| 347 | + schema_to_select : Union[Type[BaseModel], List[Type[BaseModel]], None], optional |
| 348 | + Pydantic schema for selecting specific columns from the primary model. |
| 349 | + join_schema_to_select : Union[Type[BaseModel], List[Type[BaseModel]], None], optional |
| 350 | + Pydantic schema for selecting specific columns from the joined model. |
| 351 | + join_type : str, default "left" |
| 352 | + Specifies the type of join operation to perform. Can be "left" for a left outer join or "inner" for an inner join. |
| 353 | + offset : int, default 0 |
| 354 | + The offset (number of records to skip) for pagination. |
| 355 | + limit : int, default 100 |
| 356 | + The limit (maximum number of records to return) for pagination. |
| 357 | + kwargs : dict |
| 358 | + Filters to apply to the primary query. |
| 359 | +
|
| 360 | + Returns |
| 361 | + ------- |
| 362 | + Dict[str, Any] |
| 363 | + A dictionary containing the fetched rows under 'data' key and total count under 'total_count'. |
| 364 | +
|
| 365 | + Examples |
| 366 | + -------- |
| 367 | + # Fetching multiple User records joined with Tier records, using left join |
| 368 | + users = await crud_user.get_multi_joined( |
| 369 | + db=session, |
| 370 | + join_model=Tier, |
| 371 | + join_prefix="tier_", |
| 372 | + schema_to_select=UserSchema, |
| 373 | + join_schema_to_select=TierSchema, |
| 374 | + offset=0, |
| 375 | + limit=10 |
| 376 | + ) |
| 377 | + """ |
| 378 | + if join_on is None: |
| 379 | + join_on = _auto_detect_join_condition(self._model, join_model) |
| 380 | + |
| 381 | + primary_select = _extract_matching_columns_from_schema(model=self._model, schema=schema_to_select) |
| 382 | + join_select = [] |
| 383 | + |
| 384 | + if join_schema_to_select: |
| 385 | + columns = _extract_matching_columns_from_schema(model=join_model, schema=join_schema_to_select) |
| 386 | + else: |
| 387 | + columns = inspect(join_model).c |
| 388 | + |
| 389 | + for column in columns: |
| 390 | + labeled_column = _add_column_with_prefix(column, join_prefix) |
| 391 | + if f"{join_prefix}{column.name}" not in [col.name for col in primary_select]: |
| 392 | + join_select.append(labeled_column) |
| 393 | + |
| 394 | + if join_type == "left": |
| 395 | + stmt = select(*primary_select, *join_select).outerjoin(join_model, join_on) |
| 396 | + elif join_type == "inner": |
| 397 | + stmt = select(*primary_select, *join_select).join(join_model, join_on) |
| 398 | + else: |
| 399 | + raise ValueError(f"Invalid join type: {join_type}. Only 'left' or 'inner' are valid.") |
| 400 | + |
| 401 | + for key, value in kwargs.items(): |
| 402 | + if hasattr(self._model, key): |
| 403 | + stmt = stmt.where(getattr(self._model, key) == value) |
| 404 | + |
| 405 | + stmt = stmt.offset(offset).limit(limit) |
| 406 | + |
| 407 | + db_rows = await db.execute(stmt) |
| 408 | + data = [dict(row._mapping) for row in db_rows] |
| 409 | + |
| 410 | + total_count = await self.count(db=db, **kwargs) |
| 411 | + |
| 412 | + return {"data": data, "total_count": total_count} |
| 413 | + |
195 | 414 | async def update( |
196 | 415 | self, |
197 | 416 | db: AsyncSession, |
|
0 commit comments