22from datetime import datetime
33
44from pydantic import BaseModel
5- from sqlalchemy import select , update
5+ from sqlalchemy import select , update , delete
66from sqlalchemy .ext .asyncio import AsyncSession
7+ from sqlalchemy .engine .row import Row
8+
9+ from .helper import _extract_matching_columns_from_schema , _extract_matching_columns_from_kwargs
710
811ModelType = TypeVar ("ModelType" )
912CreateSchemaType = TypeVar ("CreateSchemaType" , bound = BaseModel )
@@ -24,78 +27,68 @@ async def create(
2427 await db .commit ()
2528 return db_object
2629
27- async def get (self , db : AsyncSession , ** kwargs ) -> ModelType | None :
28- query = select (self ._model ).filter_by (** kwargs )
29- result = await db .execute (query )
30- return result .scalar_one_or_none ()
30+ async def get (self , db : AsyncSession , schema_to_select : Type [BaseModel ] | None = None , ** kwargs ) -> ModelType | None :
31+ to_select = _extract_matching_columns_from_schema (model = self ._model , schema = schema_to_select )
32+ stmt = select (* to_select ) \
33+ .filter_by (** kwargs )
34+
35+ result = await db .execute (stmt )
36+ return result .first ()
3137
32- async def get_multi (
33- self , db : AsyncSession , offset : int = 0 , limit : int = 100 , ** kwargs
34- ) -> List [ModelType ]:
35- query = select (self ._model ) \
38+ async def get_multi (self , db : AsyncSession , offset : int = 0 , limit : int = 100 , schema_to_select : Type [BaseModel ] | None = None , ** kwargs ) -> List [ModelType ]:
39+ to_select = _extract_matching_columns_from_schema (model = self ._model , schema = schema_to_select )
40+ stmt = select (* to_select ) \
3641 .filter_by (** kwargs ) \
3742 .offset (offset ) \
3843 .limit (limit )
3944
40- result = await db .execute (query )
41- return result .scalars ().all ()
42-
43- async def update (
44- self ,
45- db : AsyncSession ,
46- object : Union [UpdateSchemaType , Dict [str , Any ]],
47- db_object : ModelType | None = None ,
48- ** kwargs
49- ) -> ModelType | None :
50- db_object = db_object or await self .get (db = db , ** kwargs )
51- if db_object :
52- if isinstance (object , dict ):
53- update_data = object
54- else :
55- update_data = object .model_dump (exclude_unset = True )
56-
57- update_data .update ({"updated_at" : datetime .utcnow ()})
58- for field in object .__dict__ :
59- if field in update_data :
60- setattr (db_object , field , update_data [field ])
61- db .add (db_object )
62- await db .commit ()
63-
64- return db_object
45+ result = await db .execute (stmt )
46+ return result .all ()
47+
48+ async def exists (self , db : AsyncSession , ** kwargs ) -> bool :
49+ to_select = _extract_matching_columns_from_kwargs (model = self ._model , kwargs = kwargs )
50+ stmt = select (* to_select ) \
51+ .filter_by (** kwargs ) \
52+ .limit (1 )
53+ result = await db .execute (stmt )
54+
55+ return result .first () is not None
56+
57+ async def update (self , db : AsyncSession , object : Union [UpdateSchemaType , Dict [str , Any ]], ** kwargs ) -> ModelType | None :
58+ if isinstance (object , dict ):
59+ update_data = object
60+ else :
61+ update_data = object .model_dump (exclude_unset = True )
62+ update_data ["updated_at" ] = datetime .utcnow ()
6563
66- async def db_delete (
67- self ,
68- db : AsyncSession ,
69- db_object : ModelType | None = None ,
70- ** kwargs
71- ):
72- db_object = db_object or await self .get (db = db , ** kwargs )
73- await db .delete (db_object )
64+ stmt = update (self ._model ) \
65+ .filter_by (** kwargs ) \
66+ .values (update_data )
67+
68+ await db .execute (stmt )
7469 await db .commit ()
7570
76- return db_object
71+ async def db_delete (self , db : AsyncSession , ** kwargs ):
72+ stmt = delete (self ._model ).filter_by (** kwargs )
73+ await db .execute (stmt )
74+ await db .commit ()
7775
78- async def delete (
79- self ,
80- db : AsyncSession ,
81- db_object : ModelType | None = None ,
82- ** kwargs
83- ) -> ModelType | None :
84- db_object = db_object or await self .get (db = db , ** kwargs )
85- if db_object :
86- if "is_deleted" in db_object .__dict__ .keys ():
76+ async def delete (self , db : AsyncSession , db_row : Row | None = None , ** kwargs ) -> ModelType | None :
77+ db_row = db_row or await self .get (db = db , ** kwargs )
78+ if db_row :
79+ if "is_deleted" in db_row :
8780 object_dict = {
8881 "is_deleted" : True ,
8982 "deleted_at" : datetime .utcnow ()
9083 }
91- query = update (self ._model ) \
84+ stmt = update (self ._model ) \
9285 .filter_by (** kwargs ) \
9386 .values (object_dict )
9487
95- await db .execute (query )
88+ await db .execute (stmt )
9689 await db .commit ()
97- await db .refresh (db_object )
98- else :
99- db_object = await self .db_delete (db = db , db_object = db_object , ** kwargs )
10090
101- return db_object
91+ else :
92+ stmt = delete (self ._model ).filter_by (** kwargs )
93+ await db .execute (stmt )
94+ await db .commit ()
0 commit comments