1+
2+
3+ from typing import Any , Callable , Generic , Iterable , Optional , Type , TypeVar , get_args
4+ from sqlalchemy .sql import text
5+ from uuid import UUID
6+
7+ from pydantic import BaseModel
8+ from sqlalchemy import Engine , create_engine
9+ from sqlmodel import Session , SQLModel , select
10+
11+ T = TypeVar ("T" , SQLModel , BaseModel )
12+ ID = TypeVar ("ID" , UUID ,int )
13+
14+ import logging
15+
16+ logger = logging .getLogger (__name__ )
17+
18+ class CrudRepository (Generic [ID ,T ]):
19+ def __init__ (self , engine : Engine ) -> None :
20+ self .engine = engine
21+ self .id_type ,self .model_class = self ._get_model_id_type_with_class ()
22+
23+ @classmethod
24+ def create_all_tables (cls , url : str ) -> Engine :
25+ engine = create_engine (url , echo = False )
26+ SQLModel .metadata .create_all (engine )
27+ return engine
28+
29+ @classmethod
30+ def _get_model_id_type_with_class (cls ) -> tuple [Type [ID ], Type [T ]]:
31+ return get_args (tp = cls .__mro__ [0 ].__orig_bases__ [0 ])
32+
33+ def _commit_operation_in_session (self ,session_operation : Callable [[Session ], None ], session : Session ) -> bool :
34+ try :
35+ session_operation (session )
36+ session .commit ()
37+ except Exception as error :
38+ logger .error (error )
39+ return False
40+
41+ return True
42+
43+ def _create_session (self ) -> Session :
44+ return Session (self .engine , expire_on_commit = True )
45+
46+
47+ def find_by_id (self , id : ID ) -> tuple [T , Session ]:
48+ session = self ._create_session ()
49+ statement = select (self .model_class ).where (self .model_class .id == id ) # type: ignore
50+ return (session .exec (statement ).one (), session )
51+
52+ def find_all_by_ids (self , ids : list [ID ]) -> tuple [Iterable [T ], Session ]:
53+ session = self ._create_session ()
54+ statement = select (self .model_class ).where (self .model_class .id in ids ) # type: ignore
55+ return (session .exec (statement ).all (), session )
56+
57+
58+ def find_all (self ) -> tuple [Iterable [T ], Session ]:
59+ session = self ._create_session ()
60+ statement = select (self .model_class ) # type: ignore
61+ return (session .exec (statement ).all (), session )
62+
63+ def save (self , entity : T , session : Optional [Session ] = None ) -> T :
64+ self ._commit_operation_in_session (lambda session : session .add (entity ), session or self ._create_session ())
65+ return entity
66+
67+ def save_all (self , entities : Iterable [T ], session : Optional [Session ] = None ) -> bool :
68+ return self ._commit_operation_in_session (
69+ lambda session : session .add_all (entities ), session or self ._create_session ()
70+ )
71+
72+ def delete (self , entity : T , session : Optional [Session ] = None ) -> bool :
73+ return self ._commit_operation_in_session (
74+ lambda session : session .delete (entity ), session or self ._create_session ()
75+ )
76+
77+ def delete_all (self , entities : Iterable [T ], session : Optional [Session ] = None ) -> bool :
78+ session = session or self ._create_session ()
79+ for entity in entities :
80+ session .delete (entity )
81+ session .commit ()
82+
83+ return True
84+
85+
86+ def native_query (query : str , return_type : Type [T ]) -> Any :
87+ def decorated (func : Callable [..., T ]) -> Callable [..., T ]:
88+ def wrapper (self : CrudRepository , ** kwargs ) -> T :
89+ with self .engine .connect () as connection :
90+
91+
92+ sql = text (query .format (** kwargs ))
93+ query_result = connection .execute (sql )
94+ query_result_dicts = query_result .mappings ().all ()
95+ if return_type .__name__ == "Iterable" :
96+ cls_inside_inside_iterable = get_args (return_type )[0 ]
97+ return [cls_inside_inside_iterable .model_validate (query_result ) for query_result in query_result_dicts ] # type: ignore
98+ return return_type .model_validate (list (query_result_dicts ).pop ()) # Create an instance of the specified model class
99+ # return model_instance
100+ return wrapper
101+ return decorated # type: ignore
102+
0 commit comments