Skip to content

Commit 531f96b

Browse files
committed
get_joined and get_multi_joined methods created
1 parent d789c38 commit 531f96b

File tree

3 files changed

+289
-7
lines changed

3 files changed

+289
-7
lines changed

src/app/crud/crud_base.py

Lines changed: 223 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,16 @@
22
from datetime import datetime
33

44
from pydantic import BaseModel
5-
from sqlalchemy import select, update, delete, func, and_
5+
from sqlalchemy import select, update, delete, func, and_, inspect
6+
from sqlalchemy.sql import Join
67
from sqlalchemy.ext.asyncio import AsyncSession
78
from sqlalchemy.engine.row import Row
89

9-
from app.core.models import TimestampModel
1010
from .helper import (
1111
_extract_matching_columns_from_schema,
12-
_extract_matching_columns_from_kwargs
12+
_extract_matching_columns_from_kwargs,
13+
_auto_detect_join_condition,
14+
_add_column_with_prefix
1315
)
1416

1517
ModelType = TypeVar("ModelType")
@@ -191,7 +193,224 @@ async def get_multi(
191193
total_count = await self.count(db=db, **kwargs)
192194

193195
return {"data": data, "total_count": total_count}
194-
196+
197+
async def get_joined(
198+
self,
199+
db: AsyncSession,
200+
join_model: Type[ModelType],
201+
join_prefix: str | None = None,
202+
join_on: Union[Join, None] = None,
203+
schema_to_select: Union[Type[BaseModel], List, None] = None,
204+
join_schema_to_select: Union[Type[BaseModel], List, None] = None,
205+
join_type: str = "left",
206+
**kwargs
207+
) -> Dict | None:
208+
"""
209+
Fetches a single record with a join on another model. If 'join_on' is not provided, the method attempts
210+
to automatically detect the join condition using foreign key relationships.
211+
212+
Parameters
213+
----------
214+
db : AsyncSession
215+
The SQLAlchemy async session.
216+
join_model : Type[ModelType]
217+
The model to join with.
218+
join_prefix : Optional[str]
219+
Optional prefix to be added to all columns of the joined model. If None, no prefix is added.
220+
join_on : Join, optional
221+
SQLAlchemy Join object for specifying the ON clause of the join. If None, the join condition is
222+
auto-detected based on foreign keys.
223+
schema_to_select : Union[Type[BaseModel], List, None], optional
224+
Pydantic schema for selecting specific columns from the primary model.
225+
join_schema_to_select : Union[Type[BaseModel], List, None], optional
226+
Pydantic schema for selecting specific columns from the joined model.
227+
join_type : str, default "left"
228+
Specifies the type of join operation to perform. Can be "left" for a left outer join or "inner" for an inner join.
229+
kwargs : dict
230+
Filters to apply to the query.
231+
232+
Returns
233+
-------
234+
Dict | None
235+
The fetched database row or None if not found.
236+
237+
Examples
238+
--------
239+
Simple example: Joining User and Tier models without explicitly providing join_on
240+
```python
241+
result = await crud_user.get_joined(
242+
db=session,
243+
join_model=Tier,
244+
schema_to_select=UserSchema,
245+
join_schema_to_select=TierSchema
246+
)
247+
```
248+
249+
Complex example: Joining with a custom join condition, additional filter parameters, and a prefix
250+
```python
251+
from sqlalchemy import and_
252+
result = await crud_user.get_joined(
253+
db=session,
254+
join_model=Tier,
255+
join_prefix="tier_",
256+
join_on=and_(User.tier_id == Tier.id, User.is_superuser == True),
257+
schema_to_select=UserSchema,
258+
join_schema_to_select=TierSchema,
259+
username="john_doe"
260+
)
261+
```
262+
263+
Return example: prefix added, no schema_to_select or join_schema_to_select
264+
```python
265+
{
266+
"id": 1,
267+
"name": "John Doe",
268+
"username": "john_doe",
269+
"email": "[email protected]",
270+
"hashed_password": "hashed_password_example",
271+
"profile_image_url": "https://profileimageurl.com/default.jpg",
272+
"uuid": "123e4567-e89b-12d3-a456-426614174000",
273+
"created_at": "2023-01-01T12:00:00",
274+
"updated_at": "2023-01-02T12:00:00",
275+
"deleted_at": null,
276+
"is_deleted": false,
277+
"is_superuser": false,
278+
"tier_id": 2,
279+
"tier_name": "Premium",
280+
"tier_created_at": "2022-12-01T10:00:00",
281+
"tier_updated_at": "2023-01-01T11:00:00"
282+
}
283+
```
284+
"""
285+
if join_on is None:
286+
join_on = _auto_detect_join_condition(self._model, join_model)
287+
288+
primary_select = _extract_matching_columns_from_schema(model=self._model, schema=schema_to_select)
289+
join_select = []
290+
291+
if join_schema_to_select:
292+
columns = _extract_matching_columns_from_schema(model=join_model, schema=join_schema_to_select)
293+
else:
294+
columns = inspect(join_model).c
295+
296+
for column in columns:
297+
labeled_column = _add_column_with_prefix(column, join_prefix)
298+
if f"{join_prefix}{column.name}" not in [col.name for col in primary_select]:
299+
join_select.append(labeled_column)
300+
301+
if join_type == "left":
302+
stmt = select(*primary_select, *join_select).outerjoin(join_model, join_on)
303+
elif join_type == "inner":
304+
stmt = select(*primary_select, *join_select).join(join_model, join_on)
305+
else:
306+
raise ValueError(f"Invalid join type: {join_type}. Only 'left' or 'inner' are valid.")
307+
308+
for key, value in kwargs.items():
309+
if hasattr(self._model, key):
310+
print(self._model)
311+
stmt = stmt.where(getattr(self._model, key) == value)
312+
313+
db_row = await db.execute(stmt)
314+
result = db_row.first()
315+
if result:
316+
result = dict(result._mapping)
317+
318+
return result
319+
320+
async def get_multi_joined(
321+
self,
322+
db: AsyncSession,
323+
join_model: Type[ModelType],
324+
join_prefix: str | None = None,
325+
join_on: Union[Join, None] = None,
326+
schema_to_select: Union[Type[BaseModel], List[Type[BaseModel]], None] = None,
327+
join_schema_to_select: Union[Type[BaseModel], List[Type[BaseModel]], None] = None,
328+
join_type: str = "left",
329+
offset: int = 0,
330+
limit: int = 100,
331+
**kwargs: Any
332+
) -> Dict[str, Any]:
333+
"""
334+
Fetch multiple records with a join on another model, allowing for pagination.
335+
336+
Parameters
337+
----------
338+
db : AsyncSession
339+
The SQLAlchemy async session.
340+
join_model : Type[ModelType]
341+
The model to join with.
342+
join_prefix : Optional[str]
343+
Optional prefix to be added to all columns of the joined model. If None, no prefix is added.
344+
join_on : Join, optional
345+
SQLAlchemy Join object for specifying the ON clause of the join. If None, the join condition is
346+
auto-detected based on foreign keys.
347+
schema_to_select : Union[Type[BaseModel], List[Type[BaseModel]], None], optional
348+
Pydantic schema for selecting specific columns from the primary model.
349+
join_schema_to_select : Union[Type[BaseModel], List[Type[BaseModel]], None], optional
350+
Pydantic schema for selecting specific columns from the joined model.
351+
join_type : str, default "left"
352+
Specifies the type of join operation to perform. Can be "left" for a left outer join or "inner" for an inner join.
353+
offset : int, default 0
354+
The offset (number of records to skip) for pagination.
355+
limit : int, default 100
356+
The limit (maximum number of records to return) for pagination.
357+
kwargs : dict
358+
Filters to apply to the primary query.
359+
360+
Returns
361+
-------
362+
Dict[str, Any]
363+
A dictionary containing the fetched rows under 'data' key and total count under 'total_count'.
364+
365+
Examples
366+
--------
367+
# Fetching multiple User records joined with Tier records, using left join
368+
users = await crud_user.get_multi_joined(
369+
db=session,
370+
join_model=Tier,
371+
join_prefix="tier_",
372+
schema_to_select=UserSchema,
373+
join_schema_to_select=TierSchema,
374+
offset=0,
375+
limit=10
376+
)
377+
"""
378+
if join_on is None:
379+
join_on = _auto_detect_join_condition(self._model, join_model)
380+
381+
primary_select = _extract_matching_columns_from_schema(model=self._model, schema=schema_to_select)
382+
join_select = []
383+
384+
if join_schema_to_select:
385+
columns = _extract_matching_columns_from_schema(model=join_model, schema=join_schema_to_select)
386+
else:
387+
columns = inspect(join_model).c
388+
389+
for column in columns:
390+
labeled_column = _add_column_with_prefix(column, join_prefix)
391+
if f"{join_prefix}{column.name}" not in [col.name for col in primary_select]:
392+
join_select.append(labeled_column)
393+
394+
if join_type == "left":
395+
stmt = select(*primary_select, *join_select).outerjoin(join_model, join_on)
396+
elif join_type == "inner":
397+
stmt = select(*primary_select, *join_select).join(join_model, join_on)
398+
else:
399+
raise ValueError(f"Invalid join type: {join_type}. Only 'left' or 'inner' are valid.")
400+
401+
for key, value in kwargs.items():
402+
if hasattr(self._model, key):
403+
stmt = stmt.where(getattr(self._model, key) == value)
404+
405+
stmt = stmt.offset(offset).limit(limit)
406+
407+
db_rows = await db.execute(stmt)
408+
data = [dict(row._mapping) for row in db_rows]
409+
410+
total_count = await self.count(db=db, **kwargs)
411+
412+
return {"data": data, "total_count": total_count}
413+
195414
async def update(
196415
self,
197416
db: AsyncSession,

src/app/crud/helper.py

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
1-
from typing import Any, List, Type, Union
1+
from typing import Any, List, Type, Union, Optional
2+
from sqlalchemy import inspect
3+
from sqlalchemy.orm import DeclarativeMeta
4+
from sqlalchemy.sql import ColumnElement
5+
from sqlalchemy.sql.schema import Column
6+
from sqlalchemy.sql.elements import Label
27

38
from pydantic import BaseModel
49

@@ -53,3 +58,61 @@ def _extract_matching_columns_from_column_names(model: Type[Base], column_names:
5358
column_list.append(getattr(model, column_name))
5459

5560
return column_list
61+
62+
def _auto_detect_join_condition(base_model: Type[DeclarativeMeta], join_model: Type[DeclarativeMeta]) -> Optional[ColumnElement]:
63+
"""
64+
Automatically detects the join condition for SQLAlchemy models based on foreign key relationships.
65+
This function scans the foreign keys in the base model and tries to match them with columns in the join model.
66+
67+
Parameters
68+
----------
69+
base_model : Type[DeclarativeMeta]
70+
The base SQLAlchemy model from which to join.
71+
join_model : Type[DeclarativeMeta]
72+
The SQLAlchemy model to join with the base model.
73+
74+
Returns
75+
-------
76+
Optional[ColumnElement]
77+
A SQLAlchemy ColumnElement representing the join condition, if successfully detected.
78+
79+
Raises
80+
------
81+
ValueError
82+
If the join condition cannot be automatically determined, a ValueError is raised.
83+
84+
Example
85+
-------
86+
# Assuming User has a foreign key reference to Tier:
87+
join_condition = auto_detect_join_condition(User, Tier)
88+
"""
89+
fk_columns = [col for col in inspect(base_model).c if col.foreign_keys]
90+
join_on = next(
91+
(base_model.__table__.c[col.name] == join_model.__table__.c[list(col.foreign_keys)[0].column.name]
92+
for col in fk_columns if list(col.foreign_keys)[0].column.table == join_model.__table__),
93+
None
94+
)
95+
96+
if join_on is None:
97+
raise ValueError("Could not automatically determine join condition. Please provide join_on.")
98+
99+
return join_on
100+
101+
def _add_column_with_prefix(column: Column, prefix: Optional[str]) -> Label:
102+
"""
103+
Creates a SQLAlchemy column label with an optional prefix.
104+
105+
Parameters
106+
----------
107+
column : Column
108+
The SQLAlchemy Column object to be labeled.
109+
prefix : Optional[str]
110+
An optional prefix to prepend to the column's name.
111+
112+
Returns
113+
-------
114+
Label
115+
A labeled SQLAlchemy Column object.
116+
"""
117+
column_label = f"{prefix}{column.name}" if prefix else column.name
118+
return column.label(column_label)

src/app/models/user.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import uuid as uuid_pkg
33
from datetime import datetime
44

5-
from sqlalchemy import String, DateTime
5+
from sqlalchemy import String, DateTime, ForeignKey
66
from sqlalchemy.orm import Mapped, mapped_column
77

88
from app.core.database import Base
@@ -31,4 +31,4 @@ class User(Base):
3131
is_deleted: Mapped[bool] = mapped_column(default=False, index=True)
3232
is_superuser: Mapped[bool] = mapped_column(default=False)
3333

34-
tier_id: Mapped[int | None] = mapped_column(index=True, default=None, init=False)
34+
tier_id: Mapped[int | None] = mapped_column(ForeignKey('tier.id'), index=True, default=None, init=False)

0 commit comments

Comments
 (0)