Skip to content

Commit 98777ff

Browse files
feat: Adding automatic rollback in case of a change.
1 parent 329b4f0 commit 98777ff

File tree

1 file changed

+13
-4
lines changed

1 file changed

+13
-4
lines changed

diracx-db/src/diracx/db/sql/utils/functions.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import TYPE_CHECKING
77

88
from sqlalchemy import DateTime, func
9+
from sqlalchemy.ext.asyncio import AsyncConnection
910
from sqlalchemy.ext.compiler import compiles
1011
from sqlalchemy.sql import expression
1112

@@ -120,23 +121,31 @@ class DBStateAssertation:
120121
"""Class to handler context where we should not raise any excepted error.
121122
122123
Example:
123-
with DBStateAssertation([PilotNotFoundError, SecretNotFoundError]):
124+
with DBStateAssertation(self.conn, [PilotNotFoundError, SecretNotFoundError]):
124125
# Some code that should not raise any error filled previously
125126
# Else, raises an error
126127
128+
Note: Will rollback changes.
129+
127130
"""
128131

129-
def __init__(self, exceptions: list[type[Exception]]) -> None:
132+
def __init__(
133+
self, conn: AsyncConnection, exceptions: list[type[Exception]]
134+
) -> None:
130135
self.exceptions = exceptions
136+
self.conn = conn
131137

132-
def __enter__(self):
138+
async def __aenter__(self):
133139
pass
134140

135-
def __exit__(self, exc_type, exc_value, exc_tb):
141+
async def __aexit__(self, exc_type, exc_value, exc_tb):
136142
if exc_type is None:
137143
# No exception occurred
138144
return False
139145

146+
# If we get here, an error occured, so we rollback changes
147+
await self.conn.rollback()
148+
140149
# Check if the exception is among the expected ones
141150
if any(issubclass(exc_type, exc) for exc in self.exceptions):
142151
logger.error(

0 commit comments

Comments
 (0)