|
6 | 6 | from typing import TYPE_CHECKING |
7 | 7 |
|
8 | 8 | from sqlalchemy import DateTime, func |
| 9 | +from sqlalchemy.ext.asyncio import AsyncConnection |
9 | 10 | from sqlalchemy.ext.compiler import compiles |
10 | 11 | from sqlalchemy.sql import expression |
11 | 12 |
|
@@ -120,23 +121,31 @@ class DBStateAssertation: |
120 | 121 | """Class to handler context where we should not raise any excepted error. |
121 | 122 |
|
122 | 123 | Example: |
123 | | - with DBStateAssertation([PilotNotFoundError, SecretNotFoundError]): |
| 124 | + with DBStateAssertation(self.conn, [PilotNotFoundError, SecretNotFoundError]): |
124 | 125 | # Some code that should not raise any error filled previously |
125 | 126 | # Else, raises an error |
126 | 127 |
|
| 128 | + Note: Will rollback changes. |
| 129 | +
|
127 | 130 | """ |
128 | 131 |
|
129 | | - def __init__(self, exceptions: list[type[Exception]]) -> None: |
| 132 | + def __init__( |
| 133 | + self, conn: AsyncConnection, exceptions: list[type[Exception]] |
| 134 | + ) -> None: |
130 | 135 | self.exceptions = exceptions |
| 136 | + self.conn = conn |
131 | 137 |
|
132 | | - def __enter__(self): |
| 138 | + async def __aenter__(self): |
133 | 139 | pass |
134 | 140 |
|
135 | | - def __exit__(self, exc_type, exc_value, exc_tb): |
| 141 | + async def __aexit__(self, exc_type, exc_value, exc_tb): |
136 | 142 | if exc_type is None: |
137 | 143 | # No exception occurred |
138 | 144 | return False |
139 | 145 |
|
| 146 | + # If we get here, an error occured, so we rollback changes |
| 147 | + await self.conn.rollback() |
| 148 | + |
140 | 149 | # Check if the exception is among the expected ones |
141 | 150 | if any(issubclass(exc_type, exc) for exc in self.exceptions): |
142 | 151 | logger.error( |
|
0 commit comments