Skip to content

Commit 9fe222d

Browse files
William ChenWilliam Chen
authored andcommitted
add: Add CrudRepository interface with native_query decorator
1 parent 51df778 commit 9fe222d

File tree

2 files changed

+128
-0
lines changed

2 files changed

+128
-0
lines changed

example.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from sqlmodel import Field, SQLModel
2+
from sqlmodel.repository.crud_repository import CrudRepository, native_query
3+
4+
5+
class Hero(SQLModel, table=True):
6+
id: int | None = Field(default=None, primary_key=True)
7+
name: str
8+
secret_name: str
9+
age: int | None = None
10+
11+
12+
class HeroRepository(CrudRepository[int, Hero]):
13+
@native_query("SELECT * FROM hero WHERE name = '{name}'", Hero)
14+
def get_hero_by_name(self, name: str) -> Hero:
15+
...
16+
17+
sqlite_file_name = "database.db"
18+
sqlite_url = f"sqlite:///{sqlite_file_name}"
19+
engine = CrudRepository.create_all_tables(sqlite_url)
20+
21+
hero_repo = HeroRepository(engine)
22+
23+
deadpond = Hero(name="Deadpond", secret_name="Dive Wilson")
24+
hero_repo.save(deadpond)
25+
print(hero_repo.find_all())
26+
print(hero_repo.get_hero_by_name(name="Deadpond"))
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
2+
3+
from typing import Any, Callable, Generic, Iterable, Optional, Type, TypeVar, get_args
4+
from sqlalchemy.sql import text
5+
from uuid import UUID
6+
7+
from pydantic import BaseModel
8+
from sqlalchemy import Engine, create_engine
9+
from sqlmodel import Session, SQLModel, select
10+
11+
T = TypeVar("T", SQLModel, BaseModel)
12+
ID = TypeVar("ID", UUID,int)
13+
14+
import logging
15+
16+
logger = logging.getLogger(__name__)
17+
18+
class CrudRepository(Generic[ID,T]):
19+
def __init__(self, engine: Engine) -> None:
20+
self.engine = engine
21+
self.id_type ,self.model_class = self._get_model_id_type_with_class()
22+
23+
@classmethod
24+
def create_all_tables(cls, url: str) -> Engine:
25+
engine = create_engine(url, echo=False)
26+
SQLModel.metadata.create_all(engine)
27+
return engine
28+
29+
@classmethod
30+
def _get_model_id_type_with_class(cls) -> tuple[Type[ID], Type[T]]:
31+
return get_args(tp= cls.__mro__[0].__orig_bases__[0])
32+
33+
def _commit_operation_in_session(self,session_operation: Callable[[Session], None], session: Session) -> bool:
34+
try:
35+
session_operation(session)
36+
session.commit()
37+
except Exception as error:
38+
logger.error(error)
39+
return False
40+
41+
return True
42+
43+
def _create_session(self) -> Session:
44+
return Session(self.engine, expire_on_commit= True)
45+
46+
47+
def find_by_id(self, id: ID) -> tuple[T, Session]:
48+
session = self._create_session()
49+
statement = select(self.model_class).where(self.model_class.id == id) # type: ignore
50+
return (session.exec(statement).one(), session)
51+
52+
def find_all_by_ids(self, ids: list[ID]) -> tuple[Iterable[T], Session]:
53+
session = self._create_session()
54+
statement = select(self.model_class).where(self.model_class.id in ids) # type: ignore
55+
return (session.exec(statement).all(), session)
56+
57+
58+
def find_all(self) -> tuple[Iterable[T], Session]:
59+
session = self._create_session()
60+
statement = select(self.model_class) # type: ignore
61+
return (session.exec(statement).all(), session)
62+
63+
def save(self, entity: T, session: Optional[Session] = None) -> T:
64+
self._commit_operation_in_session(lambda session: session.add(entity), session or self._create_session())
65+
return entity
66+
67+
def save_all(self, entities: Iterable[T], session: Optional[Session] = None) -> bool:
68+
return self._commit_operation_in_session(
69+
lambda session: session.add_all(entities), session or self._create_session()
70+
)
71+
72+
def delete(self, entity: T, session: Optional[Session] = None) -> bool:
73+
return self._commit_operation_in_session(
74+
lambda session: session.delete(entity), session or self._create_session()
75+
)
76+
77+
def delete_all(self, entities: Iterable[T], session: Optional[Session] = None) -> bool:
78+
session = session or self._create_session()
79+
for entity in entities:
80+
session.delete(entity)
81+
session.commit()
82+
83+
return True
84+
85+
86+
def native_query(query: str, return_type: Type[T]) -> Any:
87+
def decorated(func: Callable[..., T]) -> Callable[..., T]:
88+
def wrapper(self: CrudRepository, **kwargs) -> T:
89+
with self.engine.connect() as connection:
90+
91+
92+
sql= text(query.format(**kwargs))
93+
query_result = connection.execute(sql)
94+
query_result_dicts = query_result.mappings().all()
95+
if return_type.__name__ == "Iterable":
96+
cls_inside_inside_iterable = get_args(return_type)[0]
97+
return [cls_inside_inside_iterable.model_validate(query_result) for query_result in query_result_dicts] # type: ignore
98+
return return_type.model_validate(list(query_result_dicts).pop()) # Create an instance of the specified model class
99+
# return model_instance
100+
return wrapper
101+
return decorated # type: ignore
102+

0 commit comments

Comments
 (0)