Skip to content

Commit f4b6480

Browse files
feat(cli): session pause and resume (#3633)
1 parent b31ade0 commit f4b6480

File tree

13 files changed

+378
-35
lines changed

13 files changed

+378
-35
lines changed

docs/_static/cheatsheet/cheatsheet.json

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,20 @@
325325
"rp"
326326
]
327327
},
328+
{
329+
"command": "$ renku session pause <name>",
330+
"description": "Pause the specified session.",
331+
"target": [
332+
"rp"
333+
]
334+
},
335+
{
336+
"command": "$ renku session resume <name>",
337+
"description": "Resume the specified paused session.",
338+
"target": [
339+
"rp"
340+
]
341+
},
328342
{
329343
"command": "$ renku session stop <name>",
330344
"description": "Stop the specified session.",
97 Bytes
Binary file not shown.

docs/cheatsheet_hash

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
ad86ac1d0614ccb692c96e893db4d20d cheatsheet.tex
1+
5316163d742bdb6792ed8bcb35031f6c cheatsheet.tex
22
c70c179e07f04186ec05497564165f11 sdsc_cheatsheet.cls

docs/cheatsheet_json_hash

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
1ac51267cefdf4976c29c9d7657063b8 cheatsheet.json
1+
1856fb451165d013777c7c4cdd56e575 cheatsheet.json

renku/command/session.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,13 @@
1717

1818
from renku.command.command_builder.command import Command
1919
from renku.core.session.session import (
20+
search_hibernating_session_providers,
2021
search_session_providers,
2122
search_sessions,
2223
session_list,
2324
session_open,
25+
session_pause,
26+
session_resume,
2427
session_start,
2528
session_stop,
2629
ssh_setup,
@@ -37,6 +40,11 @@ def search_session_providers_command():
3740
return Command().command(search_session_providers).require_migration().with_database(write=False)
3841

3942

43+
def search_hibernating_session_providers_command():
44+
"""Get all the session provider names that support hibernation and match a pattern."""
45+
return Command().command(search_hibernating_session_providers).require_migration().with_database(write=False)
46+
47+
4048
def session_list_command():
4149
"""List all the running interactive sessions."""
4250
return Command().command(session_list).with_database(write=False)
@@ -49,14 +57,24 @@ def session_start_command():
4957

5058
def session_stop_command():
5159
"""Stop a running an interactive session."""
52-
return Command().command(session_stop)
60+
return Command().command(session_stop).with_database(write=False)
5361

5462

5563
def session_open_command():
5664
"""Open a running interactive session."""
57-
return Command().command(session_open)
65+
return Command().command(session_open).with_database(write=False)
5866

5967

6068
def ssh_setup_command():
6169
"""Setup SSH keys for SSH connections to sessions."""
6270
return Command().command(ssh_setup)
71+
72+
73+
def session_pause_command():
74+
"""Pause a running interactive session."""
75+
return Command().command(session_pause).with_database(write=False)
76+
77+
78+
def session_resume_command():
79+
"""Resume a paused session."""
80+
return Command().command(session_resume).with_database(write=False)

renku/core/plugin/session.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
import pluggy
2020

21-
from renku.domain_model.session import ISessionProvider
21+
from renku.domain_model.session import IHibernatingSessionProvider, ISessionProvider
2222

2323
hookspec = pluggy.HookspecMarker("renku")
2424

@@ -41,3 +41,9 @@ def get_supported_session_providers() -> List[ISessionProvider]:
4141
providers = pm.hook.session_provider()
4242

4343
return sorted(providers, key=lambda p: p.priority)
44+
45+
46+
def get_supported_hibernating_session_providers() -> List[IHibernatingSessionProvider]:
47+
"""Returns the currently available interactive session providers that support hibernation."""
48+
providers = get_supported_session_providers()
49+
return [p for p in providers if isinstance(p, IHibernatingSessionProvider)]

renku/core/session/renkulab.py

Lines changed: 74 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,13 @@
3434
from renku.core.util.jwt import is_token_expired
3535
from renku.core.util.ssh import SystemSSHConfig
3636
from renku.domain_model.project_context import project_context
37-
from renku.domain_model.session import ISessionProvider, Session, SessionStopStatus
37+
from renku.domain_model.session import IHibernatingSessionProvider, Session, SessionStopStatus
3838

3939
if TYPE_CHECKING:
4040
from renku.core.dataset.providers.models import ProviderParameter
4141

4242

43-
class RenkulabSessionProvider(ISessionProvider):
43+
class RenkulabSessionProvider(IHibernatingSessionProvider):
4444
"""A session provider that uses the notebook service API to launch sessions."""
4545

4646
DEFAULT_TIMEOUT_SECONDS = 300
@@ -118,7 +118,7 @@ def _wait_for_session_status(
118118
)
119119
if res.status_code == 404 and status == "stopping":
120120
return
121-
if res.status_code == 200 and status != "stopping":
121+
if res.status_code in [200, 204] and status != "stopping":
122122
if res.json().get("status", {}).get("state") == status:
123123
return
124124
sleep(5)
@@ -210,9 +210,9 @@ def _remote_head_hexsha():
210210

211211
return remote.head
212212

213-
def _send_renku_request(self, req_type: str, *args, **kwargs):
214-
res = getattr(requests, req_type)(*args, **kwargs)
215-
if res.status_code == 401:
213+
def _send_renku_request(self, verb: str, *args, **kwargs):
214+
response = getattr(requests, verb)(*args, **kwargs)
215+
if response.status_code == 401:
216216
# NOTE: Check if logged in to KC but not the Renku UI
217217
token = read_renku_token(endpoint=self._renku_url())
218218
if token and not is_token_expired(token):
@@ -222,7 +222,7 @@ def _send_renku_request(self, req_type: str, *args, **kwargs):
222222
raise errors.AuthenticationError(
223223
"Please run the renku login command to authenticate with Renku or to refresh your expired credentials."
224224
)
225-
return res
225+
return response
226226

227227
@staticmethod
228228
def _project_name_from_full_project_name(project_name: str) -> str:
@@ -262,7 +262,7 @@ def find_image(self, image_name: str, config: Optional[Dict[str, Any]]) -> bool:
262262
)
263263

264264
@hookimpl
265-
def session_provider(self) -> ISessionProvider:
265+
def session_provider(self) -> IHibernatingSessionProvider:
266266
"""Supported session provider.
267267
268268
Returns:
@@ -511,3 +511,69 @@ def session_url(self, session_name: str) -> str:
511511
def force_build_image(self, **kwargs) -> bool:
512512
"""Whether we should force build the image directly or check for an existing image first."""
513513
return self._force_build
514+
515+
def session_pause(self, project_name: str, session_name: Optional[str], **_) -> SessionStopStatus:
516+
"""Pause all sessions (for the given project) or a specific interactive session."""
517+
518+
def pause(session_name: str):
519+
result = self._send_renku_request(
520+
"patch",
521+
f"{self._notebooks_url()}/servers/{session_name}",
522+
headers=self._auth_header(),
523+
json={"state": "hibernated"},
524+
)
525+
526+
self._wait_for_session_status(session_name, "hibernated")
527+
528+
return result
529+
530+
sessions = self.session_list(project_name=project_name)
531+
n_sessions = len(sessions)
532+
533+
if n_sessions == 0:
534+
return SessionStopStatus.NO_ACTIVE_SESSION
535+
536+
if session_name:
537+
response = pause(session_name)
538+
elif n_sessions == 1:
539+
response = pause(sessions[0].name)
540+
else:
541+
return SessionStopStatus.NAME_NEEDED
542+
543+
return SessionStopStatus.SUCCESSFUL if response.status_code == 204 else SessionStopStatus.FAILED
544+
545+
def session_resume(self, project_name: str, session_name: Optional[str], **kwargs) -> bool:
546+
"""Resume a paused session.
547+
548+
Args:
549+
project_name(str): Renku project name.
550+
session_name(Optional[str]): The unique id of the interactive session.
551+
"""
552+
sessions = self.session_list(project_name="")
553+
system_config = SystemSSHConfig()
554+
name = self._project_name_from_full_project_name(project_name)
555+
ssh_prefix = f"{system_config.renku_host}-{name}-"
556+
557+
if not session_name:
558+
if len(sessions) == 1:
559+
session_name = sessions[0].name
560+
else:
561+
return False
562+
else:
563+
if session_name.startswith(ssh_prefix):
564+
# NOTE: User passed in ssh connection name instead of session id by accident
565+
session_name = session_name.replace(ssh_prefix, "", 1)
566+
567+
if not any(s.name == session_name for s in sessions):
568+
return False
569+
570+
self._send_renku_request(
571+
"patch",
572+
f"{self._notebooks_url()}/servers/{session_name}",
573+
headers=self._auth_header(),
574+
json={"state": "running"},
575+
)
576+
577+
self._wait_for_session_status(session_name, "running")
578+
579+
return True

renku/core/session/session.py

Lines changed: 109 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,12 @@
2525

2626
from renku.core import errors
2727
from renku.core.config import get_value
28-
from renku.core.plugin.session import get_supported_session_providers
28+
from renku.core.plugin.session import get_supported_hibernating_session_providers, get_supported_session_providers
2929
from renku.core.session.utils import get_image_repository_host, get_renku_project_name
3030
from renku.core.util import communication
3131
from renku.core.util.os import safe_read_yaml
3232
from renku.core.util.ssh import SystemSSHConfig, generate_ssh_keys
33-
from renku.domain_model.session import ISessionProvider, Session, SessionStopStatus
33+
from renku.domain_model.session import IHibernatingSessionProvider, ISessionProvider, Session, SessionStopStatus
3434

3535

3636
def _safe_get_provider(provider: str) -> ISessionProvider:
@@ -80,6 +80,22 @@ def search_session_providers(name: str) -> List[str]:
8080
return [p.name for p in get_supported_session_providers() if p.name.lower().startswith(name)]
8181

8282

83+
@validate_arguments(config=dict(arbitrary_types_allowed=True))
84+
def search_hibernating_session_providers(name: str) -> List[str]:
85+
"""Get all session providers that support hibernation and their name starts with the given name.
86+
87+
Args:
88+
name(str): The name to search for.
89+
90+
Returns:
91+
All session providers whose name starts with ``name``.
92+
"""
93+
from renku.core.plugin.session import get_supported_hibernating_session_providers
94+
95+
name = name.lower()
96+
return [p.name for p in get_supported_hibernating_session_providers() if p.name.lower().startswith(name)]
97+
98+
8399
@validate_arguments(config=dict(arbitrary_types_allowed=True))
84100
def session_list(*, provider: Optional[str] = None) -> SessionList:
85101
"""List interactive sessions.
@@ -358,3 +374,94 @@ def ssh_setup(existing_key: Optional[Path] = None, force: bool = False):
358374
"This command does not add any public SSH keys to your project. "
359375
"Keys have to be added manually or by using the 'renku session start' command with the '--ssh' flag."
360376
)
377+
378+
379+
@validate_arguments(config=dict(arbitrary_types_allowed=True))
380+
def session_pause(session_name: Optional[str], provider: Optional[str] = None, **kwargs):
381+
"""Pause an interactive session.
382+
383+
Args:
384+
session_name(Optional[str]): Name of the session.
385+
provider(Optional[str]): Name of the session provider to use.
386+
"""
387+
388+
def pause(session_provider: IHibernatingSessionProvider) -> SessionStopStatus:
389+
try:
390+
return session_provider.session_pause(project_name=project_name, session_name=session_name)
391+
except errors.RenkulabSessionGetUrlError:
392+
if provider:
393+
raise
394+
return SessionStopStatus.FAILED
395+
396+
project_name = get_renku_project_name()
397+
398+
if provider:
399+
session_provider = _safe_get_provider(provider)
400+
if session_provider is None:
401+
raise errors.ParameterError(f"Provider '{provider}' not found")
402+
elif not isinstance(session_provider, IHibernatingSessionProvider):
403+
raise errors.ParameterError(f"Provider '{provider}' doesn't support pausing sessions")
404+
providers = [session_provider]
405+
else:
406+
providers = get_supported_hibernating_session_providers()
407+
408+
session_message = f"session {session_name}" if session_name else "session"
409+
statues = []
410+
warning_messages = []
411+
with communication.busy(msg=f"Waiting for {session_message} to pause..."):
412+
for session_provider in sorted(providers, key=lambda p: p.priority):
413+
try:
414+
status = pause(session_provider) # type: ignore
415+
except errors.RenkuException as e:
416+
warning_messages.append(f"Cannot pause sessions in provider '{session_provider.name}': {e}")
417+
else:
418+
statues.append(status)
419+
420+
# NOTE: The given session name was stopped; don't continue
421+
if session_name and status == SessionStopStatus.SUCCESSFUL:
422+
break
423+
424+
if warning_messages:
425+
for message in warning_messages:
426+
communication.warn(message)
427+
428+
if not statues:
429+
return
430+
elif all(s == SessionStopStatus.NO_ACTIVE_SESSION for s in statues):
431+
raise errors.ParameterError("There are no running sessions.")
432+
elif session_name and not any(s == SessionStopStatus.SUCCESSFUL for s in statues):
433+
raise errors.ParameterError(f"Could not find '{session_name}' among the running sessions.")
434+
elif not session_name and not any(s == SessionStopStatus.SUCCESSFUL for s in statues):
435+
raise errors.ParameterError("Session name is missing")
436+
437+
438+
@validate_arguments(config=dict(arbitrary_types_allowed=True))
439+
def session_resume(session_name: Optional[str], provider: Optional[str] = None, **kwargs):
440+
"""Resume a paused session.
441+
442+
Args:
443+
session_name(Optional[str]): Name of the session.
444+
provider(Optional[str]): Name of the session provider to use.
445+
"""
446+
project_name = get_renku_project_name()
447+
448+
if provider:
449+
session_provider = _safe_get_provider(provider)
450+
if session_provider is None:
451+
raise errors.ParameterError(f"Provider '{provider}' not found")
452+
elif not isinstance(session_provider, IHibernatingSessionProvider):
453+
raise errors.ParameterError(f"Provider '{provider}' doesn't support pausing/resuming sessions")
454+
providers = [session_provider]
455+
else:
456+
providers = get_supported_hibernating_session_providers()
457+
458+
session_message = f"session {session_name}" if session_name else "session"
459+
with communication.busy(msg=f"Waiting for {session_message} to resume..."):
460+
for session_provider in providers:
461+
if session_provider.session_resume(project_name, session_name, **kwargs): # type: ignore
462+
return
463+
464+
if session_name:
465+
raise errors.ParameterError(f"Could not find '{session_name}' among the sessions.")
466+
else:
467+
raise errors.ParameterError("Session name is missing")

renku/core/util/requests.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,11 @@ def put(url, *, data=None, files=None, headers=None, params=None):
7878
return _request("put", url=url, data=data, files=files, headers=headers, params=params)
7979

8080

81+
def patch(url, *, json=None, files=None, headers=None, params=None):
82+
"""Send a PATCH request."""
83+
return _request("patch", url=url, json=json, files=files, headers=headers, params=params)
84+
85+
8186
def _request(verb: str, url: str, *, allow_redirects=True, data=None, files=None, headers=None, json=None, params=None):
8287
try:
8388
with _retry() as session:

0 commit comments

Comments
 (0)