Skip to content

Commit d03963a

Browse files
committed
Fix
Signed-off-by: Hemil Desai <[email protected]>
1 parent 9f852c3 commit d03963a

File tree

4 files changed

+36
-52
lines changed

4 files changed

+36
-52
lines changed

src/nemo_run/core/execution/slurm.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,14 @@
4242
from nemo_run.core.packaging.base import Packager
4343
from nemo_run.core.packaging.git import GitArchivePackager
4444
from nemo_run.core.serialization.zlib_json import ZlibJSONSerializer
45-
from nemo_run.core.tunnel.callback import Callback
46-
from nemo_run.core.tunnel.client import LocalTunnel, PackagingJob, SSHConfigFile, SSHTunnel, Tunnel
45+
from nemo_run.core.tunnel.client import (
46+
Callback,
47+
LocalTunnel,
48+
PackagingJob,
49+
SSHConfigFile,
50+
SSHTunnel,
51+
Tunnel,
52+
)
4753
from nemo_run.core.tunnel.server import TunnelMetadata, server_dir
4854
from nemo_run.devspace.base import DevSpace
4955

src/nemo_run/core/tunnel/callback.py

Lines changed: 0 additions & 45 deletions
This file was deleted.

src/nemo_run/core/tunnel/client.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from abc import ABC, abstractmethod
2525
from dataclasses import dataclass, field
2626
from pathlib import Path
27-
from typing import TYPE_CHECKING, Callable, Optional
27+
from typing import Callable, Optional
2828

2929
import paramiko
3030
import paramiko.ssh_exception
@@ -35,9 +35,6 @@
3535
from nemo_run.config import NEMORUN_HOME, ConfigurableMixin
3636
from nemo_run.core.frontend.console.api import CONSOLE
3737

38-
if TYPE_CHECKING:
39-
from nemo_run.core.tunnel.callback import Callback
40-
4138
logger: logging.Logger = logging.getLogger(__name__)
4239
TUNNEL_DIR = ".tunnels"
4340
TUNNEL_FILE_SUBPATH = os.path.join(NEMORUN_HOME, TUNNEL_DIR)
@@ -383,3 +380,29 @@ def remove_entry(self, name: str):
383380
file.writelines(lines)
384381

385382
print(f"Removed SSH config entry for {host}.")
383+
384+
385+
class Callback:
386+
def setup(self, tunnel: "Tunnel"):
387+
"""Called when the tunnel is setup."""
388+
self.tunnel = tunnel
389+
390+
def on_start(self):
391+
"""Called when the keep_alive loop starts."""
392+
pass
393+
394+
def on_interval(self):
395+
"""Called at each interval during the keep_alive loop."""
396+
pass
397+
398+
def on_stop(self):
399+
"""Called when the keep_alive loop stops."""
400+
pass
401+
402+
def on_error(self, error: Exception):
403+
"""Called when an error occurs during the keep_alive loop.
404+
405+
Args:
406+
error (Exception): The exception that was raised.
407+
"""
408+
pass

src/nemo_run/devspace/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import fiddle as fdl
2020

2121
from nemo_run.core.serialization.zlib_json import ZlibJSONSerializer
22-
from nemo_run.core.tunnel.callback import Callback
22+
from nemo_run.core.tunnel.client import Callback
2323

2424
if TYPE_CHECKING:
2525
from nemo_run.core.execution.base import Executor

0 commit comments

Comments
 (0)