Skip to content

Commit 35051d4

Browse files
authored
Merge branch 'main' into main
2 parents 12ea91b + a264c75 commit 35051d4

File tree

5 files changed

+69
-0
lines changed

5 files changed

+69
-0
lines changed

torchx/runner/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,8 @@ def find_configs(dirs: Optional[Iterable[str]] = None) -> List[str]:
494494

495495
config = os.getenv(ENV_TORCHXCONFIG)
496496
if config is not None:
497+
if not config:
498+
return []
497499
configfile = Path(config)
498500
if not configfile.is_file():
499501
raise FileNotFoundError(

torchx/runner/test/config_test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,12 @@ def test_get_configs(self) -> None:
275275
),
276276
)
277277

278+
def test_no_config(self) -> None:
279+
config_dir = self.tmpdir
280+
with patch.dict(os.environ, {ENV_TORCHXCONFIG: str("")}):
281+
configs = find_configs(dirs=[str(config_dir)])
282+
self.assertEqual([], configs)
283+
278284
def test_find_configs(self) -> None:
279285
config_dir = self.tmpdir
280286
cwd_dir = config_dir / "cwd"

torchx/specs/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
RoleStatus,
4242
runopt,
4343
runopts,
44+
TORCHX_HOME,
4445
UnknownAppException,
4546
UnknownSchedulerException,
4647
VolumeMount,
@@ -53,6 +54,7 @@
5354

5455
GiB: int = 1024
5556

57+
5658
ResourceFactory = Callable[[], Resource]
5759

5860
AWS_NAMED_RESOURCES: Mapping[str, ResourceFactory] = import_attr(

torchx/specs/api.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
import inspect
1212
import json
1313
import logging as logger
14+
import os
15+
import pathlib
1416
import re
1517
import typing
1618
from dataclasses import asdict, dataclass, field
@@ -66,6 +68,32 @@
6668
RESET = "\033[0m"
6769

6870

71+
def TORCHX_HOME(*subdir_paths: str) -> pathlib.Path:
72+
"""
73+
Path to the "dot-directory" for torchx.
74+
Defaults to `~/.torchx` and is overridable via the `TORCHX_HOME` environment variable.
75+
76+
Usage:
77+
78+
.. doc-test::
79+
80+
from pathlib import Path
81+
from torchx.specs import TORCHX_HOME
82+
83+
assert TORCHX_HOME() == Path.home() / ".torchx"
84+
assert TORCHX_HOME("conda-pack-out") == Path.home() / ".torchx" / "conda-pack-out"
85+
```
86+
"""
87+
88+
default_dir = str(pathlib.Path.home() / ".torchx")
89+
torchx_home = pathlib.Path(os.getenv("TORCHX_HOME", default_dir))
90+
91+
torchx_home = torchx_home / os.path.sep.join(subdir_paths)
92+
torchx_home.mkdir(parents=True, exist_ok=True)
93+
94+
return torchx_home
95+
96+
6997
# ========================================
7098
# ==== Distributed AppDef API =======
7199
# ========================================

torchx/specs/test/api_test.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,13 @@
1010
import asyncio
1111
import concurrent
1212
import os
13+
import tempfile
1314
import time
1415
import unittest
1516
from dataclasses import asdict
17+
from pathlib import Path
1618
from typing import Dict, List, Mapping, Tuple, Union
19+
from unittest import mock
1720
from unittest.mock import MagicMock
1821

1922
import torchx.specs.named_resources_aws as named_resources_aws
@@ -40,9 +43,37 @@
4043
RoleStatus,
4144
runopt,
4245
runopts,
46+
TORCHX_HOME,
4347
)
4448

4549

50+
class TorchXHomeTest(unittest.TestCase):
51+
# guard against TORCHX_HOME set outside the test
52+
@mock.patch.dict(os.environ, {}, clear=True)
53+
def test_TORCHX_HOME_default(self) -> None:
54+
with tempfile.TemporaryDirectory() as tmpdir:
55+
user_home = Path(tmpdir) / "sally"
56+
with mock.patch("pathlib.Path.home", return_value=user_home):
57+
torchx_home = TORCHX_HOME()
58+
self.assertEqual(torchx_home, user_home / ".torchx")
59+
self.assertTrue(torchx_home.exists())
60+
61+
def test_TORCHX_HOME_override(self) -> None:
62+
with tempfile.TemporaryDirectory() as tmpdir:
63+
override_torchx_home = Path(tmpdir) / "test" / ".torchx"
64+
with mock.patch.dict(
65+
os.environ, {"TORCHX_HOME": str(override_torchx_home)}
66+
):
67+
torchx_home = TORCHX_HOME()
68+
conda_pack_out = TORCHX_HOME("conda-pack", "out")
69+
70+
self.assertEqual(override_torchx_home, torchx_home)
71+
self.assertEqual(torchx_home / "conda-pack" / "out", conda_pack_out)
72+
73+
self.assertTrue(torchx_home.is_dir())
74+
self.assertTrue(conda_pack_out.is_dir())
75+
76+
4677
class AppDryRunInfoTest(unittest.TestCase):
4778
def test_repr(self) -> None:
4879
request_mock = MagicMock()

0 commit comments

Comments
 (0)