Skip to content

Commit 4d4f9cb

Browse files
committed
refactor: add validation to storage, repo and session blueprints
1 parent c42f324 commit 4d4f9cb

File tree

9 files changed

+91
-78
lines changed

9 files changed

+91
-78
lines changed

components/renku_data_services/repositories/blueprints.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,15 @@
33
from dataclasses import dataclass
44
from urllib.parse import unquote
55

6-
from sanic import HTTPResponse, Request, json
6+
from sanic import HTTPResponse, Request
77
from sanic.response import JSONResponse
88

99
import renku_data_services.base_models as base_models
1010
from renku_data_services import errors
1111
from renku_data_services.base_api.auth import authenticate
1212
from renku_data_services.base_api.blueprint import BlueprintFactoryResponse, CustomBlueprint
1313
from renku_data_services.base_api.etag import extract_if_none_match
14+
from renku_data_services.base_models.validation import validated_json
1415
from renku_data_services.repositories import apispec
1516
from renku_data_services.repositories.apispec_base import RepositoryParams
1617
from renku_data_services.repositories.db import GitRepositoriesRepository
@@ -53,10 +54,7 @@ async def get_internal_gitlab_user() -> base_models.APIUser:
5354
if result.repository_metadata and result.repository_metadata.etag is not None
5455
else None
5556
)
56-
return json(
57-
apispec.RepositoryProviderMatch.model_validate(result).model_dump(exclude_none=True, mode="json"),
58-
headers=headers,
59-
)
57+
return validated_json(apispec.RepositoryProviderMatch, result, headers=headers)
6058

6159
return "/repositories/<repository_url>", ["GET"], _get_one_repository
6260

components/renku_data_services/session/apispec_base.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,26 @@ class Config:
1414

1515
from_attributes = True
1616

17+
@field_validator("id", mode="before", check_fields=False)
18+
@classmethod
19+
def serialize_id(cls, id: str | ULID) -> str:
20+
"""Custom serializer that can handle ULIDs."""
21+
return str(id)
22+
1723
@field_validator("project_id", mode="before", check_fields=False)
1824
@classmethod
19-
def serialize_id(cls, project_id: str | ULID) -> str:
25+
def serialize_project_id(cls, project_id: str | ULID) -> str:
2026
"""Custom serializer that can handle ULIDs."""
2127
return str(project_id)
2228

29+
@field_validator("environment_id", mode="before", check_fields=False)
30+
@classmethod
31+
def serialize_environment_id(cls, environment_id: str | ULID | None) -> str | None:
32+
"""Custom serializer that can handle ULIDs."""
33+
if environment_id is None:
34+
return None
35+
return str(environment_id)
36+
2337
@field_validator("environment_kind", mode="before", check_fields=False)
2438
@classmethod
2539
def serialize_environment_kind(cls, environment_kind: models.EnvironmentKind | str) -> str:

components/renku_data_services/session/blueprints.py

Lines changed: 11 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from dataclasses import dataclass
44
from datetime import UTC, datetime
55

6-
from sanic import HTTPResponse, Request, json
6+
from sanic import HTTPResponse, Request
77
from sanic.response import JSONResponse
88
from sanic_ext import validate
99
from ulid import ULID
@@ -12,6 +12,7 @@
1212
from renku_data_services import errors
1313
from renku_data_services.base_api.auth import authenticate, only_authenticated
1414
from renku_data_services.base_api.blueprint import BlueprintFactoryResponse, CustomBlueprint
15+
from renku_data_services.base_models.validation import validated_json
1516
from renku_data_services.session import apispec, models
1617
from renku_data_services.session.db import SessionRepository
1718

@@ -28,9 +29,7 @@ def get_all(self) -> BlueprintFactoryResponse:
2829

2930
async def _get_all(_: Request) -> JSONResponse:
3031
environments = await self.session_repo.get_environments()
31-
return json(
32-
[apispec.Environment.model_validate(e).model_dump(exclude_none=True, mode="json") for e in environments]
33-
)
32+
return validated_json(apispec.EnvironmentList, environments)
3433

