Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from sqlmodel import Field, SQLModel
from sqlmodel.repository.crud_repository import CrudRepository, native_query


class Hero(SQLModel, table=True):
id: int | None = Field(default=None, primary_key=True)
name: str
secret_name: str
age: int | None = None


class HeroRepository(CrudRepository[int, Hero]):
@native_query("SELECT * FROM hero WHERE name = '{name}'", Hero)
def get_hero_by_name(self, name: str) -> Hero:
...


sqlite_file_name = "database.db"
sqlite_url = f"sqlite:///{sqlite_file_name}"
engine = CrudRepository.create_all_tables(sqlite_url)

hero_repo = HeroRepository(engine)

deadpond = Hero(name="Deadpond", secret_name="Dive Wilson")
hero_repo.save(deadpond)
print(hero_repo.find_all())
print(hero_repo.get_hero_by_name(name="Deadpond"))
111 changes: 111 additions & 0 deletions sqlmodel/repository/crud_repository.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
from typing import Any, Callable, Generic, Iterable, Optional, Type, TypeVar, get_args
from uuid import UUID

from pydantic import BaseModel
from sqlalchemy import Engine, create_engine
from sqlalchemy.sql import text
from sqlmodel import Session, SQLModel, select

T = TypeVar("T", SQLModel, BaseModel)
ID = TypeVar("ID", UUID, int)

import logging

logger = logging.getLogger(__name__)


class CrudRepository(Generic[ID, T]):
def __init__(self, engine: Engine) -> None:
self.engine = engine
self.id_type, self.model_class = self._get_model_id_type_with_class()

@classmethod
def create_all_tables(cls, url: str) -> Engine:
engine = create_engine(url, echo=False)
SQLModel.metadata.create_all(engine)
return engine

@classmethod
def _get_model_id_type_with_class(cls) -> tuple[Type[ID], Type[T]]:
return get_args(tp=cls.__mro__[0].__orig_bases__[0])

def _commit_operation_in_session(
self, session_operation: Callable[[Session], None], session: Session
) -> bool:
try:
session_operation(session)
session.commit()
except Exception as error:
logger.error(error)
return False

return True

def _create_session(self) -> Session:
return Session(self.engine, expire_on_commit=True)

def find_by_id(self, id: ID) -> tuple[T, Session]:
session = self._create_session()
statement = select(self.model_class).where(self.model_class.id == id) # type: ignore
return (session.exec(statement).one(), session)

def find_all_by_ids(self, ids: list[ID]) -> tuple[Iterable[T], Session]:
session = self._create_session()
statement = select(self.model_class).where(self.model_class.id in ids) # type: ignore
return (session.exec(statement).all(), session)

def find_all(self) -> tuple[Iterable[T], Session]:
session = self._create_session()
statement = select(self.model_class) # type: ignore
return (session.exec(statement).all(), session)

def save(self, entity: T, session: Optional[Session] = None) -> T:
self._commit_operation_in_session(
lambda session: session.add(entity), session or self._create_session()
)
return entity

def save_all(
self, entities: Iterable[T], session: Optional[Session] = None
) -> bool:
return self._commit_operation_in_session(
lambda session: session.add_all(entities), session or self._create_session()
)

def delete(self, entity: T, session: Optional[Session] = None) -> bool:
return self._commit_operation_in_session(
lambda session: session.delete(entity), session or self._create_session()
)

def delete_all(
self, entities: Iterable[T], session: Optional[Session] = None
) -> bool:
session = session or self._create_session()
for entity in entities:
session.delete(entity)
session.commit()

return True


def native_query(query: str, return_type: Type[T]) -> Any:
def decorated(func: Callable[..., T]) -> Callable[..., T]:
def wrapper(self: CrudRepository, **kwargs) -> T:
with self.engine.connect() as connection:
sql = text(query.format(**kwargs))
query_result = connection.execute(sql)
query_result_dicts = query_result.mappings().all()
if return_type.__name__ == "Iterable":
cls_inside_inside_iterable = get_args(return_type)[0]
return [
cls_inside_inside_iterable.model_validate(query_result)
for query_result in query_result_dicts
] # type: ignore
return return_type.model_validate(
list(query_result_dicts).pop()
) # Create an instance of the specified model class
# return model_instance

return wrapper

return decorated # type: ignore