From 9fe222dde3cc4cf90c6b0fa898b115294eba308b Mon Sep 17 00:00:00 2001 From: William Chen Date: Sun, 31 Mar 2024 23:44:48 +0800 Subject: [PATCH 1/2] add: Add CrudRepository interface with native_query decorator --- example.py | 26 +++++++ sqlmodel/repository/crud_repository.py | 102 +++++++++++++++++++++++++ 2 files changed, 128 insertions(+) create mode 100644 example.py create mode 100644 sqlmodel/repository/crud_repository.py diff --git a/example.py b/example.py new file mode 100644 index 0000000000..d1933c0415 --- /dev/null +++ b/example.py @@ -0,0 +1,26 @@ +from sqlmodel import Field, SQLModel +from sqlmodel.repository.crud_repository import CrudRepository, native_query + + +class Hero(SQLModel, table=True): + id: int | None = Field(default=None, primary_key=True) + name: str + secret_name: str + age: int | None = None + + +class HeroRepository(CrudRepository[int, Hero]): + @native_query("SELECT * FROM hero WHERE name = '{name}'", Hero) + def get_hero_by_name(self, name: str) -> Hero: + ... + +sqlite_file_name = "database.db" +sqlite_url = f"sqlite:///{sqlite_file_name}" +engine = CrudRepository.create_all_tables(sqlite_url) + +hero_repo = HeroRepository(engine) + +deadpond = Hero(name="Deadpond", secret_name="Dive Wilson") +hero_repo.save(deadpond) +print(hero_repo.find_all()) +print(hero_repo.get_hero_by_name(name="Deadpond")) \ No newline at end of file diff --git a/sqlmodel/repository/crud_repository.py b/sqlmodel/repository/crud_repository.py new file mode 100644 index 0000000000..c3b496b308 --- /dev/null +++ b/sqlmodel/repository/crud_repository.py @@ -0,0 +1,102 @@ + + +from typing import Any, Callable, Generic, Iterable, Optional, Type, TypeVar, get_args +from sqlalchemy.sql import text +from uuid import UUID + +from pydantic import BaseModel +from sqlalchemy import Engine, create_engine +from sqlmodel import Session, SQLModel, select + +T = TypeVar("T", SQLModel, BaseModel) +ID = TypeVar("ID", UUID,int) + +import logging + +logger = logging.getLogger(__name__) + +class CrudRepository(Generic[ID,T]): + def __init__(self, engine: Engine) -> None: + self.engine = engine + self.id_type ,self.model_class = self._get_model_id_type_with_class() + + @classmethod + def create_all_tables(cls, url: str) -> Engine: + engine = create_engine(url, echo=False) + SQLModel.metadata.create_all(engine) + return engine + + @classmethod + def _get_model_id_type_with_class(cls) -> tuple[Type[ID], Type[T]]: + return get_args(tp= cls.__mro__[0].__orig_bases__[0]) + + def _commit_operation_in_session(self,session_operation: Callable[[Session], None], session: Session) -> bool: + try: + session_operation(session) + session.commit() + except Exception as error: + logger.error(error) + return False + + return True + + def _create_session(self) -> Session: + return Session(self.engine, expire_on_commit= True) + + + def find_by_id(self, id: ID) -> tuple[T, Session]: + session = self._create_session() + statement = select(self.model_class).where(self.model_class.id == id) # type: ignore + return (session.exec(statement).one(), session) + + def find_all_by_ids(self, ids: list[ID]) -> tuple[Iterable[T], Session]: + session = self._create_session() + statement = select(self.model_class).where(self.model_class.id in ids) # type: ignore + return (session.exec(statement).all(), session) + + + def find_all(self) -> tuple[Iterable[T], Session]: + session = self._create_session() + statement = select(self.model_class) # type: ignore + return (session.exec(statement).all(), session) + + def save(self, entity: T, session: Optional[Session] = None) -> T: + self._commit_operation_in_session(lambda session: session.add(entity), session or self._create_session()) + return entity + + def save_all(self, entities: Iterable[T], session: Optional[Session] = None) -> bool: + return self._commit_operation_in_session( + lambda session: session.add_all(entities), session or self._create_session() + ) + + def delete(self, entity: T, session: Optional[Session] = None) -> bool: + return self._commit_operation_in_session( + lambda session: session.delete(entity), session or self._create_session() + ) + + def delete_all(self, entities: Iterable[T], session: Optional[Session] = None) -> bool: + session = session or self._create_session() + for entity in entities: + session.delete(entity) + session.commit() + + return True + + +def native_query(query: str, return_type: Type[T]) -> Any: + def decorated(func: Callable[..., T]) -> Callable[..., T]: + def wrapper(self: CrudRepository, **kwargs) -> T: + with self.engine.connect() as connection: + + + sql= text(query.format(**kwargs)) + query_result = connection.execute(sql) + query_result_dicts = query_result.mappings().all() + if return_type.__name__ == "Iterable": + cls_inside_inside_iterable = get_args(return_type)[0] + return [cls_inside_inside_iterable.model_validate(query_result) for query_result in query_result_dicts] # type: ignore + return return_type.model_validate(list(query_result_dicts).pop()) # Create an instance of the specified model class + # return model_instance + return wrapper + return decorated # type: ignore + \ No newline at end of file From 54080eb4c475dbf8cca32f360d9c5112ae6b4f60 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 31 Mar 2024 16:07:46 +0000 Subject: [PATCH 2/2] =?UTF-8?q?=F0=9F=8E=A8=20[pre-commit.ci]=20Auto=20for?= =?UTF-8?q?mat=20from=20pre-commit.com=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- example.py | 3 +- sqlmodel/repository/crud_repository.py | 67 +++++++++++++++----------- 2 files changed, 40 insertions(+), 30 deletions(-) diff --git a/example.py b/example.py index d1933c0415..d1b6925b1c 100644 --- a/example.py +++ b/example.py @@ -14,6 +14,7 @@ class HeroRepository(CrudRepository[int, Hero]): def get_hero_by_name(self, name: str) -> Hero: ... + sqlite_file_name = "database.db" sqlite_url = f"sqlite:///{sqlite_file_name}" engine = CrudRepository.create_all_tables(sqlite_url) @@ -23,4 +24,4 @@ def get_hero_by_name(self, name: str) -> Hero: deadpond = Hero(name="Deadpond", secret_name="Dive Wilson") hero_repo.save(deadpond) print(hero_repo.find_all()) -print(hero_repo.get_hero_by_name(name="Deadpond")) \ No newline at end of file +print(hero_repo.get_hero_by_name(name="Deadpond")) diff --git a/sqlmodel/repository/crud_repository.py b/sqlmodel/repository/crud_repository.py index c3b496b308..bd3c014946 100644 --- a/sqlmodel/repository/crud_repository.py +++ b/sqlmodel/repository/crud_repository.py @@ -1,49 +1,49 @@ - - from typing import Any, Callable, Generic, Iterable, Optional, Type, TypeVar, get_args -from sqlalchemy.sql import text from uuid import UUID from pydantic import BaseModel from sqlalchemy import Engine, create_engine +from sqlalchemy.sql import text from sqlmodel import Session, SQLModel, select - + T = TypeVar("T", SQLModel, BaseModel) -ID = TypeVar("ID", UUID,int) +ID = TypeVar("ID", UUID, int) import logging logger = logging.getLogger(__name__) -class CrudRepository(Generic[ID,T]): + +class CrudRepository(Generic[ID, T]): def __init__(self, engine: Engine) -> None: self.engine = engine - self.id_type ,self.model_class = self._get_model_id_type_with_class() - + self.id_type, self.model_class = self._get_model_id_type_with_class() + @classmethod def create_all_tables(cls, url: str) -> Engine: engine = create_engine(url, echo=False) SQLModel.metadata.create_all(engine) return engine - + @classmethod def _get_model_id_type_with_class(cls) -> tuple[Type[ID], Type[T]]: - return get_args(tp= cls.__mro__[0].__orig_bases__[0]) + return get_args(tp=cls.__mro__[0].__orig_bases__[0]) - def _commit_operation_in_session(self,session_operation: Callable[[Session], None], session: Session) -> bool: + def _commit_operation_in_session( + self, session_operation: Callable[[Session], None], session: Session + ) -> bool: try: session_operation(session) session.commit() except Exception as error: logger.error(error) return False - + return True def _create_session(self) -> Session: - return Session(self.engine, expire_on_commit= True) + return Session(self.engine, expire_on_commit=True) - def find_by_id(self, id: ID) -> tuple[T, Session]: session = self._create_session() statement = select(self.model_class).where(self.model_class.id == id) # type: ignore @@ -53,18 +53,21 @@ def find_all_by_ids(self, ids: list[ID]) -> tuple[Iterable[T], Session]: session = self._create_session() statement = select(self.model_class).where(self.model_class.id in ids) # type: ignore return (session.exec(statement).all(), session) - def find_all(self) -> tuple[Iterable[T], Session]: session = self._create_session() - statement = select(self.model_class) # type: ignore + statement = select(self.model_class) # type: ignore return (session.exec(statement).all(), session) def save(self, entity: T, session: Optional[Session] = None) -> T: - self._commit_operation_in_session(lambda session: session.add(entity), session or self._create_session()) + self._commit_operation_in_session( + lambda session: session.add(entity), session or self._create_session() + ) return entity - def save_all(self, entities: Iterable[T], session: Optional[Session] = None) -> bool: + def save_all( + self, entities: Iterable[T], session: Optional[Session] = None + ) -> bool: return self._commit_operation_in_session( lambda session: session.add_all(entities), session or self._create_session() ) @@ -73,30 +76,36 @@ def delete(self, entity: T, session: Optional[Session] = None) -> bool: return self._commit_operation_in_session( lambda session: session.delete(entity), session or self._create_session() ) - - def delete_all(self, entities: Iterable[T], session: Optional[Session] = None) -> bool: + + def delete_all( + self, entities: Iterable[T], session: Optional[Session] = None + ) -> bool: session = session or self._create_session() for entity in entities: session.delete(entity) session.commit() return True - - + + def native_query(query: str, return_type: Type[T]) -> Any: def decorated(func: Callable[..., T]) -> Callable[..., T]: def wrapper(self: CrudRepository, **kwargs) -> T: with self.engine.connect() as connection: - - - sql= text(query.format(**kwargs)) + sql = text(query.format(**kwargs)) query_result = connection.execute(sql) query_result_dicts = query_result.mappings().all() if return_type.__name__ == "Iterable": cls_inside_inside_iterable = get_args(return_type)[0] - return [cls_inside_inside_iterable.model_validate(query_result) for query_result in query_result_dicts] # type: ignore - return return_type.model_validate(list(query_result_dicts).pop()) # Create an instance of the specified model class + return [ + cls_inside_inside_iterable.model_validate(query_result) + for query_result in query_result_dicts + ] # type: ignore + return return_type.model_validate( + list(query_result_dicts).pop() + ) # Create an instance of the specified model class # return model_instance + return wrapper - return decorated # type: ignore - \ No newline at end of file + + return decorated # type: ignore