Skip to content

Commit 6b515ac

Browse files
authored
bookmark model and router (#560)
2 parents c0a92dc + b1460ff commit 6b515ac

File tree

9 files changed

+288
-4
lines changed

9 files changed

+288
-4
lines changed

src/database/model/bookmark/__init__.py

Whitespace-only changes.
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from sqlmodel import SQLModel, Field
2+
from typing import Optional
3+
from datetime import datetime
4+
from sqlalchemy import Column, String
5+
6+
7+
class Bookmark(SQLModel, table=True): # type: ignore [call-arg]
8+
__tablename__ = "bookmark"
9+
user_identifier: str = Field(
10+
foreign_key="user.subject_identifier",
11+
nullable=False,
12+
ondelete="CASCADE",
13+
primary_key=True,
14+
description="The sub-identifier of the user who created the bookmark.",
15+
)
16+
resource_identifier: str = Field(
17+
primary_key=True, description="The identifier of the resource being bookmarked."
18+
)
19+
created_at: datetime = Field(
20+
default_factory=datetime.utcnow, description="The time when the bookmark was created."
21+
)

src/main.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
search_routers,
4141
review_router,
4242
user_router,
43+
bookmark_router,
4344
)
4445
from setup_logger import setup_logger
4546

@@ -85,7 +86,7 @@ def counts() -> dict:
8586
+ parent_routers.router_list
8687
+ enum_routers.router_list
8788
+ search_routers.router_list
88-
+ [review_router, user_router]
89+
+ [review_router, user_router, bookmark_router]
8990
):
9091
app.include_router(router.create(url_prefix))
9192

src/routers/bookmark_router.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
import sqlalchemy.exc
2+
from fastapi import APIRouter, Depends, HTTPException
3+
from typing import List, cast
4+
from sqlmodel import Session, select, Field, SQLModel
5+
6+
from authentication import KeycloakUser, get_user_or_raise
7+
from database.session import get_session
8+
from database.model.bookmark.bookmark import Bookmark
9+
from http import HTTPStatus
10+
from datetime import datetime
11+
from routers.helper_functions import get_asset_type_by_abbreviation
12+
13+
14+
class BookmarkRead(SQLModel):
15+
resource_identifier: str = Field(description="The identifier of the resource being bookmarked.")
16+
created_at: datetime = Field(
17+
description="The time when the bookmark was created in ISO 8601 format."
18+
)
19+
20+
class Config:
21+
json_encoders = {datetime: lambda dt: dt.isoformat()}
22+
23+
24+
def create(url_prefix: str = "") -> APIRouter:
25+
router = APIRouter()
26+
27+
for path in [
28+
f"{url_prefix}/v2/bookmarks",
29+
f"{url_prefix}/bookmarks",
30+
]:
31+
32+
@router.get(
33+
path,
34+
tags=["User"],
35+
description="Return all your bookmarks.",
36+
response_model=List[BookmarkRead],
37+
)
38+
def list_bookmarks(
39+
user: KeycloakUser = Depends(get_user_or_raise), session: Session = Depends(get_session)
40+
) -> List[BookmarkRead]:
41+
return session.exec(
42+
select(Bookmark).where(Bookmark.user_identifier == user._subject_identifier)
43+
).all()
44+
45+
@router.post(
46+
path,
47+
tags=["User"],
48+
response_model=BookmarkRead,
49+
description="Add the asset to the logged-in user's bookmarks."
50+
"If it was already bookmarked, return the existing bookmark.",
51+
status_code=HTTPStatus.OK,
52+
)
53+
def create_bookmark(
54+
resource_identifier: str,
55+
user: KeycloakUser = Depends(get_user_or_raise),
56+
session: Session = Depends(get_session),
57+
) -> BookmarkRead:
58+
if not resource_identifier_exists_in_database(resource_identifier, session):
59+
raise HTTPException(
60+
status_code=HTTPStatus.NOT_FOUND,
61+
detail=f"Resource {resource_identifier} does not exist.",
62+
)
63+
64+
try:
65+
bookmark = Bookmark(
66+
user_identifier=user._subject_identifier,
67+
resource_identifier=resource_identifier,
68+
)
69+
session.add(bookmark)
70+
session.commit()
71+
except sqlalchemy.exc.IntegrityError: # The entry already exists
72+
session.rollback()
73+
bookmark = session.get(Bookmark, (user._subject_identifier, resource_identifier))
74+
return cast(BookmarkRead, bookmark)
75+
76+
@router.delete(
77+
path,
78+
tags=["User"],
79+
description="Delete a bookmark for the logged-in user by resource identifier."
80+
"Also returns HTTP status code OK (200) if no such bookmark existed.",
81+
status_code=HTTPStatus.OK,
82+
)
83+
def delete_bookmark(
84+
resource_identifier: str,
85+
user: KeycloakUser = Depends(get_user_or_raise),
86+
session: Session = Depends(get_session),
87+
):
88+
bookmark = session.get(Bookmark, (user._subject_identifier, resource_identifier))
89+
if bookmark:
90+
session.delete(bookmark)
91+
session.commit()
92+
return None
93+
94+
return router
95+
96+
97+
def resource_identifier_exists_in_database(resource_identifier: str, session: Session) -> bool:
98+
"""
99+
Returns True if the given identifier exists in any of the tables.
100+
"""
101+
asset_type = get_asset_type_by_abbreviation().get(resource_identifier.split("_")[0], None)
102+
if asset_type:
103+
query = select(asset_type).where(asset_type.identifier == resource_identifier)
104+
return session.exec(query).first() is not None
105+
106+
return False

