Skip to content

Commit b68c3b7

Browse files
feat: Add pilot management: create/delete/patch and query
1 parent 6674d37 commit b68c3b7

File tree

1 file changed

+87
-3
lines changed

1 file changed

+87
-3
lines changed

diracx-db/src/diracx/db/sql/utils/functions.py

Lines changed: 87 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,30 @@
22

33
import hashlib
44
from datetime import datetime, timedelta, timezone
5-
from typing import TYPE_CHECKING
5+
from typing import TYPE_CHECKING, Any, Sequence, Type
66

7-
from sqlalchemy import DateTime, func
7+
from sqlalchemy import DateTime, RowMapping, asc, desc, func, select
8+
from sqlalchemy.ext.asyncio import AsyncConnection
89
from sqlalchemy.ext.compiler import compiles
9-
from sqlalchemy.sql import expression
10+
from sqlalchemy.sql import ColumnElement, expression
11+
12+
from diracx.core.exceptions import DiracFormattedError, InvalidQueryError
1013

1114
if TYPE_CHECKING:
1215
from sqlalchemy.types import TypeEngine
1316

1417

18+
def _get_columns(table, parameters):
19+
columns = [x for x in table.columns]
20+
if parameters:
21+
if unrecognised_parameters := set(parameters) - set(table.columns.keys()):
22+
raise InvalidQueryError(
23+
f"Unrecognised parameters requested {unrecognised_parameters}"
24+
)
25+
columns = [c for c in columns if c.name in parameters]
26+
return columns
27+
28+
1529
class utcnow(expression.FunctionElement): # noqa: N801
1630
type: TypeEngine = DateTime()
1731
inherit_cache: bool = True
@@ -140,3 +154,73 @@ def substract_date(**kwargs: float) -> datetime:
140154

141155
def hash(code: str):
142156
return hashlib.sha256(code.encode()).hexdigest()
157+
158+
159+
def raw_hash(code: str):
160+
return hashlib.sha256(code.encode()).digest()
161+
162+
163+
async def fetch_records_bulk_or_raises(
164+
conn: AsyncConnection,
165+
model: Any, # Here, we currently must use `Any` because `declarative_base()` returns any
166+
missing_elements_error_cls: Type[DiracFormattedError],
167+
column_attribute_name: str,
168+
column_name: str,
169+
elements_to_fetch: list,
170+
order_by: tuple[str, str] | None = None,
171+
allow_more_than_one_result_per_input: bool = False,
172+
allow_no_result: bool = False,
173+
) -> Sequence[RowMapping]:
174+
"""Fetches a list of elements in a table, returns a list of elements.
175+
All elements from the `element_to_fetch` **must** be present.
176+
Raises the specified error if at least one is missing.
177+
178+
Example:
179+
fetch_records_bulk_or_raises(
180+
self.conn,
181+
PilotAgents,
182+
PilotNotFound,
183+
"pilot_id",
184+
"PilotID",
185+
[1,2,3]
186+
)
187+
188+
"""
189+
assert elements_to_fetch
190+
191+
# Get the column that needs to be in elements_to_fetch
192+
column = getattr(model, column_attribute_name)
193+
194+
# Create the request
195+
stmt = select(model).with_for_update().where(column.in_(elements_to_fetch))
196+
197+
if order_by:
198+
column_name_to_order_by, direction = order_by
199+
column_to_order_by = getattr(model, column_name_to_order_by)
200+
201+
operator: ColumnElement = (
202+
asc(column_to_order_by) if direction == "asc" else desc(column_to_order_by)
203+
)
204+
205+
stmt = stmt.order_by(operator)
206+
207+
# Transform into dictionaries
208+
raw_results = await conn.execute(stmt)
209+
results = raw_results.mappings().all()
210+
211+
# Detects duplicates
212+
if not allow_more_than_one_result_per_input:
213+
if len(results) > len(elements_to_fetch):
214+
raise RuntimeError("Seems to have duplicates in the database.")
215+
216+
if not allow_no_result:
217+
# Checks if we have every elements we wanted
218+
found_keys = {row[column_name] for row in results}
219+
missing = set(elements_to_fetch) - found_keys
220+
221+
if missing:
222+
raise missing_elements_error_cls(
223+
data={column_name: str(missing)}, detail=str(missing)
224+
)
225+
226+
return results

0 commit comments

Comments
 (0)