44
55from collections .abc import Callable
66from contextlib import AbstractAsyncContextManager , nullcontext
7- from typing import Any
87
98from sqlalchemy import select
109from sqlalchemy .ext .asyncio import AsyncSession
1413from renku_data_services import errors
1514from renku_data_services .authz .authz import Authz , ResourceType
1615from renku_data_services .authz .models import Scope
16+ from renku_data_services .base_models .core import RESET
1717from renku_data_services .crc .db import ResourcePoolRepository
1818from renku_data_services .session import models
1919from renku_data_services .session import orm as schemas
@@ -101,53 +101,59 @@ async def insert_environment(
101101 await session .refresh (env )
102102 return env .dump ()
103103
104- async def __update_environment (
104+ def __update_environment (
105105 self ,
106- user : base_models .APIUser ,
107- session : AsyncSession ,
108- environment_id : ULID ,
109- kind : models .EnvironmentKind ,
110- ** kwargs : dict ,
111- ) -> models .Environment :
112- res = await session .scalars (
113- select (schemas .EnvironmentORM )
114- .where (schemas .EnvironmentORM .id == str (environment_id ))
115- .where (schemas .EnvironmentORM .environment_kind == kind .value )
116- )
117- environment = res .one_or_none ()
118- if environment is None :
119- raise errors .MissingResourceError (message = f"Session environment with id '{ environment_id } ' does not exist." )
120-
121- for key , value in kwargs .items ():
122- # NOTE: Only some fields can be edited
123- if key in [
124- "name" ,
125- "description" ,
126- "container_image" ,
127- "default_url" ,
128- "port" ,
129- "working_directory" ,
130- "mount_directory" ,
131- "uid" ,
132- "gid" ,
133- "args" ,
134- "command" ,
135- ]:
136- setattr (environment , key , value )
137-
138- return environment .dump ()
106+ environment : schemas .EnvironmentORM ,
107+ update : models .EnvironmentUpdate ,
108+ ) -> None :
109+ # NOTE: this is more verbose than a loop and setattr but this way we get mypy type checks
110+ if update .name is not None :
111+ environment .name = update .name
112+ if update .description is not None :
113+ environment .description = update .description
114+ if update .container_image is not None :
115+ environment .container_image = update .container_image
116+ if update .default_url is not None :
117+ environment .default_url = update .default_url
118+ if update .port is not None :
119+ environment .port = update .port
120+ if update .working_directory is not None :
121+ environment .working_directory = update .working_directory
122+ if update .mount_directory is not None :
123+ environment .mount_directory = update .mount_directory
124+ if update .uid is not None :
125+ environment .uid = update .uid
126+ if update .gid is not None :
127+ environment .gid = update .gid
128+ if update .args is RESET :
129+ environment .args = None
130+ elif isinstance (update .args , list ):
131+ environment .args = update .args
132+ if update .command is RESET :
133+ environment .command = None
134+ elif isinstance (update .command , list ):
135+ environment .command = update .command
139136
140137 async def update_environment (
141- self , user : base_models .APIUser , environment_id : ULID , ** kwargs : dict
138+ self , user : base_models .APIUser , environment_id : ULID , update : models . EnvironmentUpdate
142139 ) -> models .Environment :
143140 """Update a global session environment entry."""
144141 if not user .is_admin :
145142 raise errors .UnauthorizedError (message = "You do not have the required permissions for this operation." )
146143
147144 async with self .session_maker () as session , session .begin ():
148- return await self .__update_environment (
149- user , session , environment_id , models .EnvironmentKind .GLOBAL , ** kwargs
145+ res = await session .scalars (
146+ select (schemas .EnvironmentORM )
147+ .where (schemas .EnvironmentORM .id == str (environment_id ))
148+ .where (schemas .EnvironmentORM .environment_kind == models .EnvironmentKind .GLOBAL )
150149 )
150+ environment = res .one_or_none ()
151+ if environment is None :
152+ raise errors .MissingResourceError (
153+ message = f"Session environment with id '{ environment_id } ' does not exist."
154+ )
155+ self .__update_environment (environment , update )
156+ return environment .dump ()
151157
152158 async def delete_environment (self , user : base_models .APIUser , environment_id : ULID ) -> None :
153159 """Delete a global session environment entry."""
@@ -297,9 +303,8 @@ async def update_launcher(
297303 self ,
298304 user : base_models .APIUser ,
299305 launcher_id : ULID ,
300- new_custom_environment : models .UnsavedEnvironment | None ,
306+ update : models .SessionLauncherUpdate ,
301307 session : AsyncSession | None = None ,
302- ** kwargs : Any ,
303308 ) -> models .SessionLauncher :
304309 """Update a session launcher entry."""
305310 if not user .is_authenticated or user .id is None :
@@ -333,8 +338,8 @@ async def update_launcher(
333338 if not authorized :
334339 raise errors .ForbiddenError (message = "You do not have the required permissions for this operation." )
335340
336- resource_class_id = kwargs . get ( " resource_class_id" )
337- if resource_class_id is not None :
341+ resource_class_id = update . resource_class_id
342+ if isinstance ( resource_class_id , int ) :
338343 res = await session .scalars (
339344 select (schemas .ResourceClassORM ).where (schemas .ResourceClassORM .id == resource_class_id )
340345 )
@@ -351,32 +356,32 @@ async def update_launcher(
351356 message = f"You do not have access to resource class with id '{ resource_class_id } '."
352357 )
353358
354- for key , value in kwargs .items ():
355- # NOTE: Only some fields can be updated.
356- if key in [
357- "name" ,
358- "description" ,
359- "resource_class_id" ,
360- ]:
361- setattr (launcher , key , value )
362-
363- env_payload = kwargs .get ("environment" , {})
364- await self .__update_launcher_environment (user , launcher , session , new_custom_environment , ** env_payload )
365- await session .flush ()
366- await session .refresh (launcher )
359+ # NOTE: Only some fields can be updated.
360+ if update .name is not None :
361+ launcher .name = update .name
362+ if update .description is not None :
363+ launcher .description = update .description
364+ if isinstance (update .resource_class_id , int ):
365+ launcher .resource_class_id = update .resource_class_id
366+ elif update .resource_class_id is RESET :
367+ launcher .resource_class_id = None
368+
369+ if update .environment is None :
370+ return launcher .dump ()
371+
372+ await self .__update_launcher_environment (user , launcher , session , update .environment )
367373 return launcher .dump ()
368374
369375 async def __update_launcher_environment (
370376 self ,
371377 user : base_models .APIUser ,
372378 launcher : schemas .SessionLauncherORM ,
373379 session : AsyncSession ,
374- new_custom_environment : models .UnsavedEnvironment | None ,
375- ** kwargs : Any ,
380+ update : models .EnvironmentUpdate | models .UnsavedEnvironment | str ,
376381 ) -> None :
377382 current_env_kind = launcher .environment .environment_kind
378- match new_custom_environment , current_env_kind , kwargs :
379- case None , _, { "id" : env_id , ** nothing_else } if len ( nothing_else ) == 0 :
383+ match update , current_env_kind :
384+ case str () as env_id , _ :
380385 # The environment in the launcher is set via ID, the new ID has to refer
381386 # to an environment that is GLOBAL.
382387 old_environment = launcher .environment
@@ -403,33 +408,16 @@ async def __update_launcher_environment(
403408 # We remove the custom environment to avoid accumulating custom environments that are not associated
404409 # with any launchers.
405410 await session .delete (old_environment )
406- case None , models .EnvironmentKind .CUSTOM , {** rest } if (
407- rest .get ("environment_kind" ) is None
408- or rest .get ("environment_kind" ) == models .EnvironmentKind .CUSTOM .value
409- ):
411+ case models .EnvironmentUpdate (), models .EnvironmentKind .CUSTOM :
410412 # Custom environment being updated
411- for key , val in rest .items ():
412- # NOTE: Only some fields can be updated.
413- if key in [
414- "name" ,
415- "description" ,
416- "container_image" ,
417- "default_url" ,
418- "port" ,
419- "working_directory" ,
420- "mount_directory" ,
421- "uid" ,
422- "gid" ,
423- "args" ,
424- "command" ,
425- ]:
426- setattr (launcher .environment , key , val )
427- case models .UnsavedEnvironment (), models .EnvironmentKind .GLOBAL , {** nothing_else } if (
428- len (nothing_else ) == 0 and new_custom_environment .environment_kind == models .EnvironmentKind .CUSTOM
413+ self .__update_environment (launcher .environment , update )
414+ case models .UnsavedEnvironment () as new_custom_environment , models .EnvironmentKind .GLOBAL if (
415+ new_custom_environment .environment_kind == models .EnvironmentKind .CUSTOM
429416 ):
430417 # Global environment replaced by a custom one
431418 new_env = await self .__insert_environment (user , session , new_custom_environment )
432419 launcher .environment = new_env
420+ await session .flush ()
433421 case _:
434422 raise errors .ValidationError (
435423 message = "Encountered an invalid payload for updating a launcher environment" , quiet = True
0 commit comments