Skip to content

Conversation

@NFUChen
Copy link

@NFUChen NFUChen commented Mar 31, 2024

Description:

This pull request introduces a CrudRepository interface inspired by Spring-Boot ORM conventions, offering a convenient abstraction layer for database operations. The primary goal is to streamline database interactions for developers by providing a standardized interface while still allowing flexibility for custom queries when needed.

Key Features:

  1. CrudRepository Interface:

    • The CrudRepository interface abstracts common CRUD operations such as save, delete, find_by_id, find_all, etc.
    • Developers can define their repository interfaces inheriting from CrudRepository and automatically gain implementations for these operations.
  2. Automatic Table Creation:

    • A create_all_tables method is provided within CrudRepository, facilitating automatic table creation based on SQLAlchemy metadata.
  3. Native Query Decorator:

    • The native_query decorator allows developers to execute custom SQL queries while still leveraging the framework's benefits.
    • Queries are executed through SQLAlchemy's connection, and the results are seamlessly mapped to SQLModel instances.
  4. Type Annotations and Generics:

    • Strong type annotations and generics are utilized throughout the codebase, enhancing readability and maintainability.
    • Pydantic models and SQLModel classes are supported, providing a flexible schema definition.
  5. Error Handling and Logging:

    • Error handling is implemented to catch exceptions during database operations, ensuring robustness.
    • Logging is utilized to record errors and provide visibility into potential issues.

Side Note:

Feedback on the design and implementation of these concepts is welcomed. While the current implementation offers a basic foundation, detailed implementations of these concepts are still subject to change based on community feedback and evolving requirements. As we strive to improve the framework, any input is valuable.

Current Implementation:

from typing import Any, Callable, Generic, Iterable, Optional, Type, TypeVar, get_args
from sqlalchemy.sql import text
from uuid import UUID

from pydantic import BaseModel
from sqlalchemy import Engine, create_engine
from sqlmodel import Session, SQLModel, select
 
T = TypeVar("T", SQLModel, BaseModel)
ID = TypeVar("ID", UUID,int)

import logging

logger = logging.getLogger(__name__)

class CrudRepository(Generic[ID,T]):
    def __init__(self, engine: Engine) -> None:
        self.engine = engine
        self.id_type ,self.model_class = self._get_model_id_type_with_class()
        
    @classmethod
    def create_all_tables(cls, url: str) -> Engine:
        engine = create_engine(url, echo=False)
        SQLModel.metadata.create_all(engine)
        return engine
        
    @classmethod
    def _get_model_id_type_with_class(cls) -> tuple[Type[ID], Type[T]]:
        return get_args(tp= cls.__mro__[0].__orig_bases__[0])

    def _commit_operation_in_session(self,session_operation: Callable[[Session], None], session: Session) -> bool:
        try:
            session_operation(session)
            session.commit()
        except Exception as error:
            logger.error(error)
            return False
        
        return True

    def _create_session(self) -> Session:
        return Session(self.engine, expire_on_commit= True)

    
    def find_by_id(self, id: ID) -> tuple[T, Session]:
        session = self._create_session()
        statement = select(self.model_class).where(self.model_class.id == id)  # type: ignore
        return (session.exec(statement).one(), session)

    def find_all_by_ids(self, ids: list[ID]) -> tuple[Iterable[T], Session]:
        session = self._create_session()
        statement = select(self.model_class).where(self.model_class.id in ids)  # type: ignore
        return (session.exec(statement).all(), session)
    

    def find_all(self) -> tuple[Iterable[T], Session]:
        session = self._create_session()
        statement = select(self.model_class) # type: ignore
        return (session.exec(statement).all(), session)

    def save(self, entity: T, session: Optional[Session] = None) -> T:
        self._commit_operation_in_session(lambda session: session.add(entity), session or self._create_session())
        return entity

    def save_all(self, entities: Iterable[T], session: Optional[Session] = None) -> bool:
        return self._commit_operation_in_session(
            lambda session: session.add_all(entities), session or self._create_session()
        )

    def delete(self, entity: T, session: Optional[Session] = None) -> bool:
        return self._commit_operation_in_session(
            lambda session: session.delete(entity), session or self._create_session()
        )
    
    def delete_all(self, entities: Iterable[T], session: Optional[Session] = None) -> bool:
        session = session or self._create_session()
        for entity in entities:
            session.delete(entity)
        session.commit()

        return True
    
    
def native_query(query: str, return_type: Type[T]) -> Any:
    def decorated(func: Callable[..., T]) -> Callable[..., T]:
        def wrapper(self: CrudRepository, **kwargs) -> T:
            with self.engine.connect() as connection:
                
                
                sql= text(query.format(**kwargs))
                query_result = connection.execute(sql)
                query_result_dicts = query_result.mappings().all()
                if return_type.__name__ == "Iterable":
                    cls_inside_inside_iterable = get_args(return_type)[0]
                    return [cls_inside_inside_iterable.model_validate(query_result) for query_result in query_result_dicts] # type: ignore
                return return_type.model_validate(list(query_result_dicts).pop())  # Create an instance of the specified model class
        return wrapper
    return decorated # type: ignore

Example Usage:

from sqlmodel import Field, SQLModel
from sqlmodel.repository.crud_repository import CrudRepository, native_query

class Hero(SQLModel, table=True):
    id: int | None = Field(default=None, primary_key=True)
    name: str
    secret_name: str
    age: int | None = None

class HeroRepository(CrudRepository[int, Hero]):
    @native_query("SELECT * FROM hero WHERE name = '{name}'", Hero)
    def get_hero_by_name(self, name: str) -> Hero:
        ...

# Usage
sqlite_file_name = "database.db"
sqlite_url = f"sqlite:///{sqlite_file_name}"
engine = CrudRepository.create_all_tables(sqlite_url)

hero_repo = HeroRepository(engine)

deadpond = Hero(name="Deadpond", secret_name="Dive Wilson")
hero_repo.save(deadpond)

print(hero_repo.find_all())
print(hero_repo.get_hero_by_name(name="Deadpond"))

@NFUChen NFUChen closed this Apr 7, 2024
@NFUChen NFUChen reopened this May 26, 2024
@alejsdev alejsdev added the feature New feature or request label Jul 12, 2024
@tiangolo
Copy link
Member

Thanks for the interest, I think this is currently not needed, so I'll pass on this for now. ☕

@tiangolo tiangolo closed this Apr 27, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

feature New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants