Skip to content

Commit 6b19cb4

Browse files
committed
crud_base updated to allow querying only necessary columns, ORM ditched for sqlalchemy core for many of the crud functions
1 parent 9387ecb commit 6b19cb4

File tree

4 files changed

+122
-92
lines changed

4 files changed

+122
-92
lines changed

src/app/api/v1/posts.py

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ async def write_post(
3232

3333
post_internal_dict = post.model_dump()
3434
post_internal_dict["created_by_user_id"] = db_user.id
35+
3536
post_internal = PostCreateInternal(**post_internal_dict)
3637
return await crud_posts.create(db=db, object=post_internal)
3738

@@ -70,7 +71,7 @@ async def read_post(
7071
return db_post
7172

7273

73-
@router.patch("/{username}/post/{id}", response_model=PostRead)
74+
@router.patch("/{username}/post/{id}")
7475
@cache(
7576
"{username}_post_cache",
7677
resource_id_name="id",
@@ -95,7 +96,8 @@ async def patch_post(
9596
if db_post is None:
9697
raise HTTPException(status_code=404, detail="Post not found")
9798

98-
return await crud_posts.update(db=db, object=values, db_object=db_post, id=id)
99+
await crud_posts.update(db=db, object=values, id=id)
100+
return {"message": "Post updated"}
99101

100102

101103
@router.delete("/{username}/post/{id}")
@@ -122,13 +124,9 @@ async def erase_post(
122124
if db_post is None:
123125
raise HTTPException(status_code=404, detail="Post not found")
124126

125-
deleted_post = await crud_posts.delete(db=db, db_object=db_post, id=id)
126-
if deleted_post.is_deleted == True:
127-
message = {"message": "Post deleted"}
128-
else:
129-
message = {"message": "Something went wrong"}
130-
131-
return message
127+
await crud_posts.delete(db=db, db_row=db_post, id=id)
128+
129+
return {"message": "Post deleted"}
132130

133131

134132
@router.delete("/{username}/db_post/{id}", dependencies=[Depends(get_current_superuser)])
@@ -152,11 +150,4 @@ async def erase_db_post(
152150
raise HTTPException(status_code=404, detail="Post not found")
153151

154152
await crud_posts.db_delete(db=db, db_object=db_post, id=id)
155-
deleted_post = await crud_posts.get(db=db, id=id)
156-
157-
if deleted_post is None:
158-
message = {"message": "Post deleted"}
159-
else:
160-
message = {"message": "Something went wrong"}
161-
162-
return message
153+
return {"message": "Post deleted from the database"}

src/app/api/v1/users.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,17 @@
1515
router = fastapi.APIRouter(tags=["users"])
1616

1717
@router.post("/user", response_model=UserRead, status_code=201)
18-
async def write_user(request: Request, user: UserCreate, db: Annotated[AsyncSession, Depends(async_get_db)]):
19-
db_user = await crud_users.get(db=db, email=user.email)
20-
if db_user:
18+
async def write_user(
19+
request: Request,
20+
user: UserCreate,
21+
db: Annotated[AsyncSession, Depends(async_get_db)]
22+
):
23+
email_row = await crud_users.exists(db=db, email=user.email)
24+
if email_row:
2125
raise HTTPException(status_code=400, detail="Email is already registered")
2226

23-
db_user = await crud_users.get(db=db, username=user.username)
24-
if db_user:
27+
username_row = await crud_users.exists(db=db, username=user.username)
28+
if username_row:
2529
raise HTTPException(status_code=400, detail="Username not available")
2630

2731
user_internal_dict = user.model_dump()
@@ -56,7 +60,7 @@ async def read_user(request: Request, username: str, db: Annotated[AsyncSession,
5660
return db_user
5761

5862

59-
@router.patch("/user/{username}", response_model=UserRead)
63+
@router.patch("/user/{username}")
6064
async def patch_user(
6165
request: Request,
6266
values: UserUpdate,
@@ -72,17 +76,17 @@ async def patch_user(
7276
raise privileges_exception
7377

7478
if values.username != db_user.username:
75-
existing_username = await crud_users.get(db=db, username=values.username)
76-
if existing_username is not None:
79+
existing_username = await crud_users.exists(db=db, username=values.username)
80+
if existing_username:
7781
raise HTTPException(status_code=400, detail="Username not available")
7882

7983
if values.email != db_user.email:
80-
existing_email = await crud_users.get(db=db, email=values.email)
84+
existing_email = await crud_users.exists(db=db, email=values.email)
8185
if existing_email:
8286
raise HTTPException(status_code=400, detail="Email is already registered")
8387

84-
db_user = await crud_users.update(db=db, object=values, db_object=db_user)
85-
return db_user
88+
await crud_users.update(db=db, object=values, username=username)
89+
return {"message": "User updated"}
8690

8791

8892
@router.delete("/user/{username}")
@@ -99,8 +103,8 @@ async def erase_user(
99103
if db_user.username != current_user.username:
100104
raise privileges_exception
101105

102-
db_user = await crud_users.delete(db=db, db_object=db_user, username=username)
103-
return db_user
106+
await crud_users.delete(db=db, db_row=db_user, username=username)
107+
return {"message": "User deleted"}
104108

105109

106110
@router.delete("/db_user/{username}", dependencies=[Depends(get_current_superuser)])
@@ -109,9 +113,9 @@ async def erase_db_user(
109113
username: str,
110114
db: Annotated[AsyncSession, Depends(async_get_db)]
111115
):
112-
db_user = await crud_users.get(db=db, username=username)
113-
if db_user is None:
116+
db_user = await crud_users.exists(db=db, username=username)
117+
if not db_user:
114118
raise HTTPException(status_code=404, detail="User not found")
115119

116-
db_user = await crud_users.db_delete(db=db, db_object=db_user, username=username)
117-
return db_user
120+
db_user = await crud_users.db_delete(db=db, username=username)
121+
return {"message": "User deleted from the database"}

src/app/crud/crud_base.py

Lines changed: 51 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,11 @@
22
from datetime import datetime
33

44
from pydantic import BaseModel
5-
from sqlalchemy import select, update
5+
from sqlalchemy import select, update, delete
66
from sqlalchemy.ext.asyncio import AsyncSession
7+
from sqlalchemy.engine.row import Row
8+
9+
from .helper import _extract_matching_columns_from_schema, _extract_matching_columns_from_kwargs
710

811
ModelType = TypeVar("ModelType")
912
CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
@@ -24,78 +27,68 @@ async def create(
2427
await db.commit()
2528
return db_object
2629

27-
async def get(self, db: AsyncSession, **kwargs) -> ModelType | None:
28-
query = select(self._model).filter_by(**kwargs)
29-
result = await db.execute(query)
30-
return result.scalar_one_or_none()
30+
async def get(self, db: AsyncSession, schema_to_select: Type[BaseModel] | None = None, **kwargs) -> ModelType | None:
31+
to_select = _extract_matching_columns_from_schema(model=self._model, schema=schema_to_select)
32+
stmt = select(*to_select) \
33+
.filter_by(**kwargs)
34+
35+
result = await db.execute(stmt)
36+
return result.first()
3137

32-
async def get_multi(
33-
self, db: AsyncSession, offset: int = 0, limit: int = 100, **kwargs
34-
) -> List[ModelType]:
35-
query = select(self._model) \
38+
async def get_multi(self, db: AsyncSession, offset: int = 0, limit: int = 100, schema_to_select: Type[BaseModel] | None = None, **kwargs) -> List[ModelType]:
39+
to_select = _extract_matching_columns_from_schema(model=self._model, schema=schema_to_select)
40+
stmt = select(*to_select) \
3641
.filter_by(**kwargs) \
3742
.offset(offset) \
3843
.limit(limit)
3944

40-
result = await db.execute(query)
41-
return result.scalars().all()
42-
43-
async def update(
44-
self,
45-
db: AsyncSession,
46-
object: Union[UpdateSchemaType, Dict[str, Any]],
47-
db_object: ModelType | None = None,
48-
**kwargs
49-
) -> ModelType | None:
50-
db_object = db_object or await self.get(db=db, **kwargs)
51-
if db_object:
52-
if isinstance(object, dict):
53-
update_data = object
54-
else:
55-
update_data = object.model_dump(exclude_unset=True)
56-
57-
update_data.update({"updated_at": datetime.utcnow()})
58-
for field in object.__dict__:
59-
if field in update_data:
60-
setattr(db_object, field, update_data[field])
61-
db.add(db_object)
62-
await db.commit()
63-
64-
return db_object
45+
result = await db.execute(stmt)
46+
return result.all()
47+
48+
async def exists(self, db: AsyncSession, **kwargs) -> bool:
49+
to_select = _extract_matching_columns_from_kwargs(model=self._model, kwargs=kwargs)
50+
stmt = select(*to_select) \
51+
.filter_by(**kwargs) \
52+
.limit(1)
53+
result = await db.execute(stmt)
54+
55+
return result.first() is not None
56+
57+
async def update(self, db: AsyncSession, object: Union[UpdateSchemaType, Dict[str, Any]], **kwargs) -> ModelType | None:
58+
if isinstance(object, dict):
59+
update_data = object
60+
else:
61+
update_data = object.model_dump(exclude_unset=True)
62+
update_data["updated_at"] = datetime.utcnow()
6563

66-
async def db_delete(
67-
self,
68-
db: AsyncSession,
69-
db_object: ModelType | None = None,
70-
**kwargs
71-
):
72-
db_object = db_object or await self.get(db=db, **kwargs)
73-
await db.delete(db_object)
64+
stmt = update(self._model) \
65+
.filter_by(**kwargs) \
66+
.values(update_data)
67+
68+
await db.execute(stmt)
7469
await db.commit()
7570

76-
return db_object
71+
async def db_delete(self, db: AsyncSession, **kwargs):
72+
stmt = delete(self._model).filter_by(**kwargs)
73+
await db.execute(stmt)
74+
await db.commit()
7775

78-
async def delete(
79-
self,
80-
db: AsyncSession,
81-
db_object: ModelType | None = None,
82-
**kwargs
83-
) -> ModelType | None:
84-
db_object = db_object or await self.get(db=db, **kwargs)
85-
if db_object:
86-
if "is_deleted" in db_object.__dict__.keys():
76+
async def delete(self, db: AsyncSession, db_row: Row | None = None, **kwargs) -> ModelType | None:
77+
db_row = db_row or await self.get(db=db, **kwargs)
78+
if db_row:
79+
if "is_deleted" in db_row:
8780
object_dict = {
8881
"is_deleted": True,
8982
"deleted_at": datetime.utcnow()
9083
}
91-
query = update(self._model) \
84+
stmt = update(self._model) \
9285
.filter_by(**kwargs) \
9386
.values(object_dict)
9487

95-
await db.execute(query)
88+
await db.execute(stmt)
9689
await db.commit()
97-
await db.refresh(db_object)
98-
else:
99-
db_object = await self.db_delete(db=db, db_object=db_object, **kwargs)
10090

101-
return db_object
91+
else:
92+
stmt = delete(self._model).filter_by(**kwargs)
93+
await db.execute(stmt)
94+
await db.commit()

src/app/crud/helper.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
from typing import Any, List, Type
2+
3+
from pydantic import BaseModel
4+
5+
from app.core.database import Base
6+
7+
def _extract_matching_columns_from_schema(model: Type[Base], schema: Type[BaseModel] | None) -> List[Any]:
8+
"""
9+
Retrieves a list of ORM column objects from a SQLAlchemy model that match the field names in a given Pydantic schema.
10+
11+
Parameters
12+
----------
13+
model: Type[Base]
14+
The SQLAlchemy ORM model containing columns to be matched with the schema fields.
15+
schema: Type[BaseModel]
16+
The Pydantic schema containing field names to be matched with the model's columns.
17+
18+
Returns
19+
-------
20+
List[Any]
21+
A list of ORM column objects from the model that correspond to the field names defined in the schema.
22+
"""
23+
column_list = list(model.__table__.columns)
24+
if schema is not None:
25+
schema_fields = schema.__fields__.keys()
26+
column_list = []
27+
for column_name in schema_fields:
28+
if hasattr(model, column_name):
29+
column_list.append(getattr(model, column_name))
30+
31+
return column_list
32+
33+
34+
def _extract_matching_columns_from_kwargs(model: Type[Base], kwargs: dict) -> List[Any]:
35+
if kwargs is not None:
36+
kwargs_fields = kwargs.keys()
37+
column_list = []
38+
for column_name in kwargs_fields:
39+
if hasattr(model, column_name):
40+
column_list.append(getattr(model, column_name))
41+
42+
return column_list

0 commit comments

Comments
 (0)