Skip to content

Commit 9dac1e0

Browse files
committed
refactor: add validation to storage, repo and session blueprints
1 parent 9640498 commit 9dac1e0

File tree

9 files changed

+177
-96
lines changed

9 files changed

+177
-96
lines changed

components/renku_data_services/repositories/blueprints.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,15 @@
33
from dataclasses import dataclass
44
from urllib.parse import unquote
55

6-
from sanic import HTTPResponse, Request, json
6+
from sanic import HTTPResponse, Request
77
from sanic.response import JSONResponse
88

99
import renku_data_services.base_models as base_models
1010
from renku_data_services import errors
1111
from renku_data_services.base_api.auth import authenticate
1212
from renku_data_services.base_api.blueprint import BlueprintFactoryResponse, CustomBlueprint
1313
from renku_data_services.base_api.etag import extract_if_none_match
14+
from renku_data_services.base_models.validation import validated_json
1415
from renku_data_services.repositories import apispec
1516
from renku_data_services.repositories.apispec_base import RepositoryParams
1617
from renku_data_services.repositories.db import GitRepositoriesRepository
@@ -53,10 +54,7 @@ async def get_internal_gitlab_user() -> base_models.APIUser:
5354
if result.repository_metadata and result.repository_metadata.etag is not None
5455
else None
5556
)
56-
return json(
57-
apispec.RepositoryProviderMatch.model_validate(result).model_dump(exclude_none=True, mode="json"),
58-
headers=headers,
59-
)
57+
return validated_json(apispec.RepositoryProviderMatch, result, headers=headers)
6058

6159
return "/repositories/<repository_url>", ["GET"], _get_one_repository
6260

components/renku_data_services/session/apispec_base.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
"""Base models for API specifications."""
22

3-
from pydantic import BaseModel
3+
from pydantic import BaseModel, field_validator
4+
from ulid import ULID
5+
6+
from renku_data_services.session import models
47

58

69
class BaseAPISpec(BaseModel):
@@ -10,3 +13,31 @@ class Config:
1013
"""Enables orm mode for pydantic."""
1114

1215
from_attributes = True
16+
17+
@field_validator("id", mode="before", check_fields=False)
18+
@classmethod
19+
def serialize_id(cls, id: str | ULID) -> str:
20+
"""Custom serializer that can handle ULIDs."""
21+
return str(id)
22+
23+
@field_validator("project_id", mode="before", check_fields=False)
24+
@classmethod
25+
def serialize_project_id(cls, project_id: str | ULID) -> str:
26+
"""Custom serializer that can handle ULIDs."""
27+
return str(project_id)
28+
29+
@field_validator("environment_id", mode="before", check_fields=False)
30+
@classmethod
31+
def serialize_environment_id(cls, environment_id: str | ULID | None) -> str | None:
32+
"""Custom serializer that can handle ULIDs."""
33+
if environment_id is None:
34+
return None
35+
return str(environment_id)
36+
37+
@field_validator("environment_kind", mode="before", check_fields=False)
38+
@classmethod
39+
def serialize_environment_kind(cls, environment_kind: models.EnvironmentKind | str) -> str:
40+
"""Custom serializer that can handle ULIDs."""
41+
if isinstance(environment_kind, models.EnvironmentKind):
42+
return environment_kind.value
43+
return environment_kind

components/renku_data_services/session/blueprints.py

Lines changed: 48 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,19 @@
11
"""Session blueprint."""
22

33
from dataclasses import dataclass
4+
from datetime import UTC, datetime
45

5-
from sanic import HTTPResponse, Request, json
6+
from sanic import HTTPResponse, Request
67
from sanic.response import JSONResponse
78
from sanic_ext import validate
89
from ulid import ULID
910

1011
import renku_data_services.base_models as base_models
11-
from renku_data_services.base_api.auth import authenticate, validate_path_project_id
12+
from renku_data_services import errors
13+
from renku_data_services.base_api.auth import authenticate, only_authenticated
1214
from renku_data_services.base_api.blueprint import BlueprintFactoryResponse, CustomBlueprint
13-
from renku_data_services.session import apispec
15+
from renku_data_services.base_models.validation import validated_json
16+
from renku_data_services.session import apispec, models
1417
from renku_data_services.session.db import SessionRepository
1518

1619

@@ -26,9 +29,7 @@ def get_all(self) -> BlueprintFactoryResponse:
2629

