Skip to content

Commit fb6bdc5

Browse files
James Sunfacebook-github-bot
authored andcommitted
link launch and sync conda/workspace locations (#742)
Summary: Pull Request resolved: #742 X-link: #742 Make sure the conda/workspace locations during launch map with the locations when we sync. Reviewed By: kiukchung Differential Revision: D79516268 fbshipit-source-id: 80ed66f3dfc04b35c2fd66bc20ad910bdb800070
1 parent 5b5a94b commit fb6bdc5

File tree

4 files changed

+79
-21
lines changed

4 files changed

+79
-21
lines changed

python/monarch/_src/actor/proc_mesh.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
import asyncio
1010
import logging
11-
import os
1211
import sys
1312
import threading
1413
import warnings
@@ -70,6 +69,8 @@
7069
from monarch._src.actor.endpoint import endpoint
7170
from monarch._src.actor.future import DeprecatedNotAFuture, Future
7271
from monarch._src.actor.shape import MeshTrait
72+
from monarch.tools.config import Workspace
73+
from monarch.tools.utils import conda as conda_utils
7374

7475
HAS_TENSOR_ENGINE = False
7576
try:
@@ -369,7 +370,10 @@ def rank_tensors(self) -> Dict[str, "Tensor"]:
369370
return self._device_mesh.ranks
370371

371372
async def sync_workspace(
372-
self, conda: bool = False, auto_reload: bool = False
373+
self,
374+
workspace: Workspace = None,
375+
conda: bool = False,
376+
auto_reload: bool = False,
373377
) -> None:
374378
if self._code_sync_client is None:
375379
self._code_sync_client = CodeSyncMeshClient.spawn_blocking(
@@ -382,21 +386,21 @@ async def sync_workspace(
382386
# The workspace shape (i.e. only perform one rsync per host).
383387
assert set(self._shape.labels).issubset({"gpus", "hosts"})
384388

385-
# TODO(agallagher): Is there a better way to infer/set the local
386-
# workspace dir, rather than use PWD?
387-
workspaces = [
388-
WorkspaceConfig(
389-
local=Path(os.getcwd()),
390-
remote=RemoteWorkspace(
391-
location=WorkspaceLocation.FromEnvVar("WORKSPACE_DIR"),
392-
shape=WorkspaceShape.shared("gpus"),
389+
workspaces = []
390+
if workspace is not None:
391+
workspaces.append(
392+
WorkspaceConfig(
393+
local=Path(workspace),
394+
remote=RemoteWorkspace(
395+
location=WorkspaceLocation.FromEnvVar("WORKSPACE_DIR"),
396+
shape=WorkspaceShape.shared("gpus"),
397+
),
398+
method=CodeSyncMethod.Rsync,
393399
),
394-
method=CodeSyncMethod.Rsync,
395-
),
396-
]
400+
)
397401

398402
# If `conda` is set, also sync the currently activated conda env.
399-
conda_prefix = os.environ.get("CONDA_PREFIX")
403+
conda_prefix = conda_utils.active_env_dir()
400404
if conda and conda_prefix is not None:
401405
workspaces.append(
402406
WorkspaceConfig(

python/monarch/tools/config/__init__.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@
66

77
# pyre-strict
88
from dataclasses import dataclass, field
9-
from typing import Any, Dict, List, Optional
9+
from typing import Any, Dict, List, TYPE_CHECKING
1010

11-
from torchx.specs import Role
11+
# Defer the import of Role to avoid requiring torchx at import time
12+
if TYPE_CHECKING:
13+
from torchx.specs import Role
1214

1315

1416
NOT_SET: str = "__NOT_SET__"
@@ -20,10 +22,18 @@ class UnnamedAppDef:
2022
A TorchX AppDef without a name.
2123
"""
2224

23-
roles: List[Role] = field(default_factory=list)
25+
roles: List["Role"] = field(default_factory=list)
2426
metadata: Dict[str, str] = field(default_factory=dict)
2527

2628

29+
# TODO: provide a proper Workspace class to support
30+
# - multiple workspaces
31+
# - empty workspaces
32+
# - no workspace
33+
# - experimental directories
34+
Workspace = str | None
35+
36+
2737
@dataclass
2838
class Config:
2939
"""
@@ -32,6 +42,6 @@ class Config:
3242

3343
scheduler: str = NOT_SET
3444
scheduler_args: dict[str, Any] = field(default_factory=dict)
35-
workspace: Optional[str] = None
45+
workspace: Workspace = None
3646
dryrun: bool = False
3747
appdef: UnnamedAppDef = field(default_factory=UnnamedAppDef)

python/monarch/tools/config/defaults.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88

99
"""Defines defaults for ``monarch.tools``"""
1010

11-
from typing import Callable, Optional
11+
from typing import Callable
1212

1313
from monarch.tools.components import hyperactor
14-
from monarch.tools.config import Config, UnnamedAppDef
14+
from monarch.tools.config import Config, UnnamedAppDef, Workspace
1515

1616
from torchx import specs
1717
from torchx.schedulers import (
@@ -40,7 +40,7 @@ def scheduler_factories() -> dict[str, SchedulerFactory]:
4040
}
4141

4242

43-
def config(scheduler: str, workspace: Optional[str] = None) -> Config:
43+
def config(scheduler: str, workspace: Workspace = None) -> Config:
4444
"""The default :py:class:`~monarch.tools.config.Config` to use when submitting to the provided ``scheduler``."""
4545
return Config(scheduler=scheduler, workspace=workspace)
4646

python/tests/test_python_actors.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
local_proc_mesh,
3838
proc_mesh,
3939
)
40+
from monarch.tools.config import defaults
4041
from typing_extensions import assert_type
4142

4243

@@ -950,6 +951,49 @@ async def test_same_actor_twice() -> None:
950951
), f"Expected error message about duplicate actor name, got: {error_msg}"
951952

952953

954+
class LsActor(Actor):
955+
def __init__(self, workspace: str):
956+
self.workspace = workspace
957+
958+
@endpoint
959+
async def ls(self) -> list[str]:
960+
return os.listdir(self.workspace)
961+
962+
963+
async def test_sync_workspace() -> None:
964+
pm = await proc_mesh(gpus=1)
965+
966+
# create two workspaces: one for local and one for remote
967+
with tempfile.TemporaryDirectory() as workspace_src, tempfile.TemporaryDirectory() as workspace_dst, unittest.mock.patch.dict(
968+
os.environ, {"WORKSPACE_DIR": workspace_dst}
969+
):
970+
os.environ["WORKSPACE_DIR"] = workspace_dst
971+
config = defaults.config("slurm", workspace_src)
972+
await pm.sync_workspace(
973+
workspace=config.workspace, conda=False, auto_reload=True
974+
)
975+
976+
# now file in remote workspace initially
977+
am = await pm.spawn("ls", LsActor, workspace_dst)
978+
for item in list(am.ls.call().get()):
979+
assert len(item[1]) == 0
980+
981+
# write a file to local workspace
982+
file_path = os.path.join(workspace_src, "new_file")
983+
with open(file_path, "w") as f:
984+
f.write("hello world")
985+
f.flush()
986+
987+
# force a sync and it should populate on the dst workspace
988+
await pm.sync_workspace(config.workspace, conda=False, auto_reload=True)
989+
for item in list(am.ls.call().get()):
990+
assert len(item[1]) == 1
991+
assert item[1][0] == "new_file"
992+
file_path = os.path.join(workspace_dst, item[1][0])
993+
with open(file_path, "r") as f:
994+
assert f.readline() == "hello world"
995+
996+
953997
class TestActorMeshStop(unittest.IsolatedAsyncioTestCase):
954998
async def test_actor_mesh_stop(self) -> None:
955999
pm = proc_mesh(gpus=2)

0 commit comments

Comments
 (0)