3534
return "/environments", ["GET"], _get_all
3635

@@ -39,7 +38,7 @@ def get_one(self) -> BlueprintFactoryResponse:
3938

4039
async def _get_one(_: Request, environment_id: ULID) -> JSONResponse:
4140
environment = await self.session_repo.get_environment(environment_id=environment_id)
42-
return json(apispec.Environment.model_validate(environment).model_dump(exclude_none=True, mode="json"))
41+
return validated_json(apispec.Environment, environment)
4342

4443
return "/environments/<environment_id:ulid>", ["GET"], _get_one
4544

@@ -59,7 +58,7 @@ async def _post(_: Request, user: base_models.APIUser, body: apispec.Environment
5958
creation_date=datetime.now(UTC).replace(microsecond=0),
6059
)
6160
environment = await self.session_repo.insert_environment(user=user, new_environment=environment_model)
62-
return json(apispec.Environment.model_validate(environment).model_dump(exclude_none=True, mode="json"), 201)
61+
return validated_json(apispec.Environment, environment, 201)
6362

6463
return "/environments", ["POST"], _post
6564

@@ -75,7 +74,7 @@ async def _patch(
7574
environment = await self.session_repo.update_environment(
7675
user=user, environment_id=environment_id, **body_dict
7776
)
78-
return json(apispec.Environment.model_validate(environment).model_dump(exclude_none=True, mode="json"))
77+
return validated_json(apispec.Environment, environment)
7978

8079
return "/environments/<environment_id:ulid>", ["PATCH"], _patch
8180

@@ -103,12 +102,7 @@ def get_all(self) -> BlueprintFactoryResponse:
103102
@authenticate(self.authenticator)
104103
async def _get_all(_: Request, user: base_models.APIUser) -> JSONResponse:
105104
launchers = await self.session_repo.get_launchers(user=user)
106-
return json(
107-
[
108-
apispec.SessionLauncher.model_validate(item).model_dump(exclude_none=True, mode="json")
109-
for item in launchers
110-
]
111-
)
105+
return validated_json(apispec.SessionLaunchersList, launchers)
112106

113107
return "/session_launchers", ["GET"], _get_all
114108

@@ -118,7 +112,7 @@ def get_one(self) -> BlueprintFactoryResponse:
118112
@authenticate(self.authenticator)
119113
async def _get_one(_: Request, user: base_models.APIUser, launcher_id: ULID) -> JSONResponse:
120114
launcher = await self.session_repo.get_launcher(user=user, launcher_id=launcher_id)
121-
return json(apispec.SessionLauncher.model_validate(launcher).model_dump(exclude_none=True, mode="json"))
115+
return validated_json(apispec.SessionLauncher, launcher)
122116

123117
return "/session_launchers/<launcher_id:ulid>", ["GET"], _get_one
124118

@@ -150,9 +144,7 @@ async def _post(_: Request, user: base_models.APIUser, body: apispec.SessionLaun
150144
creation_date=datetime.now(UTC).replace(microsecond=0),
151145
)
152146
launcher = await self.session_repo.insert_launcher(user=user, new_launcher=launcher_model)
153-
return json(
154-
apispec.SessionLauncher.model_validate(launcher).model_dump(exclude_none=True, mode="json"), 201
155-
)
147+
return validated_json(apispec.SessionLauncher, launcher, 201)
156148

157149
return "/session_launchers", ["POST"], _post
158150

@@ -166,7 +158,7 @@ async def _patch(
166158
) -> JSONResponse:
167159
body_dict = body.model_dump(exclude_none=True)
168160
launcher = await self.session_repo.update_launcher(user=user, launcher_id=launcher_id, **body_dict)
169-
return json(apispec.SessionLauncher.model_validate(launcher).model_dump(exclude_none=True, mode="json"))
161+
return validated_json(apispec.SessionLauncher, launcher)
170162

171163
return "/session_launchers/<launcher_id:ulid>", ["PATCH"], _patch
172164

@@ -186,11 +178,6 @@ def get_project_launchers(self) -> BlueprintFactoryResponse:
186178
@authenticate(self.authenticator)
187179
async def _get_launcher(_: Request, user: base_models.APIUser, project_id: ULID) -> JSONResponse:
188180
launchers = await self.session_repo.get_project_launchers(user=user, project_id=project_id)
189-
return json(
190-
[
191-
apispec.SessionLauncher.model_validate(item).model_dump(exclude_none=True, mode="json")
192-
for item in launchers
193-
]
194-
)
181+
return validated_json(apispec.SessionLaunchersList, launchers)
195182

196183
return "/projects/<project_id:ulid>/session_launchers", ["GET"], _get_launcher

components/renku_data_services/session/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ class UnsavedEnvironment(BaseModel):
4141
class Environment(UnsavedEnvironment): # type: ignore[misc]
4242
"""Session environment model."""
4343

44-
id: str
44+
id: ULID
4545

4646

4747
@dataclass(frozen=True, eq=True, kw_only=True)

components/renku_data_services/session/orm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from renku_data_services.crc.orm import ResourceClassORM
1111
from renku_data_services.project.orm import ProjectORM
1212
from renku_data_services.session import models
13+
from renku_data_services.utils.sqlalchemy import ULIDType
1314

1415
metadata_obj = MetaData(schema="sessions") # Has to match alembic ini section name
1516

@@ -25,7 +26,7 @@ class EnvironmentORM(BaseORM):
2526

2627
__tablename__ = "environments"
2728

28-
id: Mapped[str] = mapped_column("id", String(26), primary_key=True, default_factory=lambda: str(ULID()), init=False)
29+
id: Mapped[ULID] = mapped_column("id", ULIDType, primary_key=True, default_factory=lambda: str(ULID()), init=False)
2930
"""Id of this session environment object."""
3031

3132
name: Mapped[str] = mapped_column("name", String(99))

components/renku_data_services/storage/api.spec.yaml

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ paths:
262262
content:
263263
application/json:
264264
schema:
265-
$ref: "#/components/schemas/RCloneConfig"
265+
$ref: "#/components/schemas/RCloneConfigValidate"
266266
responses:
267267
"204":
268268
description: The configuration is valid
@@ -346,6 +346,16 @@ components:
346346
nullable: true
347347
- type: boolean
348348
- type: object
349+
RCloneConfigValidate: #this is the same as RCloneConfig but duplicated so a class gets generated
350+
type: object
351+
description: Dictionary of rclone key:value pairs (based on schema from '/storage_schema')
352+
additionalProperties:
353+
oneOf:
354+
- type: integer
355+
- type: string
356+
nullable: true
357+
- type: boolean
358+
- type: object
349359
CloudStorageUrl:
350360
allOf:
351361
- $ref: "#/components/schemas/GitRequest"

components/renku_data_services/storage/apispec.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# generated by datamodel-codegen:
22
# filename: api.spec.yaml
3-
# timestamp: 2024-08-06T05:55:29+00:00
3+
# timestamp: 2024-08-09T12:39:58+00:00
44

55
from __future__ import annotations
66

@@ -11,6 +11,12 @@
1111
from renku_data_services.storage.apispec_base import BaseAPISpec
1212

1313

14+
class RCloneConfigValidate(
15+
RootModel[Optional[Dict[str, Union[int, Optional[str], bool, Dict[str, Any]]]]]
16+
):
17+
root: Optional[Dict[str, Union[int, Optional[str], bool, Dict[str, Any]]]] = None
18+
19+
1420
class Example(BaseAPISpec):
1521
value: Optional[str] = Field(
1622
None, description="a potential value for the option (think enum)"

0 commit comments

Comments
 (0)