Skip to content

Commit 732238d

Browse files
authored
Add sorting to listing endpoints (#653)
* Add deterministic order for pagination * Add sorting to list endpoints * Add test for sorting behavior
1 parent c595bd5 commit 732238d

File tree

12 files changed

+148
-10
lines changed

12 files changed

+148
-10
lines changed

src/dependencies/sorting.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
from enum import StrEnum, auto
2+
from typing import Annotated
3+
4+
from fastapi import Query, Depends
5+
from pydantic import BaseModel
6+
from sqlmodel import Field
7+
8+
9+
class SortDirection(StrEnum):
10+
ASC = auto()
11+
DESC = auto()
12+
13+
14+
class Sort(StrEnum):
15+
DATE_CREATED = auto()
16+
DATE_MODIFIED = auto()
17+
18+
19+
class Sorting(BaseModel):
20+
"""Sorting modes for any AIoDConcept with an AIoD entry."""
21+
22+
direction: SortDirection = Field(
23+
Query(
24+
description="The direction of the sort (ascending or descending).",
25+
default=SortDirection.DESC,
26+
)
27+
)
28+
sort: Sort = Field(
29+
Query(
30+
description="The property to sort by.",
31+
default=Sort.DATE_MODIFIED,
32+
)
33+
)
34+
35+
36+
SortingParams = Annotated[Sorting, Depends(Sorting)]

src/routers/bookmark_router.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def list_bookmarks(
4848
stmt = (
4949
select(Bookmark)
5050
.where(Bookmark.user_identifier == user._subject_identifier)
51+
.order_by(Bookmark.created_at)
5152
.offset(pagination.offset)
5253
.limit(pagination.limit)
5354
)

src/routers/resource_router.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from database.session import DbSession
2929
from dependencies.filtering import ResourceFilters, ResourceFiltersParams
3030
from dependencies.pagination import Pagination, PaginationParams
31+
from dependencies.sorting import SortingParams, Sorting, SortDirection
3132
from error_handling import as_http_exception
3233
from database.model.ai_asset.distribution import Distribution
3334
from database.model.helper_functions import get_asset_type_by_abbreviation
@@ -210,6 +211,7 @@ def get_resources(
210211
self,
211212
schema: str,
212213
pagination: Pagination,
214+
sorting: Sorting,
213215
resource_filters: ResourceFilters,
214216
user: KeycloakUser | None = None,
215217
platform: str | None = None,
@@ -227,7 +229,7 @@ def get_resources(
227229
else cast(Callable, self.orm_to_read)
228230
)
229231
resources: Any = self._retrieve_resources_and_post_process(
230-
session, pagination, resource_filters, user, platform
232+
session, pagination, sorting, resource_filters, user, platform
231233
)
232234
for resource in resources:
233235
if not get_image and hasattr(resource, "media"):
@@ -289,13 +291,15 @@ def get_resources_func(self):
289291

290292
def get_resources(
291293
pagination: PaginationParams,
294+
sorting: SortingParams,
292295
resource_filters: ResourceFiltersParams,
293296
schema: self._possible_schemas_type = "aiod", # type:ignore
294297
user: KeycloakUser | None = Depends(get_user_or_none),
295298
):
296299
resources = self.get_resources(
297300
schema=schema,
298301
pagination=pagination,
302+
sorting=sorting,
299303
resource_filters=resource_filters,
300304
user=user,
301305
platform=None,
@@ -367,13 +371,15 @@ def get_resources(
367371
),
368372
],
369373
pagination: PaginationParams,
374+
sorting: SortingParams,
370375
resource_filters: ResourceFiltersParams,
371376
schema: self._possible_schemas_type = "aiod", # type:ignore
372377
user: KeycloakUser | None = Depends(get_user_or_none),
373378
):
374379
resources = self.get_resources(
375380
schema=schema,
376381
pagination=pagination,
382+
sorting=sorting,
377383
resource_filters=resource_filters,
378384
user=user,
379385
platform=platform,
@@ -725,6 +731,7 @@ def _retrieve_resources(
725731
self,
726732
session: Session,
727733
pagination: Pagination,
734+
sorting: Sorting,
728735
resource_filters: ResourceFilters,
729736
platform: str | None = None,
730737
) -> Sequence[type[RESOURCE_MODEL]]:
@@ -743,10 +750,17 @@ def _retrieve_resources(
743750
else True,
744751
AIoDEntryORM.status == EntryStatus.PUBLISHED,
745752
)
753+
sort_attribute = getattr(AIoDEntryORM, sorting.sort.lower())
754+
sort = (
755+
sort_attribute.asc()
756+
if sorting.direction == SortDirection.ASC
757+
else sort_attribute.desc()
758+
)
746759
query = (
747760
select(self.resource_class)
748761
.join(self.resource_class.aiod_entry, isouter=True)
749762
.where(where_clause)
763+
.order_by(sort, AIoDEntryORM.identifier.asc()) # type: ignore[attr-defined]
750764
.offset(pagination.offset)
751765
.limit(pagination.limit)
752766
)
@@ -773,6 +787,7 @@ def _retrieve_resources_and_post_process(
773787
self,
774788
session: Session,
775789
pagination: Pagination,
790+
sorting: Sorting,
776791
resource_filters: ResourceFilters,
777792
user: KeycloakUser | None = None,
778793
platform: str | None = None,
@@ -783,7 +798,7 @@ def _retrieve_resources_and_post_process(
783798
implement further verification on user access to the resource.
784799
"""
785800
resources: Sequence[type[RESOURCE_MODEL]] = self._retrieve_resources(
786-
session, pagination, resource_filters, platform
801+
session, pagination, sorting, resource_filters, platform
787802
)
788803
return self._mask_or_filter(resources, session, user)
789804

src/routers/resource_routers/organisation_router.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from database.model.agent.organisation import Organisation, organisation_versions
2+
from dependencies.sorting import SortingParams
23
from routers.resource_router import ResourceRouter
34
from fastapi import UploadFile, File, HTTPException, Query, status, APIRouter, Depends
45
from http import HTTPStatus
@@ -43,6 +44,7 @@ def resource_class(self) -> type[Organisation]:
4344
def get_resources_func(self):
4445
def get_resources(
4546
pagination: PaginationParams,
47+
sorting: SortingParams,
4648
resource_filters: ResourceFiltersParams,
4749
schema: self._possible_schemas_type = "aiod", # type:ignore
4850
get_image: bool = Query(False, description="Include image bytes in response?"),
@@ -51,6 +53,7 @@ def get_resources(
5153
return self.get_resources(
5254
schema=schema,
5355
pagination=pagination,
56+
sorting=sorting,
5457
resource_filters=resource_filters,
5558
user=user,
5659
get_image=get_image,

src/routers/resource_routers/platform_router.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,12 @@ def _retrieve_resources(
301301
"""
302302
Retrieve a sequence of resources from the database based on the provided identifier.
303303
"""
304-
query = select(self.resource_class).offset(pagination.offset).limit(pagination.limit)
304+
query = (
305+
select(self.resource_class)
306+
.order_by(Platform.identifier)
307+
.offset(pagination.offset)
308+
.limit(pagination.limit)
309+
)
305310
resources: Sequence = session.scalars(query).all()
306311
return resources
307312

src/routers/resource_routers/project_router.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from database.model.project.project import Project, project_versions
55
from dependencies.filtering import ResourceFiltersParams
66
from dependencies.pagination import PaginationParams
7+
from dependencies.sorting import SortingParams
78
from routers.resource_routers.organisation_router import add_custom_routes
89
from routers.resource_router import ResourceRouter
910
from versioning import Version
@@ -30,6 +31,7 @@ def resource_class(self) -> type[Project]:
3031
def get_resources_func(self):
3132
def get_resources(
3233
pagination: PaginationParams,
34+
sorting: SortingParams,
3335
resource_filters: ResourceFiltersParams,
3436
schema: self._possible_schemas_type = "aiod", # type:ignore
3537
get_image: bool = Query(False, description="Include image bytes in response?"),
@@ -38,6 +40,7 @@ def get_resources(
3840
return self.get_resources(
3941
schema=schema,
4042
pagination=pagination,
43+
sorting=sorting,
4144
resource_filters=resource_filters,
4245
user=user,
4346
get_image=get_image,

src/routers/user_router.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import datetime
12
from typing import List
23

34
from fastapi import APIRouter, Depends
@@ -6,6 +7,7 @@
67
from sqlmodel import Session
78

89
from dependencies.pagination import PaginationParams
10+
from dependencies.sorting import SortingParams, SortDirection, Sort
911
from routers.resource_routers import versioned_routers
1012
from authentication import KeycloakUser, get_user_or_raise
1113
from database.authorization import Permission, PermissionType
@@ -48,6 +50,7 @@ def create(url_prefix: str, version: Version) -> APIRouter:
4850
)
4951
def get_versioned_resources_for_user(
5052
pagination: PaginationParams,
53+
sorting: SortingParams,
5154
user: KeycloakUser = Depends(get_user_or_raise),
5255
session: Session = Depends(get_session),
5356
) -> dict[str, list[AIoDConcept]]:
@@ -60,6 +63,8 @@ def get_versioned_resources_for_user(
6063
session,
6164
offset=pagination.offset,
6265
limit=limit,
66+
sort_by=sorting.sort,
67+
sort_direction=sorting.direction,
6368
)
6469
all_assets = [asset for assets in resources.values() for asset in assets]
6570
for resource in [a for a in all_assets if hasattr(a, "media")]:
@@ -70,8 +75,17 @@ def get_versioned_resources_for_user(
7075
r.resource_class.__tablename__: r.orm_to_read
7176
for r in versioned_routers.get(version, [])
7277
}
78+
79+
def sort_function(asset):
80+
value = getattr(asset.aiod_entry, sorting.sort)
81+
direction = -1 if sorting.direction == SortDirection.DESC else 1
82+
return direction * datetime.datetime.timestamp(value)
83+
7384
return {
74-
asset_name: [orm_to_read[asset_name](asset) for asset in assets]
85+
asset_name: sorted(
86+
(orm_to_read[asset_name](asset) for asset in assets),
87+
key=sort_function,
88+
)
7589
for asset_name, assets in resources.items()
7690
}
7791

@@ -84,15 +98,20 @@ def _get_resources_for_user(
8498
*,
8599
offset: int = 0,
86100
limit: int | None = None,
101+
sort_by: Sort = Sort.DATE_MODIFIED,
102+
sort_direction: SortDirection = SortDirection.DESC,
87103
) -> dict[str, list[AIoDConcept]]:
88104
# "Ownership" is currently equivalent to having ADMIN permissions
105+
sort_attribute = getattr(AIoDEntryORM, sort_by.lower())
106+
sort = sort_attribute.asc() if sort_direction == SortDirection.ASC else sort_attribute.desc()
89107
stmt = (
90108
select(AIoDEntryORM)
91109
.join(Permission.aiod_entry)
92110
.where(
93111
Permission.user_identifier == user._subject_identifier,
94112
Permission.type_ == PermissionType.ADMIN,
95113
)
114+
.order_by(sort, AIoDEntryORM.identifier.asc()) # type: ignore[attr-defined]
96115
.offset(offset)
97116
)
98117
if limit:

src/tests/routers/generic/test_router_get_all.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def test_get_all_happy_path(client_test_resource: TestClient):
1919
with DbSession() as session:
2020
session.add_all(resources)
2121
session.commit()
22-
response = client_test_resource.get("/test_resources")
22+
response = client_test_resource.get("/test_resources?direction=asc")
2323
assert response.status_code == 200, response.json()
2424
response_json = response.json()
2525

src/tests/routers/generic/test_router_platform_get_all.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def test_get_all_happy_path(client_test_resource: TestClient, auto_publish):
2222
with DbSession() as session:
2323
session.add_all(resources)
2424
session.commit()
25-
response = client_test_resource.get("/platforms/example/test_resources")
25+
response = client_test_resource.get("/platforms/example/test_resources?direction=asc")
2626
assert response.status_code == 200, response.json()
2727
response_json = response.json()
2828

src/tests/routers/generic/test_router_relations.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def test_get_happy_path(test_objects: list[TestObject], client_with_testobject:
181181

182182

183183
def test_get_all_happy_path(client_with_testobject: TestClient):
184-
response = client_with_testobject.get("/test_resources")
184+
response = client_with_testobject.get("/test_resources?direction=asc")
185185
assert response.status_code == 200, response.json()
186186
response_json = response.json()
187187
assert "deprecated" not in response.headers
@@ -213,7 +213,7 @@ def test_post_happy_path(client_with_testobject: TestClient, auto_publish: None)
213213
headers={"Authorization": "Fake token"},
214214
)
215215
assert response.status_code == 200, response.json()
216-
objects = client_with_testobject.get("/test_resources").json()
216+
objects = client_with_testobject.get("/test_resources?direction=asc").json()
217217
obj = objects[-1]
218218
assert obj["title"] == "title"
219219
assert obj["named_string"] == "named_string1"

0 commit comments

Comments
 (0)