Skip to content

Commit 54080eb

Browse files
🎨 [pre-commit.ci] Auto format from pre-commit.com hooks
1 parent 9fe222d commit 54080eb

File tree

2 files changed

+40
-30
lines changed

2 files changed

+40
-30
lines changed

example.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ class HeroRepository(CrudRepository[int, Hero]):
1414
def get_hero_by_name(self, name: str) -> Hero:
1515
...
1616

17+
1718
sqlite_file_name = "database.db"
1819
sqlite_url = f"sqlite:///{sqlite_file_name}"
1920
engine = CrudRepository.create_all_tables(sqlite_url)
@@ -23,4 +24,4 @@ def get_hero_by_name(self, name: str) -> Hero:
2324
deadpond = Hero(name="Deadpond", secret_name="Dive Wilson")
2425
hero_repo.save(deadpond)
2526
print(hero_repo.find_all())
26-
print(hero_repo.get_hero_by_name(name="Deadpond"))
27+
print(hero_repo.get_hero_by_name(name="Deadpond"))
Lines changed: 38 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,49 +1,49 @@
1-
2-
31
from typing import Any, Callable, Generic, Iterable, Optional, Type, TypeVar, get_args
4-
from sqlalchemy.sql import text
52
from uuid import UUID
63

74
from pydantic import BaseModel
85
from sqlalchemy import Engine, create_engine
6+
from sqlalchemy.sql import text
97
from sqlmodel import Session, SQLModel, select
10-
8+
119
T = TypeVar("T", SQLModel, BaseModel)
12-
ID = TypeVar("ID", UUID,int)
10+
ID = TypeVar("ID", UUID, int)
1311

1412
import logging
1513

1614
logger = logging.getLogger(__name__)
1715

18-
class CrudRepository(Generic[ID,T]):
16+
17+
class CrudRepository(Generic[ID, T]):
1918
def __init__(self, engine: Engine) -> None:
2019
self.engine = engine
21-
self.id_type ,self.model_class = self._get_model_id_type_with_class()
22-
20+
self.id_type, self.model_class = self._get_model_id_type_with_class()
21+
2322
@classmethod
2423
def create_all_tables(cls, url: str) -> Engine:
2524
engine = create_engine(url, echo=False)
2625
SQLModel.metadata.create_all(engine)
2726
return engine
28-
27+
2928
@classmethod
3029
def _get_model_id_type_with_class(cls) -> tuple[Type[ID], Type[T]]:
31-
return get_args(tp= cls.__mro__[0].__orig_bases__[0])
30+
return get_args(tp=cls.__mro__[0].__orig_bases__[0])
3231

33-
def _commit_operation_in_session(self,session_operation: Callable[[Session], None], session: Session) -> bool:
32+
def _commit_operation_in_session(
33+
self, session_operation: Callable[[Session], None], session: Session
34+
) -> bool:
3435
try:
3536
session_operation(session)
3637
session.commit()
3738
except Exception as error:
3839
logger.error(error)
3940
return False
40-
41+
4142
return True
4243

4344
def _create_session(self) -> Session:
44-
return Session(self.engine, expire_on_commit= True)
45+
return Session(self.engine, expire_on_commit=True)
4546

46-
4747
def find_by_id(self, id: ID) -> tuple[T, Session]:
4848
session = self._create_session()
4949
statement = select(self.model_class).where(self.model_class.id == id) # type: ignore
@@ -53,18 +53,21 @@ def find_all_by_ids(self, ids: list[ID]) -> tuple[Iterable[T], Session]:
5353
session = self._create_session()
5454
statement = select(self.model_class).where(self.model_class.id in ids) # type: ignore
5555
return (session.exec(statement).all(), session)
56-
5756

5857
def find_all(self) -> tuple[Iterable[T], Session]:
5958
session = self._create_session()
60-
statement = select(self.model_class) # type: ignore
59+
statement = select(self.model_class) # type: ignore
6160
return (session.exec(statement).all(), session)
6261

6362
def save(self, entity: T, session: Optional[Session] = None) -> T:
64-
self._commit_operation_in_session(lambda session: session.add(entity), session or self._create_session())
63+
self._commit_operation_in_session(
64+
lambda session: session.add(entity), session or self._create_session()
65+
)
6566
return entity
6667

67-
def save_all(self, entities: Iterable[T], session: Optional[Session] = None) -> bool:
68+
def save_all(
69+
self, entities: Iterable[T], session: Optional[Session] = None
70+
) -> bool:
6871
return self._commit_operation_in_session(
6972
lambda session: session.add_all(entities), session or self._create_session()
7073
)
@@ -73,30 +76,36 @@ def delete(self, entity: T, session: Optional[Session] = None) -> bool:
7376
return self._commit_operation_in_session(
7477
lambda session: session.delete(entity), session or self._create_session()
7578
)
76-
77-
def delete_all(self, entities: Iterable[T], session: Optional[Session] = None) -> bool:
79+
80+
def delete_all(
81+
self, entities: Iterable[T], session: Optional[Session] = None
82+
) -> bool:
7883
session = session or self._create_session()
7984
for entity in entities:
8085
session.delete(entity)
8186
session.commit()
8287

8388
return True
84-
85-
89+
90+
8691
def native_query(query: str, return_type: Type[T]) -> Any:
8792
def decorated(func: Callable[..., T]) -> Callable[..., T]:
8893
def wrapper(self: CrudRepository, **kwargs) -> T:
8994
with self.engine.connect() as connection:
90-
91-
92-
sql= text(query.format(**kwargs))
95+
sql = text(query.format(**kwargs))
9396
query_result = connection.execute(sql)
9497
query_result_dicts = query_result.mappings().all()
9598
if return_type.__name__ == "Iterable":
9699
cls_inside_inside_iterable = get_args(return_type)[0]
97-
return [cls_inside_inside_iterable.model_validate(query_result) for query_result in query_result_dicts] # type: ignore
98-
return return_type.model_validate(list(query_result_dicts).pop()) # Create an instance of the specified model class
100+
return [
101+
cls_inside_inside_iterable.model_validate(query_result)
102+
for query_result in query_result_dicts
103+
] # type: ignore
104+
return return_type.model_validate(
105+
list(query_result_dicts).pop()
106+
) # Create an instance of the specified model class
99107
# return model_instance
108+
100109
return wrapper
101-
return decorated # type: ignore
102-
110+
111+
return decorated # type: ignore

0 commit comments

Comments
 (0)