|
2 | 2 |
|
3 | 3 | import hashlib |
4 | 4 | from datetime import datetime, timedelta, timezone |
5 | | -from typing import TYPE_CHECKING |
| 5 | +from typing import TYPE_CHECKING, Any, Sequence, Type |
6 | 6 |
|
7 | | -from sqlalchemy import DateTime, func |
| 7 | +from sqlalchemy import DateTime, RowMapping, asc, desc, func, select |
| 8 | +from sqlalchemy.ext.asyncio import AsyncConnection |
8 | 9 | from sqlalchemy.ext.compiler import compiles |
9 | | -from sqlalchemy.sql import expression |
| 10 | +from sqlalchemy.sql import ColumnElement, expression |
| 11 | + |
| 12 | +from diracx.core.exceptions import DiracFormattedError, InvalidQueryError |
10 | 13 |
|
11 | 14 | if TYPE_CHECKING: |
12 | 15 | from sqlalchemy.types import TypeEngine |
13 | 16 |
|
14 | 17 |
|
| 18 | +def _get_columns(table, parameters): |
| 19 | + columns = [x for x in table.columns] |
| 20 | + if parameters: |
| 21 | + if unrecognised_parameters := set(parameters) - set(table.columns.keys()): |
| 22 | + raise InvalidQueryError( |
| 23 | + f"Unrecognised parameters requested {unrecognised_parameters}" |
| 24 | + ) |
| 25 | + columns = [c for c in columns if c.name in parameters] |
| 26 | + return columns |
| 27 | + |
| 28 | + |
15 | 29 | class utcnow(expression.FunctionElement): # noqa: N801 |
16 | 30 | type: TypeEngine = DateTime() |
17 | 31 | inherit_cache: bool = True |
@@ -140,3 +154,73 @@ def substract_date(**kwargs: float) -> datetime: |
140 | 154 |
|
141 | 155 | def hash(code: str): |
142 | 156 | return hashlib.sha256(code.encode()).hexdigest() |
| 157 | + |
| 158 | + |
| 159 | +def raw_hash(code: str): |
| 160 | + return hashlib.sha256(code.encode()).digest() |
| 161 | + |
| 162 | + |
| 163 | +async def fetch_records_bulk_or_raises( |
| 164 | + conn: AsyncConnection, |
| 165 | + model: Any, # Here, we currently must use `Any` because `declarative_base()` returns any |
| 166 | + missing_elements_error_cls: Type[DiracFormattedError], |
| 167 | + column_attribute_name: str, |
| 168 | + column_name: str, |
| 169 | + elements_to_fetch: list, |
| 170 | + order_by: tuple[str, str] | None = None, |
| 171 | + allow_more_than_one_result_per_input: bool = False, |
| 172 | + allow_no_result: bool = False, |
| 173 | +) -> Sequence[RowMapping]: |
| 174 | + """Fetches a list of elements in a table, returns a list of elements. |
| 175 | + All elements from the `element_to_fetch` **must** be present. |
| 176 | + Raises the specified error if at least one is missing. |
| 177 | +
|
| 178 | + Example: |
| 179 | + fetch_records_bulk_or_raises( |
| 180 | + self.conn, |
| 181 | + PilotAgents, |
| 182 | + PilotNotFound, |
| 183 | + "pilot_id", |
| 184 | + "PilotID", |
| 185 | + [1,2,3] |
| 186 | + ) |
| 187 | +
|
| 188 | + """ |
| 189 | + assert elements_to_fetch |
| 190 | + |
| 191 | + # Get the column that needs to be in elements_to_fetch |
| 192 | + column = getattr(model, column_attribute_name) |
| 193 | + |
| 194 | + # Create the request |
| 195 | + stmt = select(model).with_for_update().where(column.in_(elements_to_fetch)) |
| 196 | + |
| 197 | + if order_by: |
| 198 | + column_name_to_order_by, direction = order_by |
| 199 | + column_to_order_by = getattr(model, column_name_to_order_by) |
| 200 | + |
| 201 | + operator: ColumnElement = ( |
| 202 | + asc(column_to_order_by) if direction == "asc" else desc(column_to_order_by) |
| 203 | + ) |
| 204 | + |
| 205 | + stmt = stmt.order_by(operator) |
| 206 | + |
| 207 | + # Transform into dictionaries |
| 208 | + raw_results = await conn.execute(stmt) |
| 209 | + results = raw_results.mappings().all() |
| 210 | + |
| 211 | + # Detects duplicates |
| 212 | + if not allow_more_than_one_result_per_input: |
| 213 | + if len(results) > len(elements_to_fetch): |
| 214 | + raise RuntimeError("Seems to have duplicates in the database.") |
| 215 | + |
| 216 | + if not allow_no_result: |
| 217 | + # Checks if we have every elements we wanted |
| 218 | + found_keys = {row[column_name] for row in results} |
| 219 | + missing = set(elements_to_fetch) - found_keys |
| 220 | + |
| 221 | + if missing: |
| 222 | + raise missing_elements_error_cls( |
| 223 | + data={column_name: str(missing)}, detail=str(missing) |
| 224 | + ) |
| 225 | + |
| 226 | + return results |
0 commit comments