Skip to content

Commit 00de177

Browse files
committed
handled DB transaction globally
1 parent db018ba commit 00de177

File tree

4 files changed

+45
-44
lines changed

4 files changed

+45
-44
lines changed

fastapi_2fa/api/deps/db.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,11 @@ async def get_db() -> Generator:
77
try:
88
db = SessionLocal()
99
yield db
10+
await db.commit()
11+
print('committed in db...')
12+
except Exception as ex:
13+
print(f'rolling db for exception {ex} ...')
14+
await db.rollback()
1015
finally:
16+
print('closing db...')
1117
await db.close()

fastapi_2fa/crud/base_crud.py

Lines changed: 26 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,31 @@ class CrudBase(
2222
UpdateSchemaType, # Pydantic Update Schema
2323
]
2424
):
25-
def __init__(self, model: Type[ModelType]):
25+
def __init__(self, model: Type[ModelType], transaction: bool = False):
2626
"""
2727
CRUD object with default methods to Create, Read, Update, Delete (CRUD).
2828
29-
**Parameters**
30-
31-
* `model`: A SQLAlchemy model class
32-
* `schema`: A Pydantic model (schema) class
29+
Args:
30+
model (Type[ModelType]): A SQLAlchemy model class
31+
transaction (bool, optional): if True, changes won't be committed.
32+
Defaults to False.
3333
"""
34+
3435
self.model = model
36+
self.transaction = transaction
37+
38+
def __call__(self, transaction: bool) -> 'CrudBase':
39+
self.transaction = transaction
40+
return self
41+
42+
async def handle_commit(self, db: Session) -> bool:
43+
committed = False
44+
if not self.transaction:
45+
print(f'{self} is NOT under transaction, committing..')
46+
await db.commit()
47+
committed = True
48+
print(f'{self} is under transaction, NOT committing..')
49+
return committed
3550

3651
async def get(self, db: Session, id: Any) -> Optional[ModelType]:
3752
query = select(self.model).where(self.model.id == id)
@@ -49,12 +64,8 @@ async def create(self, db: Session, *, obj_in: CreateSchemaType):
4964
obj_in_data = jsonable_encoder(obj_in)
5065
db_obj: ModelType = self.model(**obj_in_data)
5166
db.add(db_obj)
52-
try:
53-
await db.commit()
54-
except Exception:
55-
await db.rollback()
56-
raise
57-
await db.refresh(db_obj)
67+
if await self.handle_commit(db):
68+
await db.refresh(db_obj)
5869
return db_obj
5970

6071
async def update(
@@ -76,21 +87,11 @@ async def update(
7687
.execution_options(synchronize_session="fetch")
7788
)
7889
await db.execute(query)
79-
try:
80-
# commit only if not under transaction
81-
if not db.in_nested_transaction():
82-
await db.commit()
83-
return db_obj
84-
except Exception as ex:
85-
await db.rollback()
86-
raise ex
90+
if await self.handle_commit(db):
91+
await db.refresh(db_obj)
92+
return db_obj
8793

8894
async def remove(self, db: Session, *, id: int) -> bool:
8995
query = delete(self.model).where(self.model.id == id)
9096
await db.execute(query)
91-
try:
92-
await db.commit()
93-
except Exception as ex:
94-
await db.rollback()
95-
raise ex
96-
return True
97+
return self.handle_commit(db)

fastapi_2fa/crud/device.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,28 +10,26 @@
1010

1111
class DeviceCrud(CrudBase[Device, DeviceCreate, DeviceUpdate]):
1212

13-
@staticmethod
14-
async def create(db: Session, device: DeviceCreate, user: User) -> Device:
13+
async def create(
14+
self, db: Session, device: DeviceCreate, user: User
15+
) -> tuple[Device, SvgImage | None]:
1516
# create device
17+
encoded_key: str = create_encoded_two_factor_auth_key()
1618
db_device = Device(
17-
key=create_encoded_two_factor_auth_key(),
18-
user_id=user.id,
19+
key=encoded_key,
20+
user=user,
1921
device_type=device.device_type
2022
)
2123
# create backup tokens
2224
for token in get_fake_otp_tokens():
2325
token_db = BackupToken(
24-
device_id=db_device.id,
26+
device=db_device,
2527
token=token
2628
)
2729
db_device.backup_tokens.append(token_db)
2830
db.add(db_device)
29-
try:
30-
await db.commit()
31-
except Exception:
32-
await db.rollback()
33-
raise
34-
return db_device
31+
if await self.handle_commit(db):
32+
await db.refresh(db_device)
3533

3634

3735
device_crud = DeviceCrud(model=Device)

fastapi_2fa/crud/users.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,21 +11,17 @@
1111

1212

1313
class UserCrud(CrudBase[User, UserCreate, UserUpdate]):
14-
@staticmethod
15-
async def create(db: Session, user: UserCreate) -> User:
14+
15+
async def create(self, db: Session, user: UserCreate) -> User:
1616
db_obj = User(
1717
email=user.email,
1818
hashed_password=get_password_hash(user.password),
1919
tfa_enabled=user.tfa_enabled,
2020
full_name=user.full_name,
2121
)
2222
db.add(db_obj)
23-
try:
24-
await db.commit()
25-
except Exception:
26-
await db.rollback()
27-
raise
28-
await db.refresh(db_obj)
23+
if await self.handle_commit(db):
24+
await db.refresh(db_obj)
2925
return db_obj
3026

3127
@staticmethod

0 commit comments

Comments
 (0)