diff --git a/torchx/cli/cmd_run.py b/torchx/cli/cmd_run.py index 23f96c8f6..567001e45 100644 --- a/torchx/cli/cmd_run.py +++ b/torchx/cli/cmd_run.py @@ -36,6 +36,7 @@ ) from torchx.util.log_tee_helpers import tee_logs from torchx.util.types import none_throws +from torchx.workspace import Workspace MISSING_COMPONENT_ERROR_MSG = ( @@ -92,7 +93,7 @@ def torchx_run_args_from_json(json_data: Dict[str, Any]) -> TorchXRunArgs: torchx_args = TorchXRunArgs(**filtered_json_data) if torchx_args.workspace == "": - torchx_args.workspace = f"file://{Path.cwd()}" + torchx_args.workspace = f"{Path.cwd()}" return torchx_args @@ -250,7 +251,7 @@ def add_arguments(self, subparser: argparse.ArgumentParser) -> None: subparser.add_argument( "--workspace", "--buck-target", - default=f"file://{Path.cwd()}", + default=f"{Path.cwd()}", action=torchxconfig_run, help="local workspace to build/patch (buck-target of main binary if using buck)", ) @@ -289,12 +290,14 @@ def _run_inner(self, runner: Runner, args: TorchXRunArgs) -> None: else args.component_args ) try: + workspace = Workspace.from_str(args.workspace) if args.workspace else None + if args.dryrun: dryrun_info = runner.dryrun_component( args.component_name, component_args, args.scheduler, - workspace=args.workspace, + workspace=workspace, cfg=args.scheduler_cfg, parent_run_id=args.parent_run_id, ) diff --git a/torchx/cli/test/cmd_run_test.py b/torchx/cli/test/cmd_run_test.py index 4cdadaa58..8fc632a67 100644 --- a/torchx/cli/test/cmd_run_test.py +++ b/torchx/cli/test/cmd_run_test.py @@ -401,7 +401,7 @@ def test_verify_no_extra_args_stdin_with_boolean_flags(self) -> None: def test_verify_no_extra_args_stdin_with_value_args(self) -> None: """Test that arguments with values conflict with stdin.""" - args = self.parser.parse_args(["--stdin", "--workspace", "file:///custom/path"]) + args = self.parser.parse_args(["--stdin", "--workspace", "/custom/path"]) with self.assertRaises(SystemExit): self.cmd_run.verify_no_extra_args(args) @@ -499,7 +499,7 @@ def test_torchx_run_args_from_json(self) -> None: self.assertEqual(result.dryrun, False) self.assertEqual(result.wait, False) self.assertEqual(result.log, False) - self.assertEqual(result.workspace, f"file://{Path.cwd()}") + self.assertEqual(result.workspace, f"{Path.cwd()}") self.assertEqual(result.parent_run_id, None) self.assertEqual(result.tee_logs, False) self.assertEqual(result.component_args, {}) @@ -515,7 +515,7 @@ def test_torchx_run_args_from_json(self) -> None: "dryrun": True, "wait": True, "log": True, - "workspace": "file:///custom/path", + "workspace": "/custom/path", "parent_run_id": "parent123", "tee_logs": True, } @@ -529,7 +529,7 @@ def test_torchx_run_args_from_json(self) -> None: self.assertEqual(result2.dryrun, True) self.assertEqual(result2.wait, True) self.assertEqual(result2.log, True) - self.assertEqual(result2.workspace, "file:///custom/path") + self.assertEqual(result2.workspace, "/custom/path") self.assertEqual(result2.parent_run_id, "parent123") self.assertEqual(result2.tee_logs, True) @@ -626,7 +626,7 @@ def test_torchx_run_args_from_argparse(self) -> None: args.dryrun = True args.wait = False args.log = True - args.workspace = "file:///custom/workspace" + args.workspace = "/custom/workspace" args.parent_run_id = "parent_123" args.tee_logs = False @@ -654,7 +654,7 @@ def test_torchx_run_args_from_argparse(self) -> None: self.assertEqual(result.dryrun, True) self.assertEqual(result.wait, False) self.assertEqual(result.log, True) - self.assertEqual(result.workspace, "file:///custom/workspace") + self.assertEqual(result.workspace, "/custom/workspace") self.assertEqual(result.parent_run_id, "parent_123") self.assertEqual(result.tee_logs, False) self.assertEqual(result.component_args, {}) diff --git a/torchx/runner/api.py b/torchx/runner/api.py index 4efc8a83b..08d732238 100644 --- a/torchx/runner/api.py +++ b/torchx/runner/api.py @@ -54,7 +54,7 @@ from torchx.util.session import get_session_id_or_create_new, TORCHX_INTERNAL_SESSION_ID from torchx.util.types import none_throws -from torchx.workspace.api import WorkspaceMixin +from torchx.workspace.api import Workspace, WorkspaceMixin if TYPE_CHECKING: from typing_extensions import Self @@ -171,7 +171,7 @@ def run_component( component_args: Union[list[str], dict[str, Any]], scheduler: str, cfg: Optional[Mapping[str, CfgVal]] = None, - workspace: Optional[str] = None, + workspace: Optional[Union[Workspace, str]] = None, parent_run_id: Optional[str] = None, ) -> AppHandle: """ @@ -206,7 +206,7 @@ def run_component( ComponentNotFoundException: if the ``component_path`` is failed to resolve. """ - with log_event("run_component", workspace=workspace) as ctx: + with log_event("run_component") as ctx: dryrun_info = self.dryrun_component( component, component_args, @@ -217,7 +217,8 @@ def run_component( ) handle = self.schedule(dryrun_info) app = none_throws(dryrun_info._app) - ctx._torchx_event.workspace = workspace + + ctx._torchx_event.workspace = str(workspace) ctx._torchx_event.scheduler = none_throws(dryrun_info._scheduler) ctx._torchx_event.app_image = app.roles[0].image ctx._torchx_event.app_id = parse_app_handle(handle)[2] @@ -230,7 +231,7 @@ def dryrun_component( component_args: Union[list[str], dict[str, Any]], scheduler: str, cfg: Optional[Mapping[str, CfgVal]] = None, - workspace: Optional[str] = None, + workspace: Optional[Union[Workspace, str]] = None, parent_run_id: Optional[str] = None, ) -> AppDryRunInfo: """ @@ -259,7 +260,7 @@ def run( app: AppDef, scheduler: str, cfg: Optional[Mapping[str, CfgVal]] = None, - workspace: Optional[str] = None, + workspace: Optional[Union[Workspace, str]] = None, parent_run_id: Optional[str] = None, ) -> AppHandle: """ @@ -272,9 +273,7 @@ def run( An application handle that is used to call other action APIs on the app. """ - with log_event( - api="run", runcfg=json.dumps(cfg) if cfg else None, workspace=workspace - ) as ctx: + with log_event(api="run") as ctx: dryrun_info = self.dryrun( app, scheduler, @@ -283,10 +282,15 @@ def run( parent_run_id=parent_run_id, ) handle = self.schedule(dryrun_info) - ctx._torchx_event.scheduler = none_throws(dryrun_info._scheduler) - ctx._torchx_event.app_image = none_throws(dryrun_info._app).roles[0].image - ctx._torchx_event.app_id = parse_app_handle(handle)[2] - ctx._torchx_event.app_metadata = app.metadata + + event = ctx._torchx_event + event.scheduler = scheduler + event.runcfg = json.dumps(cfg) if cfg else None + event.workspace = str(workspace) + event.app_id = parse_app_handle(handle)[2] + event.app_image = none_throws(dryrun_info._app).roles[0].image + event.app_metadata = app.metadata + return handle def schedule(self, dryrun_info: AppDryRunInfo) -> AppHandle: @@ -320,21 +324,22 @@ def schedule(self, dryrun_info: AppDryRunInfo) -> AppHandle: """ scheduler = none_throws(dryrun_info._scheduler) - app_image = none_throws(dryrun_info._app).roles[0].image cfg = dryrun_info._cfg - with log_event( - "schedule", - scheduler, - app_image=app_image, - runcfg=json.dumps(cfg) if cfg else None, - ) as ctx: + with log_event("schedule") as ctx: sched = self._scheduler(scheduler) app_id = sched.schedule(dryrun_info) app_handle = make_app_handle(scheduler, self._name, app_id) + app = none_throws(dryrun_info._app) self._apps[app_handle] = app - _, _, app_id = parse_app_handle(app_handle) - ctx._torchx_event.app_id = app_id + + event = ctx._torchx_event + event.scheduler = scheduler + event.runcfg = json.dumps(cfg) if cfg else None + event.app_id = app_id + event.app_image = none_throws(dryrun_info._app).roles[0].image + event.app_metadata = app.metadata + return app_handle def name(self) -> str: @@ -345,7 +350,7 @@ def dryrun( app: AppDef, scheduler: str, cfg: Optional[Mapping[str, CfgVal]] = None, - workspace: Optional[str] = None, + workspace: Optional[Union[Workspace, str]] = None, parent_run_id: Optional[str] = None, ) -> AppDryRunInfo: """ @@ -414,7 +419,7 @@ def dryrun( "dryrun", scheduler, runcfg=json.dumps(cfg) if cfg else None, - workspace=workspace, + workspace=str(workspace), ): sched = self._scheduler(scheduler) resolved_cfg = sched.run_opts().resolve(cfg) @@ -429,7 +434,7 @@ def dryrun( logger.info( 'To disable workspaces pass: --workspace="" from CLI or workspace=None programmatically.' ) - sched.build_workspace_and_update_role(role, workspace, resolved_cfg) + sched.build_workspace_and_update_role2(role, workspace, resolved_cfg) if old_img != role.image: logger.info( diff --git a/torchx/runner/events/__init__.py b/torchx/runner/events/__init__.py index 8fab92a10..5a913e8d4 100644 --- a/torchx/runner/events/__init__.py +++ b/torchx/runner/events/__init__.py @@ -33,8 +33,9 @@ from .api import SourceType, TorchxEvent # noqa F401 -# pyre-fixme[9]: _events_logger is a global variable -_events_logger: logging.Logger = None +_events_logger: Optional[logging.Logger] = None + +log: logging.Logger = logging.getLogger(__name__) def _get_or_create_logger(destination: str = "null") -> logging.Logger: @@ -51,19 +52,28 @@ def _get_or_create_logger(destination: str = "null") -> logging.Logger: a new logger if None provided. """ global _events_logger + if _events_logger: return _events_logger - logging_handler = get_logging_handler(destination) - logging_handler.setLevel(logging.DEBUG) - _events_logger = logging.getLogger(f"torchx-events-{destination}") - # Do not propagate message to the root logger - _events_logger.propagate = False - _events_logger.addHandler(logging_handler) - return _events_logger + else: + logging_handler = get_logging_handler(destination) + logging_handler.setLevel(logging.DEBUG) + _events_logger = logging.getLogger(f"torchx-events-{destination}") + # Do not propagate message to the root logger + _events_logger.propagate = False + _events_logger.addHandler(logging_handler) + + assert _events_logger # make type-checker happy + return _events_logger def record(event: TorchxEvent, destination: str = "null") -> None: - _get_or_create_logger(destination).info(event.serialize()) + try: + serialized_event = event.serialize() + except Exception: + log.exception("failed to serialize event, will not record event") + else: + _get_or_create_logger(destination).info(serialized_event) class log_event: diff --git a/torchx/runner/events/api.py b/torchx/runner/events/api.py index f03815e75..6bb9b068c 100644 --- a/torchx/runner/events/api.py +++ b/torchx/runner/events/api.py @@ -29,7 +29,7 @@ class TorchxEvent: scheduler: Scheduler that is used to execute request api: Api name app_id: Unique id that is set by the underlying scheduler - image: Image/container bundle that is used to execute request. + app_image: Image/container bundle that is used to execute request. app_metadata: metadata to the app (treatment of metadata is scheduler dependent) runcfg: Run config that was used to schedule app. source: Type of source the event is generated. diff --git a/torchx/runner/test/config_test.py b/torchx/runner/test/config_test.py index 901018c9a..6abff46fa 100644 --- a/torchx/runner/test/config_test.py +++ b/torchx/runner/test/config_test.py @@ -27,6 +27,7 @@ from torchx.schedulers.api import DescribeAppResponse, ListAppResponse, Stream from torchx.specs import AppDef, AppDryRunInfo, CfgVal, runopts from torchx.test.fixtures import TestWithTmpDir +from torchx.workspace import Workspace class TestScheduler(Scheduler): @@ -506,3 +507,31 @@ def test_dump_and_load_all_registered_schedulers(self) -> None: opt_name in cfg, f"missing {opt_name} in {sched} run opts with cfg {cfg}", ) + + def test_get_workspace_config(self) -> None: + configdir = self.tmpdir + self.write( + str(configdir / ".torchxconfig"), + """# +[cli:run] +workspace = + /home/foo/third-party/verl: verl + /home/foo/bar/scripts/.torchxconfig: verl/.torchxconfig + /home/foo/baz: +""", + ) + + workspace_config = get_config( + prefix="cli", name="run", key="workspace", dirs=[str(configdir)] + ) + self.assertIsNotNone(workspace_config) + + workspace = Workspace.from_str(workspace_config) + self.assertDictEqual( + { + "/home/foo/third-party/verl": "verl", + "/home/foo/bar/scripts/.torchxconfig": "verl/.torchxconfig", + "/home/foo/baz": "", + }, + workspace.projects, + ) diff --git a/torchx/schedulers/api.py b/torchx/schedulers/api.py index 48ca64849..14f00547c 100644 --- a/torchx/schedulers/api.py +++ b/torchx/schedulers/api.py @@ -12,7 +12,7 @@ from dataclasses import dataclass, field from datetime import datetime from enum import Enum -from typing import Generic, Iterable, List, Optional, TypeVar +from typing import Generic, Iterable, List, Optional, TypeVar, Union from torchx.specs import ( AppDef, @@ -23,7 +23,7 @@ RoleStatus, runopts, ) -from torchx.workspace.api import WorkspaceMixin +from torchx.workspace.api import Workspace, WorkspaceMixin DAYS_IN_2_WEEKS = 14 @@ -131,7 +131,7 @@ def submit( self, app: A, cfg: T, - workspace: Optional[str] = None, + workspace: Optional[Union[Workspace, str]] = None, ) -> str: """ Submits the application to be run by the scheduler. @@ -144,10 +144,9 @@ def submit( # pyre-fixme: Generic cfg type passed to resolve resolved_cfg = self.run_opts().resolve(cfg) if workspace: - sched = self - assert isinstance(sched, WorkspaceMixin) - role = app.roles[0] - sched.build_workspace_and_update_role(role, workspace, resolved_cfg) + assert isinstance(self, WorkspaceMixin) + self.build_workspace_and_update_role2(app.roles[0], workspace, resolved_cfg) + # pyre-fixme: submit_dryrun takes Generic type for resolved_cfg dryrun_info = self.submit_dryrun(app, resolved_cfg) return self.schedule(dryrun_info) @@ -356,13 +355,14 @@ def _validate(self, app: A, scheduler: str, cfg: T) -> None: Raises error if application is not compatible with scheduler """ - if isinstance(app, AppDef): - for role in app.roles: - if role.resource == NULL_RESOURCE: - raise ValueError( - f"No resource for role: {role.image}." - f" Did you forget to attach resource to the role" - ) + if not isinstance(app, AppDef): + return + + for role in app.roles: + if role.resource == NULL_RESOURCE: + raise ValueError( + f"No resource for role: {role.image}. Did you forget to attach resource to the role" + ) def filter_regex(regex: str, data: Iterable[str]) -> Iterable[str]: diff --git a/torchx/workspace/__init__.py b/torchx/workspace/__init__.py index 5625ce41d..34405292f 100644 --- a/torchx/workspace/__init__.py +++ b/torchx/workspace/__init__.py @@ -22,4 +22,4 @@ * ``memory://foo-bar/`` an in-memory workspace for notebook/programmatic usage """ -from torchx.workspace.api import walk_workspace, WorkspaceMixin # noqa: F401 +from torchx.workspace.api import walk_workspace, Workspace, WorkspaceMixin # noqa: F401 diff --git a/torchx/workspace/api.py b/torchx/workspace/api.py index 694e3fb57..b3b8ab681 100644 --- a/torchx/workspace/api.py +++ b/torchx/workspace/api.py @@ -9,9 +9,22 @@ import abc import fnmatch import posixpath +import shutil +import tempfile import warnings from dataclasses import dataclass -from typing import Any, Dict, Generic, Iterable, Mapping, Tuple, TYPE_CHECKING, TypeVar +from pathlib import Path +from typing import ( + Any, + Dict, + Generic, + Iterable, + Mapping, + Tuple, + TYPE_CHECKING, + TypeVar, + Union, +) from torchx.specs import AppDef, CfgVal, Role, runopts @@ -75,6 +88,71 @@ def build_workspace(self, sync: bool = True) -> PkgInfo[PackageType]: pass +@dataclass +class Workspace: + """ + Specifies a local "workspace" (a set of directories). Workspaces are ad-hoc built + into an (usually ephemeral) image. This effectively mirrors the local code changes + at job submission time. + + For example: + + 1. ``projects={"~/github/torch": "torch"}`` copies ``~/github/torch/**`` into ``$REMOTE_WORKSPACE_ROOT/torch/**`` + 2. ``projects={"~/github/torch": ""}`` copies ``~/github/torch/**`` into ``$REMOTE_WORKSPACE_ROOT/**`` + + The exact location of ``$REMOTE_WORKSPACE_ROOT`` is implementation dependent and varies between + different implementations of :py:class:`~torchx.workspace.api.WorkspaceMixin`. + Check the scheduler documentation for details on which workspace it supports. + + Note: ``projects`` maps the location of the local project to a sub-directory in the remote workspace root directory. + Typically the local project location is a directory path (e.g. ``/home/foo/github/torch``). + + + Attributes: + projects: mapping of local project to the sub-dir in the remote workspace dir. + """ + + projects: dict[str, str] + + def is_unmapped_single_project(self) -> bool: + """ + Returns ``True`` if this workspace only has 1 project + and its target mapping is an empty string. + """ + return len(self.projects) == 1 and not next(iter(self.projects.values())) + + @staticmethod + def from_str(workspace: str) -> "Workspace": + import yaml + + projects = yaml.safe_load(workspace) + if isinstance(projects, str): # single project workspace + projects = {projects: ""} + else: # multi-project workspace + # Replace None mappings with "" (empty string) + projects = {k: ("" if v is None else v) for k, v in projects.items()} + + return Workspace(projects) + + def __str__(self) -> str: + """ + Returns a string representation of the Workspace by concatenating + the project mappings using ';' as a delimiter and ':' between key and value. + If the single-project workspace with no target mapping, then simply + returns the src (local project dir) + + NOTE: meant to be used for logging purposes not serde. + Therefore not symmetric with :py:func:`Workspace.from_str`. + + """ + if self.is_unmapped_single_project(): + return next(iter(self.projects)) + else: + return ";".join( + k if not v else f"{k}:{v}" for k, v in self.projects.items() + ) + + class WorkspaceMixin(abc.ABC, Generic[T]): """ Note: (Prototype) this interface may change without notice! @@ -100,9 +178,50 @@ def workspace_opts(self) -> runopts: """ return runopts() + def build_workspace_and_update_role2( + self, + role: Role, + workspace: Union[Workspace, str], + cfg: Mapping[str, CfgVal], + ) -> None: + """ + Same as :py:meth:`build_workspace_and_update_role` but operates + on :py:class:`Workspace` (supports multi-project workspaces) + as well as ``str`` (for backwards compatibility). + + If ``workspace`` is a ``str`` this method simply calls + :py:meth:`build_workspace_and_update_role`. + + If ``workspace`` is :py:class:`Workspace` then the default + impl copies all the projects into a tmp directory and passes the tmp dir to + :py:meth:`build_workspace_and_update_role` + + Subclasses can override this method to customize multi-project + workspace building logic. + """ + if isinstance(workspace, Workspace): + if not workspace.is_unmapped_single_project(): + with tempfile.TemporaryDirectory(suffix="torchx_workspace_") as outdir: + for src, dst in workspace.projects.items(): + dst_path = Path(outdir) / dst + if Path(src).is_file(): + shutil.copy2(src, dst_path) + else: # src is dir + shutil.copytree(src, dst_path, dirs_exist_ok=True) + + self.build_workspace_and_update_role(role, outdir, cfg) + return + else: # single project workspace with no target mapping (treat like a str workspace) + workspace = str(workspace) + + self.build_workspace_and_update_role(role, workspace, cfg) + @abc.abstractmethod def build_workspace_and_update_role( - self, role: Role, workspace: str, cfg: Mapping[str, CfgVal] + self, + role: Role, + workspace: str, + cfg: Mapping[str, CfgVal], ) -> None: """ Builds the specified ``workspace`` with respect to ``img`` diff --git a/torchx/workspace/test/api_test.py b/torchx/workspace/test/api_test.py new file mode 100644 index 000000000..352d48530 --- /dev/null +++ b/torchx/workspace/test/api_test.py @@ -0,0 +1,182 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import shutil + +from pathlib import Path +from typing import Mapping + +from torchx.specs import CfgVal, Role +from torchx.test.fixtures import TestWithTmpDir + +from torchx.workspace.api import Workspace, WorkspaceMixin + + +class TestWorkspace(WorkspaceMixin[None]): + def __init__(self, tmpdir: Path) -> None: + self.tmpdir = tmpdir + + def build_workspace_and_update_role( + self, role: Role, workspace: str, cfg: Mapping[str, CfgVal] + ) -> None: + role.image = "bar" + role.metadata["workspace"] = workspace + + if not workspace.startswith("//"): + # to validate the merged workspace dir copy its content to the tmpdir + shutil.copytree(workspace, self.tmpdir) + + +class WorkspaceTest(TestWithTmpDir): + + def test_to_string_single_project_workspace(self) -> None: + self.assertEqual( + "/home/foo/bar", + str(Workspace(projects={"/home/foo/bar": ""})), + ) + + def test_to_string_multi_project_workspace(self) -> None: + workspace = Workspace( + projects={ + "/home/foo/workspace/myproj": "", + "/home/foo/github/torch": "torch", + } + ) + + self.assertEqual( + "/home/foo/workspace/myproj;/home/foo/github/torch:torch", + str(workspace), + ) + + def test_is_unmapped_single_project_workspace(self) -> None: + self.assertTrue( + Workspace(projects={"/home/foo/bar": ""}).is_unmapped_single_project() + ) + + self.assertFalse( + Workspace(projects={"/home/foo/bar": "baz"}).is_unmapped_single_project() + ) + + self.assertFalse( + Workspace( + projects={"/home/foo/bar": "", "/home/foo/torch": ""} + ).is_unmapped_single_project() + ) + + self.assertFalse( + Workspace( + projects={"/home/foo/bar": "", "/home/foo/torch": "pytorch"} + ).is_unmapped_single_project() + ) + + def test_from_str_single_project(self) -> None: + self.assertDictEqual( + {"/home/foo/bar": ""}, + Workspace.from_str("/home/foo/bar").projects, + ) + + self.assertDictEqual( + {"/home/foo/bar": "baz"}, + Workspace.from_str("/home/foo/bar: baz").projects, + ) + + def test_from_str_multi_project(self) -> None: + self.assertDictEqual( + { + "/home/foo/bar": "", + "/home/foo/third-party/verl": "verl", + }, + Workspace.from_str( + """# +/home/foo/bar: +/home/foo/third-party/verl: verl +""" + ).projects, + ) + + def test_build_and_update_role2_str_workspace(self) -> None: + proj = self.tmpdir / "github" / "torch" + proj.mkdir(parents=True) + (proj / "torch.py").touch() + + role = Role(name="__IGNORED__", image="foo") + out = self.tmpdir / "workspace-merged" + TestWorkspace(out).build_workspace_and_update_role2( + role, + str(proj), + cfg={}, + ) + + # make sure build_workspace_and_update_role has been called + # by checking that the image is updated from "foo" to "bar" + self.assertEqual(role.image, "bar") + self.assertTrue((out / "torch.py").exists()) + + def test_build_and_update_role2_unmapped_single_project_workspace(self) -> None: + proj = self.tmpdir / "github" / "torch" + proj.mkdir(parents=True) + (proj / "torch.py").touch() + + role = Role(name="__IGNORED__", image="foo") + out = self.tmpdir / "workspace-merged" + TestWorkspace(out).build_workspace_and_update_role2( + role, + Workspace(projects={str(proj): ""}), + cfg={}, + ) + + self.assertEqual(role.image, "bar") + self.assertTrue((out / "torch.py").exists()) + + def test_build_and_update_role2_unmapped_single_project_workspace_buck( + self, + ) -> None: + buck_target = "//foo/bar:main" + + role = Role(name="__IGNORED__", image="foo") + out = self.tmpdir / "workspace-merged" + TestWorkspace(out).build_workspace_and_update_role2( + role, + Workspace(projects={buck_target: ""}), + cfg={}, + ) + self.assertEqual(role.image, "bar") + self.assertEqual(role.metadata["workspace"], buck_target) + + def test_build_and_update_role2_multi_project_workspace(self) -> None: + proj1 = self.tmpdir / "github" / "torch" + proj1.mkdir(parents=True) + (proj1 / "torch.py").touch() + + proj2 = self.tmpdir / "github" / "verl" + proj2.mkdir(parents=True) + (proj2 / "verl.py").touch() + + file1 = self.tmpdir / ".torchxconfig" + file1.touch() + + role = Role(name="__IGNORED__", image="foo") + workspace = Workspace( + projects={ + str(proj1): "", + str(proj2): "verl", + str(file1): "verl/.torchxconfig", + } + ) + + out = self.tmpdir / "workspace-merged" + TestWorkspace(out).build_workspace_and_update_role2( + role, + workspace, + cfg={}, + ) + + self.assertEqual(role.image, "bar") + self.assertTrue((out / "torch.py").exists()) + self.assertTrue((out / "verl" / "verl.py").exists()) + self.assertTrue((out / "verl" / ".torchxconfig").exists())