Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 18 additions & 8 deletions components/renku_data_services/base_models/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,23 @@
from renku_data_services import errors


def validate_and_dump(
model: type[BaseModel],
data: Any,
exclude_none: bool = True,
) -> Any:
"""Validate and dump with a pydantic model, ensuring proper validation errors."""
try:
body = model.model_validate(data).model_dump(exclude_none=exclude_none, mode="json")
except PydanticValidationError as err:
parts = [".".join(str(i) for i in field["loc"]) + ": " + field["msg"] for field in err.errors()]
message = (
f"The server could not construct a valid response. Errors found in the following fields: {', '.join(parts)}"
)
raise errors.ProgrammingError(message=message) from err
return body


def validated_json(
model: type[BaseModel],
data: Any,
Expand All @@ -25,12 +42,5 @@ def validated_json(

If the input data fails validation, an HTTP status code 500 will be raised.
"""
try:
body = model.model_validate(data).model_dump(exclude_none=exclude_none, mode="json")
except PydanticValidationError as err:
parts = [".".join(str(i) for i in field["loc"]) + ": " + field["msg"] for field in err.errors()]
message = (
f"The server could not construct a valid response. Errors found in the following fields: {', '.join(parts)}"
)
raise errors.ProgrammingError(message=message) from err
body = validate_and_dump(model, data, exclude_none)
return json(body, status=status, headers=headers, content_type=content_type, dumps=dumps, **kwargs)
19 changes: 8 additions & 11 deletions components/renku_data_services/connected_services/blueprints.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from renku_data_services.base_api.auth import authenticate, only_admins, only_authenticated
from renku_data_services.base_api.blueprint import BlueprintFactoryResponse, CustomBlueprint
from renku_data_services.base_api.misc import validate_query
from renku_data_services.base_models.validation import validated_json
from renku_data_services.connected_services import apispec
from renku_data_services.connected_services.apispec_base import AuthorizeParams, CallbackParams
from renku_data_services.connected_services.db import ConnectedServicesRepository
Expand All @@ -30,9 +31,7 @@ def get_all(self) -> BlueprintFactoryResponse:
@authenticate(self.authenticator)
async def _get_all(_: Request, user: base_models.APIUser) -> JSONResponse:
clients = await self.connected_services_repo.get_oauth2_clients(user=user)
return json(
[apispec.Provider.model_validate(c).model_dump(exclude_none=True, mode="json") for c in clients]
)
return validated_json(apispec.ProviderList, clients)

return "/oauth2/providers", ["GET"], _get_all

Expand All @@ -42,7 +41,7 @@ def get_one(self) -> BlueprintFactoryResponse:
@authenticate(self.authenticator)
async def _get_one(_: Request, user: base_models.APIUser, provider_id: str) -> JSONResponse:
client = await self.connected_services_repo.get_oauth2_client(provider_id=provider_id, user=user)
return json(apispec.Provider.model_validate(client).model_dump(exclude_none=True, mode="json"))
return validated_json(apispec.Provider, client)

return "/oauth2/providers/<provider_id>", ["GET"], _get_one

Expand All @@ -54,7 +53,7 @@ def post(self) -> BlueprintFactoryResponse:
@validate(json=apispec.ProviderPost)
async def _post(_: Request, user: base_models.APIUser, body: apispec.ProviderPost) -> JSONResponse:
client = await self.connected_services_repo.insert_oauth2_client(user=user, new_client=body)
return json(apispec.Provider.model_validate(client).model_dump(exclude_none=True, mode="json"), 201)
return validated_json(apispec.Provider, client, 201)

return "/oauth2/providers", ["POST"], _post

Expand All @@ -71,7 +70,7 @@ async def _patch(
client = await self.connected_services_repo.update_oauth2_client(
user=user, provider_id=provider_id, **body_dict
)
return json(apispec.Provider.model_validate(client).model_dump(exclude_none=True, mode="json"))
return validated_json(apispec.Provider, client)

return "/oauth2/providers/<provider_id>", ["PATCH"], _patch

Expand Down Expand Up @@ -143,9 +142,7 @@ def get_all(self) -> BlueprintFactoryResponse:
@authenticate(self.authenticator)
async def _get_all(_: Request, user: base_models.APIUser) -> JSONResponse:
connections = await self.connected_services_repo.get_oauth2_connections(user=user)
return json(
[apispec.Connection.model_validate(c).model_dump(exclude_none=True, mode="json") for c in connections]
)
return validated_json(apispec.ConnectionList, connections)

return "/oauth2/connections", ["GET"], _get_all

Expand All @@ -157,7 +154,7 @@ async def _get_one(_: Request, user: base_models.APIUser, connection_id: str) ->
connection = await self.connected_services_repo.get_oauth2_connection(
connection_id=connection_id, user=user
)
return json(apispec.Connection.model_validate(connection).model_dump(exclude_none=True, mode="json"))
return validated_json(apispec.Connection, connection)

return "/oauth2/connections/<connection_id>", ["GET"], _get_one

Expand All @@ -169,7 +166,7 @@ async def _get_account(_: Request, user: base_models.APIUser, connection_id: str
account = await self.connected_services_repo.get_oauth2_connected_account(
connection_id=connection_id, user=user
)
return json(apispec.ConnectedAccount.model_validate(account).model_dump(exclude_none=True, mode="json"))
return validated_json(apispec.ConnectedAccount, account)

return "/oauth2/connections/<connection_id>/account", ["GET"], _get_account

Expand Down
69 changes: 37 additions & 32 deletions components/renku_data_services/namespace/blueprints.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from dataclasses import dataclass

from sanic import HTTPResponse, Request, json
from sanic import HTTPResponse, Request
from sanic.response import JSONResponse
from sanic_ext import validate

Expand All @@ -12,6 +12,7 @@
from renku_data_services.base_api.blueprint import BlueprintFactoryResponse, CustomBlueprint
from renku_data_services.base_api.misc import validate_query
from renku_data_services.base_api.pagination import PaginationRequest, paginate
from renku_data_services.base_models.validation import validate_and_dump, validated_json
from renku_data_services.errors import errors
from renku_data_services.namespace import apispec
from renku_data_services.namespace.db import GroupRepository
Expand All @@ -35,7 +36,7 @@ async def _get_all(
) -> tuple[list[dict], int]:
groups, rec_count = await self.group_repo.get_groups(user=user, pagination=pagination)
return (
[apispec.GroupResponse.model_validate(g).model_dump(exclude_none=True, mode="json") for g in groups],
validate_and_dump(apispec.GroupResponseList, groups),
rec_count,
)

Expand All @@ -49,7 +50,7 @@ def post(self) -> BlueprintFactoryResponse:
@validate(json=apispec.GroupPostRequest)
async def _post(_: Request, user: base_models.APIUser, body: apispec.GroupPostRequest) -> JSONResponse:
result = await self.group_repo.insert_group(user=user, payload=body)
return json(apispec.GroupResponse.model_validate(result).model_dump(exclude_none=True, mode="json"), 201)
return validated_json(apispec.GroupResponse, result, 201)

return "/groups", ["POST"], _post

Expand All @@ -59,7 +60,7 @@ def get_one(self) -> BlueprintFactoryResponse:
@authenticate(self.authenticator)
async def _get_one(_: Request, user: base_models.APIUser, slug: str) -> JSONResponse:
result = await self.group_repo.get_group(user=user, slug=slug)
return json(apispec.GroupResponse.model_validate(result).model_dump(exclude_none=True, mode="json"))
return validated_json(apispec.GroupResponse, result)

return "/groups/<slug:renku_slug>", ["GET"], _get_one

Expand All @@ -85,7 +86,7 @@ async def _patch(
) -> JSONResponse:
body_dict = body.model_dump(exclude_none=True)
res = await self.group_repo.update_group(user=user, slug=slug, payload=body_dict)
return json(apispec.GroupResponse.model_validate(res).model_dump(exclude_none=True, mode="json"))
return validated_json(apispec.GroupResponse, res)

return "/groups/<slug:renku_slug>", ["PATCH"], _patch

Expand All @@ -95,17 +96,18 @@ def get_all_members(self) -> BlueprintFactoryResponse:
@authenticate(self.authenticator)
async def _get_all_members(_: Request, user: base_models.APIUser, slug: str) -> JSONResponse:
members = await self.group_repo.get_group_members(user, slug)
return json(
return validated_json(
apispec.GroupMemberResponseList,
[
apispec.GroupMemberResponse(
dict(
id=m.id,
email=m.email,
first_name=m.first_name,
last_name=m.last_name,
role=apispec.GroupRole(m.role.value),
).model_dump(exclude_none=True, mode="json")
)
for m in members
]
],
)

return "/groups/<slug:renku_slug>/members", ["GET"], _get_all_members
Expand All @@ -115,25 +117,24 @@ def update_members(self) -> BlueprintFactoryResponse:

@authenticate(self.authenticator)
@only_authenticated
async def _update_members(
request: Request,
user: base_models.APIUser,
slug: str,
) -> JSONResponse:
async def _update_members(request: Request, user: base_models.APIUser, slug: str) -> JSONResponse:
# TODO: sanic validation does not support validating top-level json lists, switch this to @validate
# once sanic-org/sanic-ext/issues/198 is fixed
body_validated = apispec.GroupMemberPatchRequestList.model_validate(request.json)
res = await self.group_repo.update_group_members(
user=user,
slug=slug,
payload=body_validated,
)
return json(
return validated_json(
apispec.GroupMemberPatchRequestList,
[
apispec.GroupMemberPatchRequest(
dict(
id=m.member.user_id,
role=apispec.GroupRole(m.member.role.value),
).model_dump(exclude_none=True, mode="json")
)
for m in res
]
],
)

return "/groups/<slug:renku_slug>/members", ["PATCH"], _update_members
Expand Down Expand Up @@ -164,17 +165,20 @@ async def _get_namespaces(
nss, total_count = await self.group_repo.get_namespaces(
user=user, pagination=pagination, minimum_role=minimum_role
)
return [
apispec.NamespaceResponse(
id=ns.id,
name=ns.name,
slug=ns.latest_slug if ns.latest_slug else ns.slug,
created_by=ns.created_by,
creation_date=None, # NOTE: we do not save creation date in the DB
namespace_kind=apispec.NamespaceKind(ns.kind.value),
).model_dump(exclude_none=True, mode="json")
for ns in nss
], total_count
return validate_and_dump(
apispec.NamespaceResponseList,
[
dict(
id=ns.id,
name=ns.name,
slug=ns.latest_slug if ns.latest_slug else ns.slug,
created_by=ns.created_by,
creation_date=None, # NOTE: we do not save creation date in the DB
namespace_kind=apispec.NamespaceKind(ns.kind.value),
)
for ns in nss
],
), total_count

return "/namespaces", ["GET"], _get_namespaces

Expand All @@ -186,15 +190,16 @@ async def _get_namespace(_: Request, user: base_models.APIUser, slug: str) -> JS
ns = await self.group_repo.get_namespace_by_slug(user=user, slug=slug)
if not ns:
raise errors.MissingResourceError(message=f"The namespace with slug {slug} does not exist")
return json(
apispec.NamespaceResponse(
return validated_json(
apispec.NamespaceResponse,
dict(
id=ns.id,
name=ns.name,
slug=ns.latest_slug if ns.latest_slug else ns.slug,
created_by=ns.created_by,
creation_date=None, # NOTE: we do not save creation date in the DB
namespace_kind=apispec.NamespaceKind(ns.kind.value),
).model_dump(exclude_none=True, mode="json")
),
)

return "/namespaces/<slug:renku_slug>", ["GET"], _get_namespace
Loading