2730
async def _get_all(_: Request) -> JSONResponse:
2831
environments = await self.session_repo.get_environments()
29-
return json(
30-
[apispec.Environment.model_validate(e).model_dump(exclude_none=True, mode="json") for e in environments]
31-
)
32+
return validated_json(apispec.EnvironmentList, environments)
3233

3334
return "/environments", ["GET"], _get_all
3435

@@ -37,7 +38,7 @@ def get_one(self) -> BlueprintFactoryResponse:
3738

3839
async def _get_one(_: Request, environment_id: ULID) -> JSONResponse:
3940
environment = await self.session_repo.get_environment(environment_id=environment_id)
40-
return json(apispec.Environment.model_validate(environment).model_dump(exclude_none=True, mode="json"))
41+
return validated_json(apispec.Environment, environment)
4142

4243
return "/environments/<environment_id:ulid>", ["GET"], _get_one
4344

@@ -47,8 +48,17 @@ def post(self) -> BlueprintFactoryResponse:
4748
@authenticate(self.authenticator)
4849
@validate(json=apispec.EnvironmentPost)
4950
async def _post(_: Request, user: base_models.APIUser, body: apispec.EnvironmentPost) -> JSONResponse:
50-
environment = await self.session_repo.insert_environment(user=user, new_environment=body)
51-
return json(apispec.Environment.model_validate(environment).model_dump(exclude_none=True, mode="json"), 201)
51+
assert user.id
52+
environment_model = models.UnsavedEnvironment(
53+
name=body.name,
54+
description=body.description,
55+
container_image=body.container_image,
56+
default_url=body.default_url,
57+
created_by=models.Member(id=user.id),
58+
creation_date=datetime.now(UTC).replace(microsecond=0),
59+
)
60+
environment = await self.session_repo.insert_environment(user=user, new_environment=environment_model)
61+
return validated_json(apispec.Environment, environment, 201)
5262

5363
return "/environments", ["POST"], _post
5464

