Skip to content

Commit 0b7476f

Browse files
committed
Add a helper function get_executor for retrieving executors from NEMORUN_HOME
Signed-off-by: Hemil Desai <hemild@nvidia.com>
1 parent 38265c4 commit 0b7476f

File tree

2 files changed

+56
-2
lines changed

2 files changed

+56
-2
lines changed

src/nemo_run/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,13 @@
1616
from nemo_run import cli
1717
from nemo_run.api import autoconvert, dryrun_fn
1818
from nemo_run.config import Config, ConfigurableMixin, Partial, Script
19-
from nemo_run.core.execution.base import Executor, ExecutorMacros, FaultTolerance, Torchrun
19+
from nemo_run.core.execution.base import (
20+
Executor,
21+
ExecutorMacros,
22+
FaultTolerance,
23+
Torchrun,
24+
get_executor,
25+
)
2026
from nemo_run.core.execution.docker import DockerExecutor
2127
from nemo_run.core.execution.local import LocalExecutor
2228
from nemo_run.core.execution.skypilot import SkypilotExecutor
@@ -40,6 +46,7 @@
4046
"DockerExecutor",
4147
"dryrun_fn",
4248
"Executor",
49+
"get_executor",
4350
"ExecutorMacros",
4451
"Experiment",
4552
"FaultTolerance",

src/nemo_run/core/execution/base.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# limitations under the License.
1515

1616
import copy
17+
import importlib.util
1718
import os
1819
from dataclasses import asdict, dataclass, field
1920
from string import Template
@@ -23,7 +24,7 @@
2324
from torchx.specs import Role
2425
from typing_extensions import Self
2526

26-
from nemo_run.config import ConfigurableMixin
27+
from nemo_run.config import NEMORUN_HOME, ConfigurableMixin
2728
from nemo_run.core.packaging.base import Packager
2829

2930

@@ -226,3 +227,49 @@ def package_configs(self, *cfgs: tuple[str, str]) -> list[str]:
226227
return filenames
227228

228229
def cleanup(self, handle: str): ...
230+
231+
232+
def get_executor(name: str, file_path: Optional[str] = None) -> Executor:
233+
"""
234+
Retrieves an executor instance by name from a specified or default Python file.
235+
The file must contain a global dict called EXECUTOR_MAP, which maps executor names to their corresponding instances.
236+
237+
This function dynamically imports the file_path, searches for the EXECUTOR_MAP dictionary
238+
and returns the value corresponding to the given name.
239+
240+
This functionality allows you to define all your executors in a single file which lives separately from your codebase.
241+
It is similar to ~/.ssh/config and allows you to use executors across your projects without having to redefine them.
242+
243+
Example:
244+
executor = get_executor("local", file_path="path/to/executors.py")
245+
executor = get_executor("gpu") # Uses the default location of os.path.join(NEMORUN_HOME, "executors.py")
246+
247+
Args:
248+
name (str): The name of the executor to retrieve.
249+
file_path (Optional[str]): The path to the Python file containing the executor definitions.
250+
Defaults to None, in which case the default location of os.path.join(NEMORUN_HOME, "executors.py") is used.
251+
252+
The file_path is expected to be a string representing a file path with the following structure:
253+
- It should be a path to a Python file (with a .py extension).
254+
- The file should contain a dictionary named `EXECUTOR_MAP` that maps executor names to their corresponding instances.
255+
- The file can be located anywhere in the file system, but if not provided, it defaults to `NEMORUN_HOME/executors.py`.
256+
257+
Returns:
258+
Executor: The executor instance corresponding to the given name.
259+
260+
Raises:
261+
AttributeError: If the file at the specified path does not contain an `EXECUTOR_MAP` dictionary.
262+
AssertionError: If the given executor name is not found in the `EXECUTOR_MAP` dictionary.
263+
"""
264+
265+
if not file_path:
266+
file_path = os.path.join(NEMORUN_HOME, "executors.py")
267+
268+
spec = importlib.util.spec_from_file_location("executors", file_path)
269+
assert spec
270+
module = importlib.util.module_from_spec(spec)
271+
assert spec.loader
272+
spec.loader.exec_module(module)
273+
executor_map = getattr(module, "EXECUTOR_MAP")
274+
assert name in executor_map, f"Executor {name} not found."
275+
return executor_map[name]

0 commit comments

Comments
 (0)