Skip to content

Commit 87000a6

Browse files
committed
fix: allow session launcher parameters to be reset (#434)
Allows the API to accept None as input for args, command and the session launcher resource class ID so that they can be reset to their defaults in patch endpoints.
1 parent 4b2686f commit 87000a6

File tree

6 files changed

+231
-113
lines changed

6 files changed

+231
-113
lines changed

components/renku_data_services/base_models/core.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from dataclasses import dataclass, field
66
from datetime import datetime
77
from enum import Enum, StrEnum
8-
from typing import ClassVar, Optional, Protocol, Self, TypeVar
8+
from typing import ClassVar, NewType, Optional, Protocol, Self, TypeVar
99

1010
from sanic import Request
1111

@@ -212,3 +212,12 @@ class Authenticator(Protocol[AnyAPIUser]):
212212
async def authenticate(self, access_token: str, request: Request) -> AnyAPIUser:
213213
"""Validates the user credentials (i.e. we can say that the user is a valid Renku user)."""
214214
...
215+
216+
217+
ResetType = NewType("ResetType", object)
218+
"""This type represents that a value that may be None should be reset back to None or null.
219+
This type should have only one instance, defined in the same file as this type.
220+
"""
221+
222+
RESET: ResetType = ResetType(object())
223+
"""The single instance of the ResetType, can be compared to similar to None, i.e. `if value is RESET`"""

components/renku_data_services/session/blueprints.py

Lines changed: 10 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from renku_data_services.base_api.auth import authenticate, only_authenticated
1313
from renku_data_services.base_api.blueprint import BlueprintFactoryResponse, CustomBlueprint
1414
from renku_data_services.base_models.validation import validated_json
15-
from renku_data_services.session import apispec, models
15+
from renku_data_services.session import apispec, converters, models
1616
from renku_data_services.session.db import SessionRepository
1717

1818

@@ -76,9 +76,11 @@ def patch(self) -> BlueprintFactoryResponse:
7676
async def _patch(
7777
_: Request, user: base_models.APIUser, environment_id: ULID, body: apispec.EnvironmentPatch
7878
) -> JSONResponse:
79-
body_dict = body.model_dump(exclude_none=True)
79+
update = converters.environment_update_from_patch(body)
8080
environment = await self.session_repo.update_environment(
81-
user=user, environment_id=environment_id, **body_dict
81+
user=user,
82+
environment_id=environment_id,
83+
update=update,
8284
)
8385
return validated_json(apispec.Environment, environment)
8486

@@ -169,34 +171,14 @@ def patch(self) -> BlueprintFactoryResponse:
169171
async def _patch(
170172
_: Request, user: base_models.APIUser, launcher_id: ULID, body: apispec.SessionLauncherPatch
171173
) -> JSONResponse:
172-
body_dict = body.model_dump(exclude_none=True, mode="json")
173174
async with self.session_repo.session_maker() as session, session.begin():
174175
current_launcher = await self.session_repo.get_launcher(user, launcher_id)
175-
new_env: models.UnsavedEnvironment | None = None
176-
if (
177-
isinstance(body.environment, apispec.EnvironmentPatchInLauncher)
178-
and current_launcher.environment.environment_kind == models.EnvironmentKind.GLOBAL
179-
and body.environment.environment_kind == apispec.EnvironmentKind.CUSTOM
180-
):
181-
# This means that the global environment is being swapped for a custom one,
182-
# so we have to create a brand new environment, but we have to validate here.
183-
validated_env = apispec.EnvironmentPostInLauncher.model_validate(body_dict.pop("environment"))
184-
new_env = models.UnsavedEnvironment(
185-
name=validated_env.name,
186-
description=validated_env.description,
187-
container_image=validated_env.container_image,
188-
default_url=validated_env.default_url,
189-
port=validated_env.port,
190-
working_directory=PurePosixPath(validated_env.working_directory),
191-
mount_directory=PurePosixPath(validated_env.mount_directory),
192-
uid=validated_env.uid,
193-
gid=validated_env.gid,
194-
environment_kind=models.EnvironmentKind(validated_env.environment_kind.value),
195-
args=validated_env.args,
196-
command=validated_env.command,
197-
)
176+
update = converters.launcher_update_from_patch(body, current_launcher)
198177
launcher = await self.session_repo.update_launcher(
199-
user=user, launcher_id=launcher_id, new_custom_environment=new_env, session=session, **body_dict
178+
user=user,
179+
launcher_id=launcher_id,
180+
session=session,
181+
update=update,
200182
)
201183
return validated_json(apispec.SessionLauncher, launcher)
202184

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
"""Code used to convert from/to apispec and models."""
2+
3+
from pathlib import PurePosixPath
4+
5+
from renku_data_services.base_models.core import RESET, ResetType
6+
from renku_data_services.session import apispec, models
7+
8+
9+
def environment_update_from_patch(data: apispec.EnvironmentPatch) -> models.EnvironmentUpdate:
10+
"""Create an update object from an apispec or any other pydantic model."""
11+
data_dict = data.model_dump(exclude_unset=True, mode="json")
12+
working_directory: PurePosixPath | None = None
13+
if data.working_directory is not None:
14+
working_directory = PurePosixPath(data.working_directory)
15+
mount_directory: PurePosixPath | None = None
16+
if data.mount_directory is not None:
17+
mount_directory = PurePosixPath(data.mount_directory)
18+
# NOTE: If the args or command are present in the data_dict and they are None they were passed in by the user.
19+
# The None specifically passed by the user indicates that the value should be removed from the DB.
20+
args = RESET if "args" in data_dict and data_dict["args"] is None else data.args
21+
command = RESET if "command" in data_dict and data_dict["command"] is None else data.command
22+
return models.EnvironmentUpdate(
23+
name=data.name,
24+
description=data.description,
25+
container_image=data.container_image,
26+
default_url=data.default_url,
27+
port=data.port,
28+
working_directory=working_directory,
29+
mount_directory=mount_directory,
30+
uid=data.uid,
31+
gid=data.gid,
32+
args=args,
33+
command=command,
34+
)
35+
36+
37+
def launcher_update_from_patch(
38+
data: apispec.SessionLauncherPatch,
39+
current_launcher: models.SessionLauncher | None = None,
40+
) -> models.SessionLauncherUpdate:
41+
"""Create an update object from an apispec or any other pydantic model."""
42+
data_dict = data.model_dump(exclude_unset=True, mode="json")
43+
environment: str | models.EnvironmentUpdate | models.UnsavedEnvironment | None = None
44+
if (
45+
isinstance(data.environment, apispec.EnvironmentPatchInLauncher)
46+
and current_launcher is not None
47+
and current_launcher.environment.environment_kind == models.EnvironmentKind.GLOBAL
48+
and data.environment.environment_kind == apispec.EnvironmentKind.CUSTOM
49+
):
50+
# This means that the global environment is being swapped for a custom one,
51+
# so we have to create a brand new environment, but we have to validate here.
52+
validated_env = apispec.EnvironmentPostInLauncher.model_validate(data_dict["environment"])
53+
environment = models.UnsavedEnvironment(
54+
name=validated_env.name,
55+
description=validated_env.description,
56+
container_image=validated_env.container_image,
57+
default_url=validated_env.default_url,
58+
port=validated_env.port,
59+
working_directory=PurePosixPath(validated_env.working_directory),
60+
mount_directory=PurePosixPath(validated_env.mount_directory),
61+
uid=validated_env.uid,
62+
gid=validated_env.gid,
63+
environment_kind=models.EnvironmentKind(validated_env.environment_kind.value),
64+
args=validated_env.args,
65+
command=validated_env.command,
66+
)
67+
elif isinstance(data.environment, apispec.EnvironmentPatchInLauncher):
68+
environment = environment_update_from_patch(data.environment)
69+
elif isinstance(data.environment, apispec.EnvironmentIdOnlyPatch):
70+
environment = data.environment.id
71+
resource_class_id: int | None | ResetType = None
72+
if "resource_class_id" in data_dict and data_dict["resource_class_id"] is None:
73+
# NOTE: This means that the resource class set in the DB should be removed so that the
74+
# default resource class currently set in the CRC will be used.
75+
resource_class_id = RESET
76+
else:
77+
resource_class_id = data_dict.get("resource_class_id")
78+
return models.SessionLauncherUpdate(
79+
name=data_dict.get("name"),
80+
description=data_dict.get("description"),
81+
environment=environment,
82+
resource_class_id=resource_class_id,
83+
)

components/renku_data_services/session/db.py

Lines changed: 69 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
from collections.abc import Callable
66
from contextlib import AbstractAsyncContextManager, nullcontext
7-
from typing import Any
87

98
from sqlalchemy import select
109
from sqlalchemy.ext.asyncio import AsyncSession
@@ -14,6 +13,7 @@
1413
from renku_data_services import errors
1514
from renku_data_services.authz.authz import Authz, ResourceType
1615
from renku_data_services.authz.models import Scope
16+
from renku_data_services.base_models.core import RESET
1717
from renku_data_services.crc.db import ResourcePoolRepository
1818
from renku_data_services.session import models
1919
from renku_data_services.session import orm as schemas
@@ -101,53 +101,59 @@ async def insert_environment(
101101
await session.refresh(env)
102102
return env.dump()
103103

104-
async def __update_environment(
104+
def __update_environment(
105105
self,
106-
user: base_models.APIUser,
107-
session: AsyncSession,
108-
environment_id: ULID,
109-
kind: models.EnvironmentKind,
110-
**kwargs: dict,
111-
) -> models.Environment:
112-
res = await session.scalars(
113-
select(schemas.EnvironmentORM)
114-
.where(schemas.EnvironmentORM.id == str(environment_id))
115-
.where(schemas.EnvironmentORM.environment_kind == kind.value)
116-
)
117-
environment = res.one_or_none()
118-
if environment is None:
119-
raise errors.MissingResourceError(message=f"Session environment with id '{environment_id}' does not exist.")
120-
121-
for key, value in kwargs.items():
122-
# NOTE: Only some fields can be edited
123-
if key in [
124-
"name",
125-
"description",
126-
"container_image",
127-
"default_url",
128-
"port",
129-
"working_directory",
130-
"mount_directory",
131-
"uid",
132-
"gid",
133-
"args",
134-
"command",
135-
]:
136-
setattr(environment, key, value)
137-
138-
return environment.dump()
106+
environment: schemas.EnvironmentORM,
107+
update: models.EnvironmentUpdate,
108+
) -> None:
109+
# NOTE: this is more verbose than a loop and setattr but this way we get mypy type checks
110+
if update.name is not None:
111+
environment.name = update.name
112+
if update.description is not None:
113+
environment.description = update.description
114+
if update.container_image is not None:
115+
environment.container_image = update.container_image
116+
if update.default_url is not None:
117+
environment.default_url = update.default_url
118+
if update.port is not None:
119+
environment.port = update.port
120+
if update.working_directory is not None:
121+
environment.working_directory = update.working_directory
122+
if update.mount_directory is not None:
123+
environment.mount_directory = update.mount_directory
124+
if update.uid is not None:
125+
environment.uid = update.uid
126+
if update.gid is not None:
127+
environment.gid = update.gid
128+
if update.args is RESET:
129+
environment.args = None
130+
elif isinstance(update.args, list):
131+
environment.args = update.args
132+
if update.command is RESET:
133+
environment.command = None
134+
elif isinstance(update.command, list):
135+
environment.command = update.command
139136

140137
async def update_environment(
141-
self, user: base_models.APIUser, environment_id: ULID, **kwargs: dict
138+
self, user: base_models.APIUser, environment_id: ULID, update: models.EnvironmentUpdate
142139
) -> models.Environment:
143140
"""Update a global session environment entry."""
144141
if not user.is_admin:
145142
raise errors.UnauthorizedError(message="You do not have the required permissions for this operation.")
146143

147144
async with self.session_maker() as session, session.begin():
148-
return await self.__update_environment(
149-
user, session, environment_id, models.EnvironmentKind.GLOBAL, **kwargs
145+
res = await session.scalars(
146+
select(schemas.EnvironmentORM)
147+
.where(schemas.EnvironmentORM.id == str(environment_id))
148+
.where(schemas.EnvironmentORM.environment_kind == models.EnvironmentKind.GLOBAL)
150149
)
150+
environment = res.one_or_none()
151+
if environment is None:
152+
raise errors.MissingResourceError(
153+
message=f"Session environment with id '{environment_id}' does not exist."
154+
)
155+
self.__update_environment(environment, update)
156+
return environment.dump()
151157

152158
async def delete_environment(self, user: base_models.APIUser, environment_id: ULID) -> None:
153159
"""Delete a global session environment entry."""
@@ -297,9 +303,8 @@ async def update_launcher(
297303
self,
298304
user: base_models.APIUser,
299305
launcher_id: ULID,
300-
new_custom_environment: models.UnsavedEnvironment | None,
306+
update: models.SessionLauncherUpdate,
301307
session: AsyncSession | None = None,
302-
**kwargs: Any,
303308
) -> models.SessionLauncher:
304309
"""Update a session launcher entry."""
305310
if not user.is_authenticated or user.id is None:
@@ -333,8 +338,8 @@ async def update_launcher(
333338
if not authorized:
334339
raise errors.ForbiddenError(message="You do not have the required permissions for this operation.")
335340

336-
resource_class_id = kwargs.get("resource_class_id")
337-
if resource_class_id is not None:
341+
resource_class_id = update.resource_class_id
342+
if isinstance(resource_class_id, int):
338343
res = await session.scalars(
339344
select(schemas.ResourceClassORM).where(schemas.ResourceClassORM.id == resource_class_id)
340345
)
@@ -351,32 +356,32 @@ async def update_launcher(
351356
message=f"You do not have access to resource class with id '{resource_class_id}'."
352357
)
353358

354-
for key, value in kwargs.items():
355-
# NOTE: Only some fields can be updated.
356-
if key in [
357-
"name",
358-
"description",
359-
"resource_class_id",
360-
]:
361-
setattr(launcher, key, value)
362-
363-
env_payload = kwargs.get("environment", {})
364-
await self.__update_launcher_environment(user, launcher, session, new_custom_environment, **env_payload)
365-
await session.flush()
366-
await session.refresh(launcher)
359+
# NOTE: Only some fields can be updated.
360+
if update.name is not None:
361+
launcher.name = update.name
362+
if update.description is not None:
363+
launcher.description = update.description
364+
if isinstance(update.resource_class_id, int):
365+
launcher.resource_class_id = update.resource_class_id
366+
elif update.resource_class_id is RESET:
367+
launcher.resource_class_id = None
368+
369+
if update.environment is None:
370+
return launcher.dump()
371+
372+
await self.__update_launcher_environment(user, launcher, session, update.environment)
367373
return launcher.dump()
368374

369375
async def __update_launcher_environment(
370376
self,
371377
user: base_models.APIUser,
372378
launcher: schemas.SessionLauncherORM,
373379
session: AsyncSession,
374-
new_custom_environment: models.UnsavedEnvironment | None,
375-
**kwargs: Any,
380+
update: models.EnvironmentUpdate | models.UnsavedEnvironment | str,
376381
) -> None:
377382
current_env_kind = launcher.environment.environment_kind
378-
match new_custom_environment, current_env_kind, kwargs:
379-
case None, _, {"id": env_id, **nothing_else} if len(nothing_else) == 0:
383+
match update, current_env_kind:
384+
case str() as env_id, _:
380385
# The environment in the launcher is set via ID, the new ID has to refer
381386
# to an environment that is GLOBAL.
382387
old_environment = launcher.environment
@@ -403,33 +408,16 @@ async def __update_launcher_environment(
403408
# We remove the custom environment to avoid accumulating custom environments that are not associated
404409
# with any launchers.
405410
await session.delete(old_environment)
406-
case None, models.EnvironmentKind.CUSTOM, {**rest} if (
407-
rest.get("environment_kind") is None
408-
or rest.get("environment_kind") == models.EnvironmentKind.CUSTOM.value
409-
):
411+
case models.EnvironmentUpdate(), models.EnvironmentKind.CUSTOM:
410412
# Custom environment being updated
411-
for key, val in rest.items():
412-
# NOTE: Only some fields can be updated.
413-
if key in [
414-
"name",
415-
"description",
416-
"container_image",
417-
"default_url",
418-
"port",
419-
"working_directory",
420-
"mount_directory",
421-
"uid",
422-
"gid",
423-
"args",
424-
"command",
425-
]:
426-
setattr(launcher.environment, key, val)
427-
case models.UnsavedEnvironment(), models.EnvironmentKind.GLOBAL, {**nothing_else} if (
428-
len(nothing_else) == 0 and new_custom_environment.environment_kind == models.EnvironmentKind.CUSTOM
413+
self.__update_environment(launcher.environment, update)
414+
case models.UnsavedEnvironment() as new_custom_environment, models.EnvironmentKind.GLOBAL if (
415+
new_custom_environment.environment_kind == models.EnvironmentKind.CUSTOM
429416
):
430417
# Global environment replaced by a custom one
431418
new_env = await self.__insert_environment(user, session, new_custom_environment)
432419
launcher.environment = new_env
420+
await session.flush()
433421
case _:
434422
raise errors.ValidationError(
435423
message="Encountered an invalid payload for updating a launcher environment", quiet=True

0 commit comments

Comments
 (0)