src/routers/helper_functions.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import routers
22
from database.model.concept.concept import AIoDConcept
33
from database.model.helper_functions import non_abstract_subclasses
4+
from functools import cache
45

56

67
def get_all_read_classes() -> dict[str, AIoDConcept]:
@@ -22,3 +23,12 @@ def get_all_asset_schemas():
2223
return [
2324
{"$ref": f"#/components/schemas/{clz.__name__}"} for clz in get_all_read_classes().values()
2425
]
26+
27+
28+
@cache
29+
def get_asset_type_by_abbreviation() -> dict[str, type[AIoDConcept]]:
30+
return {
31+
cls.__abbreviation__: cls
32+
for cls in non_abstract_subclasses(AIoDConcept)
33+
if hasattr(cls, "__abbreviation__")
34+
}

src/routers/resource_routers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from .resource_bundle_router import ResourceBundleRouter
1818
from .. import ResourceRouter
1919

20+
2021
router_list: list[ResourceRouter | PlatformRouter] = [
2122
PlatformRouter(),
2223
CaseStudyRouter(),

src/tests/routers/generic/test_router_get_count.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from database.session import DbSession
1010
from tests.testutils.test_resource import factory_test_resource
1111
from tests.testutils.users import register_asset
12-
12+
from database.model.concept.aiod_entry import EntryStatus, AIoDEntryORM
1313

1414
def test_get_count_happy_path(client_test_resource: TestClient):
1515
with DbSession() as session:
@@ -60,7 +60,6 @@ def test_get_count_detailed_happy_path(client_test_resource: TestClient):
6060
assert response_json == {"aiod": 1, "example": 2, "openml": 1}
6161
assert "deprecated" not in response.headers
6262

63-
from database.model.concept.aiod_entry import EntryStatus, AIoDEntryORM
6463
# default platfrom is "aiod"
6564
def test_get_count_total(
6665
client: TestClient,
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
from http import HTTPStatus
2+
3+
from starlette.testclient import TestClient
4+
from database.session import DbSession
5+
6+
from tests.testutils.users import logged_in_user, ALICE
7+
from database.model.bookmark.bookmark import Bookmark
8+
from database.session import DbSession
9+
from tests.testutils.users import register_asset, register_user
10+
from datetime import datetime
11+
from database.model.agent.person import Person
12+
from database.model.agent.contact import Contact
13+
14+
15+
def test_create_bookmark(
16+
client: TestClient,
17+
person: Person) -> None:
18+
19+
identifier = register_asset(person)
20+
21+
with logged_in_user():
22+
response = client.post(
23+
f"/bookmarks?resource_identifier={identifier}",
24+
headers={"Authorization": "fake token"},
25+
)
26+
assert response.status_code == HTTPStatus.OK
27+
bookmark = response.json()
28+
assert bookmark["resource_identifier"] == identifier
29+
now = datetime.utcnow()
30+
assert (now - datetime.fromisoformat(bookmark["created_at"])).total_seconds() < 2
31+
32+
33+
def test_create_bookmark_for_non_existing_resource(client: TestClient) -> None:
34+
# Create bookmark for non existing resource.
35+
36+
with logged_in_user():
37+
response = client.post(
38+
f"/bookmarks?resource_identifier=wrong_indetifier",
39+
headers={"Authorization": "fake token"},
40+
)
41+
42+
assert response.status_code == HTTPStatus.NOT_FOUND
43+
assert response.json()["detail"] == f"Resource wrong_indetifier does not exist."
44+
45+
46+
def test_create_duplicate(
47+
client: TestClient,
48+
person: Person
49+
) -> None:
50+
51+
with DbSession() as session:
52+
identifier = register_asset(person)
53+
user = register_user(ALICE, session)
54+
now = datetime.utcnow()
55+
bookmark = Bookmark(
56+
user_identifier=user.subject_identifier,
57+
resource_identifier=identifier,
58+
created_at=now
59+
)
60+
session.add(bookmark)
61+
session.commit()
62+
63+
# Attempt to create a duplicate bookmark
64+
with logged_in_user(ALICE):
65+
response = client.post(
66+
f"/bookmarks?resource_identifier={identifier}",
67+
headers={"Authorization": "fake token"},
68+
)
69+
70+
assert response.status_code == HTTPStatus.OK
71+
bookmark = response.json()
72+
assert bookmark["resource_identifier"] == identifier
73+
assert bookmark["created_at"] == now.isoformat()
74+
75+
76+
def test_get_bookmarks(client: TestClient, person: Person, contact: Contact) -> None:
77+
with DbSession() as session:
78+
prsn_id = register_asset(person)
79+
contact_id = register_asset(contact)
80+
user = register_user(ALICE, session)
81+
session.commit()
82+
83+
# Add a bookmark
84+
with logged_in_user(ALICE):
85+
response = client.post(
86+
f"/bookmarks?resource_identifier={prsn_id}",
87+
headers={"Authorization": "fake token"},
88+
)
89+
assert response.status_code == HTTPStatus.OK
90+
91+
response = client.post(
92+
f"/bookmarks?resource_identifier={contact_id}",
93+
headers={"Authorization": "fake token"},
94+
)
95+
assert response.status_code == HTTPStatus.OK
96+
97+
# Fetch bookmarks
98+
response = client.get(
99+
"/bookmarks",
100+
headers={"Authorization": "fake token"},
101+
)
102+
103+
assert response.status_code == HTTPStatus.OK
104+
bookmarks = response.json()
105+
assert len(bookmarks) == 2
106+
107+
108+
def test_delete_bookmark(
109+
client: TestClient,
110+
person: Person
111+
) -> None:
112+
113+
114+
with DbSession() as session:
115+
identifier = register_asset(person)
116+
register_user(ALICE, session)
117+
session.commit()
118+
119+
with logged_in_user(ALICE):
120+
response = client.post(
121+
f"/bookmarks?resource_identifier={identifier}",
122+
headers={"Authorization": "fake token"},
123+
)
124+
assert response.status_code == HTTPStatus.OK
125+
126+
127+
with logged_in_user(ALICE):
128+
response = client.delete(
129+
f"/bookmarks?resource_identifier={identifier}",
130+
headers={"Authorization": "fake token"},
131+
)
132+
assert response.status_code == HTTPStatus.OK
133+
assert response.json() == None
134+
135+
# Confirm it's deleted
136+
with logged_in_user(ALICE):
137+
response = client.get(
138+
"/bookmarks",
139+
headers={"Authorization": "fake token"},
140+
)
141+
assert response.status_code == HTTPStatus.OK
142+
assert all(b["resource_identifier"] != identifier for b in response.json())

src/tests/testutils/default_sqlalchemy.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,11 @@ def clear_db(request, engine: Engine):
125125
with engine.connect() as connection:
126126
transaction = connection.begin()
127127
for table in reversed(SQLModel.metadata.sorted_tables):
128-
connection.execute(table.delete())
128+
try:
129+
connection.execute(table.delete())
130+
except Exception as e:
131+
print(f"Error while clearing table {table.name}: {e}")
132+
raise
129133
transaction.commit()
130134

131135

0 commit comments

Comments
 (0)