1-
2-
31from typing import Any , Callable , Generic , Iterable , Optional , Type , TypeVar , get_args
4- from sqlalchemy .sql import text
52from uuid import UUID
63
74from pydantic import BaseModel
85from sqlalchemy import Engine , create_engine
6+ from sqlalchemy .sql import text
97from sqlmodel import Session , SQLModel , select
10-
8+
119T = TypeVar ("T" , SQLModel , BaseModel )
12- ID = TypeVar ("ID" , UUID ,int )
10+ ID = TypeVar ("ID" , UUID , int )
1311
1412import logging
1513
1614logger = logging .getLogger (__name__ )
1715
18- class CrudRepository (Generic [ID ,T ]):
16+
17+ class CrudRepository (Generic [ID , T ]):
1918 def __init__ (self , engine : Engine ) -> None :
2019 self .engine = engine
21- self .id_type , self .model_class = self ._get_model_id_type_with_class ()
22-
20+ self .id_type , self .model_class = self ._get_model_id_type_with_class ()
21+
2322 @classmethod
2423 def create_all_tables (cls , url : str ) -> Engine :
2524 engine = create_engine (url , echo = False )
2625 SQLModel .metadata .create_all (engine )
2726 return engine
28-
27+
2928 @classmethod
3029 def _get_model_id_type_with_class (cls ) -> tuple [Type [ID ], Type [T ]]:
31- return get_args (tp = cls .__mro__ [0 ].__orig_bases__ [0 ])
30+ return get_args (tp = cls .__mro__ [0 ].__orig_bases__ [0 ])
3231
33- def _commit_operation_in_session (self ,session_operation : Callable [[Session ], None ], session : Session ) -> bool :
32+ def _commit_operation_in_session (
33+ self , session_operation : Callable [[Session ], None ], session : Session
34+ ) -> bool :
3435 try :
3536 session_operation (session )
3637 session .commit ()
3738 except Exception as error :
3839 logger .error (error )
3940 return False
40-
41+
4142 return True
4243
4344 def _create_session (self ) -> Session :
44- return Session (self .engine , expire_on_commit = True )
45+ return Session (self .engine , expire_on_commit = True )
4546
46-
4747 def find_by_id (self , id : ID ) -> tuple [T , Session ]:
4848 session = self ._create_session ()
4949 statement = select (self .model_class ).where (self .model_class .id == id ) # type: ignore
@@ -53,18 +53,21 @@ def find_all_by_ids(self, ids: list[ID]) -> tuple[Iterable[T], Session]:
5353 session = self ._create_session ()
5454 statement = select (self .model_class ).where (self .model_class .id in ids ) # type: ignore
5555 return (session .exec (statement ).all (), session )
56-
5756
5857 def find_all (self ) -> tuple [Iterable [T ], Session ]:
5958 session = self ._create_session ()
60- statement = select (self .model_class ) # type: ignore
59+ statement = select (self .model_class ) # type: ignore
6160 return (session .exec (statement ).all (), session )
6261
6362 def save (self , entity : T , session : Optional [Session ] = None ) -> T :
64- self ._commit_operation_in_session (lambda session : session .add (entity ), session or self ._create_session ())
63+ self ._commit_operation_in_session (
64+ lambda session : session .add (entity ), session or self ._create_session ()
65+ )
6566 return entity
6667
67- def save_all (self , entities : Iterable [T ], session : Optional [Session ] = None ) -> bool :
68+ def save_all (
69+ self , entities : Iterable [T ], session : Optional [Session ] = None
70+ ) -> bool :
6871 return self ._commit_operation_in_session (
6972 lambda session : session .add_all (entities ), session or self ._create_session ()
7073 )
@@ -73,30 +76,36 @@ def delete(self, entity: T, session: Optional[Session] = None) -> bool:
7376 return self ._commit_operation_in_session (
7477 lambda session : session .delete (entity ), session or self ._create_session ()
7578 )
76-
77- def delete_all (self , entities : Iterable [T ], session : Optional [Session ] = None ) -> bool :
79+
80+ def delete_all (
81+ self , entities : Iterable [T ], session : Optional [Session ] = None
82+ ) -> bool :
7883 session = session or self ._create_session ()
7984 for entity in entities :
8085 session .delete (entity )
8186 session .commit ()
8287
8388 return True
84-
85-
89+
90+
8691def native_query (query : str , return_type : Type [T ]) -> Any :
8792 def decorated (func : Callable [..., T ]) -> Callable [..., T ]:
8893 def wrapper (self : CrudRepository , ** kwargs ) -> T :
8994 with self .engine .connect () as connection :
90-
91-
92- sql = text (query .format (** kwargs ))
95+ sql = text (query .format (** kwargs ))
9396 query_result = connection .execute (sql )
9497 query_result_dicts = query_result .mappings ().all ()
9598 if return_type .__name__ == "Iterable" :
9699 cls_inside_inside_iterable = get_args (return_type )[0 ]
97- return [cls_inside_inside_iterable .model_validate (query_result ) for query_result in query_result_dicts ] # type: ignore
98- return return_type .model_validate (list (query_result_dicts ).pop ()) # Create an instance of the specified model class
100+ return [
101+ cls_inside_inside_iterable .model_validate (query_result )
102+ for query_result in query_result_dicts
103+ ] # type: ignore
104+ return return_type .model_validate (
105+ list (query_result_dicts ).pop ()
106+ ) # Create an instance of the specified model class
99107 # return model_instance
108+
100109 return wrapper
101- return decorated # type: ignore
102-
110+
111+ return decorated # type: ignore
0 commit comments