@@ -64,7 +74,7 @@ async def _patch(
6474
environment = await self.session_repo.update_environment(
6575
user=user, environment_id=environment_id, **body_dict
6676
)
67-
return json(apispec.Environment.model_validate(environment).model_dump(exclude_none=True, mode="json"))
77+
return validated_json(apispec.Environment, environment)
6878

6979
return "/environments/<environment_id:ulid>", ["PATCH"], _patch
7080

@@ -92,12 +102,7 @@ def get_all(self) -> BlueprintFactoryResponse:
92102
@authenticate(self.authenticator)
93103
async def _get_all(_: Request, user: base_models.APIUser) -> JSONResponse:
94104
launchers = await self.session_repo.get_launchers(user=user)
95-
return json(
96-
[
97-
apispec.SessionLauncher.model_validate(item).model_dump(exclude_none=True, mode="json")
98-
for item in launchers
99-
]
100-
)
105+
return validated_json(apispec.SessionLaunchersList, launchers)
101106

102107
return "/session_launchers", ["GET"], _get_all
103108

@@ -107,20 +112,39 @@ def get_one(self) -> BlueprintFactoryResponse:
107112
@authenticate(self.authenticator)
108113
async def _get_one(_: Request, user: base_models.APIUser, launcher_id: ULID) -> JSONResponse:
109114
launcher = await self.session_repo.get_launcher(user=user, launcher_id=launcher_id)
110-
return json(apispec.SessionLauncher.model_validate(launcher).model_dump(exclude_none=True, mode="json"))
115+
return validated_json(apispec.SessionLauncher, launcher)
111116

112117
return "/session_launchers/<launcher_id:ulid>", ["GET"], _get_one
113118

114119
def post(self) -> BlueprintFactoryResponse:
115120
"""Create a new session launcher."""
116121

117122
@authenticate(self.authenticator)
123+
@only_authenticated
118124
@validate(json=apispec.SessionLauncherPost)
119125
async def _post(_: Request, user: base_models.APIUser, body: apispec.SessionLauncherPost) -> JSONResponse:
120-
launcher = await self.session_repo.insert_launcher(user=user, new_launcher=body)
121-
return json(
122-
apispec.SessionLauncher.model_validate(launcher).model_dump(exclude_none=True, mode="json"), 201
126+
match body.environment_kind:
127+
case apispec.EnvironmentKind.global_environment:
128+
environment_kind = models.EnvironmentKind.global_environment
129+
case apispec.EnvironmentKind.container_image:
130+
environment_kind = models.EnvironmentKind.container_image
131+
case _:
132+
raise errors.ValidationError(message=f"Unknown environment kind {body.environment_kind}")
133+
assert user.id
134+
launcher_model = models.UnsavedSessionLauncher(
135+
name=body.name,
136+
project_id=ULID.from_str(body.project_id),
137+
description=body.description,
138+
environment_kind=environment_kind,
139+
environment_id=body.environment_id,
140+
resource_class_id=body.resource_class_id,
141+
container_image=body.container_image,
142+
default_url=body.default_url,
143+
created_by=models.Member(id=user.id),
144+
creation_date=datetime.now(UTC).replace(microsecond=0),
123145
)
146+
launcher = await self.session_repo.insert_launcher(user=user, new_launcher=launcher_model)
147+
return validated_json(apispec.SessionLauncher, launcher, 201)
124148

125149
return "/session_launchers", ["POST"], _post
126150

@@ -134,7 +158,7 @@ async def _patch(
134158
) -> JSONResponse:
135159
body_dict = body.model_dump(exclude_none=True)
136160
launcher = await self.session_repo.update_launcher(user=user, launcher_id=launcher_id, **body_dict)
137-
return json(apispec.SessionLauncher.model_validate(launcher).model_dump(exclude_none=True, mode="json"))
161+
return validated_json(apispec.SessionLauncher, launcher)
138162

139163
return "/session_launchers/<launcher_id:ulid>", ["PATCH"], _patch
140164

@@ -152,14 +176,8 @@ def get_project_launchers(self) -> BlueprintFactoryResponse:
152176
"""Get all launchers belonging to a project."""
153177

154178
@authenticate(self.authenticator)
155-
@validate_path_project_id
156-
async def _get_launcher(_: Request, user: base_models.APIUser, project_id: str) -> JSONResponse:
179+
async def _get_launcher(_: Request, user: base_models.APIUser, project_id: ULID) -> JSONResponse:
157180
launchers = await self.session_repo.get_project_launchers(user=user, project_id=project_id)
158-
return json(
159-
[
160-
apispec.SessionLauncher.model_validate(item).model_dump(exclude_none=True, mode="json")
161-
for item in launchers
162-
]
163-
)
181+
return validated_json(apispec.SessionLaunchersList, launchers)
164182

165-
return "/projects/<project_id>/session_launchers", ["GET"], _get_launcher
183+
return "/projects/<project_id:ulid>/session_launchers", ["GET"], _get_launcher

components/renku_data_services/session/models.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,20 @@
22

33
from dataclasses import dataclass
44
from datetime import datetime
5+
from enum import Enum
6+
from typing import Self
57

68
from pydantic import BaseModel, model_validator
9+
from ulid import ULID
710

811
from renku_data_services import errors
9-
from renku_data_services.session.apispec import EnvironmentKind
12+
13+
14+
class EnvironmentKind(Enum):
15+
"""Environment kind enum."""
16+
17+
global_environment = "global_environment"
18+
container_image = "container_image"
1019

1120

1221
@dataclass(frozen=True, eq=True, kw_only=True)
@@ -17,10 +26,9 @@ class Member(BaseModel):
1726

1827

1928
@dataclass(frozen=True, eq=True, kw_only=True)
20-
class Environment(BaseModel):
21-
"""Session environment model."""
29+
class UnsavedEnvironment(BaseModel):
30+
"""Session environment model that isn't in the db yet."""
2231

23-
id: str | None
2432
name: str
2533
creation_date: datetime
2634
description: str | None
@@ -30,11 +38,17 @@ class Environment(BaseModel):
3038

3139

3240
@dataclass(frozen=True, eq=True, kw_only=True)
33-
class SessionLauncher(BaseModel):
34-
"""Session launcher model."""
41+
class Environment(UnsavedEnvironment): # type: ignore[misc]
42+
"""Session environment model."""
43+
44+
id: ULID
45+
3546

36-
id: str | None
37-
project_id: str
47+
@dataclass(frozen=True, eq=True, kw_only=True)
48+
class UnsavedSessionLauncher(BaseModel):
49+
"""Session launcher model that isn't in the db yet."""
50+
51+
project_id: ULID
3852
name: str
3953
creation_date: datetime
4054
description: str | None
@@ -46,7 +60,7 @@ class SessionLauncher(BaseModel):
4660
created_by: Member
4761

4862
@model_validator(mode="after")
49-
def check_launcher_environment_kind(self) -> "SessionLauncher":
63+
def check_launcher_environment_kind(self) -> Self:
5064
"""Validates the environment of a launcher."""
5165

5266
environment_kind = self.environment_kind
@@ -60,3 +74,10 @@ def check_launcher_environment_kind(self) -> "SessionLauncher":
6074
raise errors.ValidationError(message="'container_image' not set when environment_kind=container_image")
6175

6276
return self
77+
78+
79+
@dataclass(frozen=True, eq=True, kw_only=True)
80+
class SessionLauncher(UnsavedSessionLauncher): # type: ignore[misc]
81+
"""Session launcher model."""
82+
83+
id: str

components/renku_data_services/session/orm.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from renku_data_services.crc.orm import ResourceClassORM
1111
from renku_data_services.project.orm import ProjectORM
1212
from renku_data_services.session import models
13-
from renku_data_services.session.apispec import EnvironmentKind
13+
from renku_data_services.utils.sqlalchemy import ULIDType
1414

1515
metadata_obj = MetaData(schema="sessions") # Has to match alembic ini section name
1616

@@ -26,7 +26,7 @@ class EnvironmentORM(BaseORM):
2626

2727
__tablename__ = "environments"
2828

29-
id: Mapped[str] = mapped_column("id", String(26), primary_key=True, default_factory=lambda: str(ULID()), init=False)
29+
id: Mapped[ULID] = mapped_column("id", ULIDType, primary_key=True, default_factory=lambda: str(ULID()), init=False)
3030
"""Id of this session environment object."""
3131

3232
name: Mapped[str] = mapped_column("name", String(99))
@@ -48,7 +48,7 @@ class EnvironmentORM(BaseORM):
4848
"""Default URL path to open in a session."""
4949

5050
@classmethod
51-
def load(cls, environment: models.Environment) -> "EnvironmentORM":
51+
def load(cls, environment: models.UnsavedEnvironment) -> "EnvironmentORM":
5252
"""Create EnvironmentORM from the session environment model."""
5353
return cls(
5454
name=environment.name,
@@ -92,7 +92,7 @@ class SessionLauncherORM(BaseORM):
9292
description: Mapped[str | None] = mapped_column("description", String(500))
9393
"""Human-readable description of the session launcher."""
9494

95-
environment_kind: Mapped[EnvironmentKind]
95+
environment_kind: Mapped[models.EnvironmentKind]
9696
"""The kind of environment definition to use."""
9797

9898
container_image: Mapped[str | None] = mapped_column("container_image", String(500))
@@ -124,7 +124,7 @@ class SessionLauncherORM(BaseORM):
124124
"""Id of the resource class."""
125125

126126
@classmethod
127-
def load(cls, launcher: models.SessionLauncher) -> "SessionLauncherORM":
127+
def load(cls, launcher: models.UnsavedSessionLauncher) -> "SessionLauncherORM":
128128
"""Create SessionLauncherORM from the session launcher model."""
129129
return cls(
130130
name=launcher.name,

components/renku_data_services/storage/api.spec.yaml

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ paths:
262262
content:
263263
application/json:
264264
schema:
265-
$ref: "#/components/schemas/RCloneConfig"
265+
$ref: "#/components/schemas/RCloneConfigValidate"
266266
responses:
267267
"204":
268268
description: The configuration is valid
@@ -346,6 +346,16 @@ components:
346346
nullable: true
347347
- type: boolean
348348
- type: object
349+
RCloneConfigValidate: #this is the same as RCloneConfig but duplicated so a class gets generated
350+
type: object
351+
description: Dictionary of rclone key:value pairs (based on schema from '/storage_schema')
352+
additionalProperties:
353+
oneOf:
354+
- type: integer
355+
- type: string
356+
nullable: true
357+
- type: boolean
358+
- type: object
349359
CloudStorageUrl:
350360
allOf:
351361
- $ref: "#/components/schemas/GitRequest"

components/renku_data_services/storage/apispec.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# generated by datamodel-codegen:
22
# filename: api.spec.yaml
3-
# timestamp: 2024-08-06T05:55:29+00:00
3+
# timestamp: 2024-08-09T12:39:58+00:00
44

55
from __future__ import annotations
66

@@ -11,6 +11,12 @@
1111
from renku_data_services.storage.apispec_base import BaseAPISpec
1212

1313

14+
class RCloneConfigValidate(
15+
RootModel[Optional[Dict[str, Union[int, Optional[str], bool, Dict[str, Any]]]]]
16+
):
17+
root: Optional[Dict[str, Union[int, Optional[str], bool, Dict[str, Any]]]] = None
18+
19+
1420
class Example(BaseAPISpec):
1521
value: Optional[str] = Field(
1622
None, description="a potential value for the option (think enum)"

0 commit comments

Comments
 (0)