|
18 | 18 | from torchx.runner.config import ( |
19 | 19 | apply, |
20 | 20 | dump, |
| 21 | + ENV_TORCHXCONFIG, |
| 22 | + find_configs, |
21 | 23 | get_config, |
22 | 24 | get_configs, |
23 | 25 | load, |
@@ -259,6 +261,44 @@ def test_get_configs(self) -> None: |
259 | 261 | ), |
260 | 262 | ) |
261 | 263 |
|
| 264 | + def test_find_configs(self) -> None: |
| 265 | + config_dir = Path(self.test_dir) |
| 266 | + cwd_dir = config_dir / "cwd" |
| 267 | + custom_dir = config_dir / "custom" |
| 268 | + |
| 269 | + cwd_dir.mkdir() |
| 270 | + custom_dir.mkdir() |
| 271 | + |
| 272 | + base_config = config_dir / ".torchxconfig" |
| 273 | + cwd_config = cwd_dir / ".torchxconfig" |
| 274 | + override_config = custom_dir / ".torchxconfig" |
| 275 | + |
| 276 | + base_config.touch() |
| 277 | + cwd_config.touch() |
| 278 | + override_config.touch() |
| 279 | + |
| 280 | + # should find configs in the specified dirs |
| 281 | + configs = find_configs(dirs=[str(config_dir)]) |
| 282 | + self.assertEqual([str(base_config)], configs) |
| 283 | + |
| 284 | + # should find config in cwd if no dirs is specified |
| 285 | + with patch(PATH_CWD, return_value=str(cwd_dir)): |
| 286 | + configs = find_configs() |
| 287 | + self.assertEqual([str(cwd_config)], configs) |
| 288 | + |
| 289 | + # if TORCHXCONFIG env var exists then should just return the config specified |
| 290 | + with patch.dict(os.environ, {ENV_TORCHXCONFIG: str(override_config)}): |
| 291 | + configs = find_configs(dirs=[str(config_dir)]) |
| 292 | + self.assertEqual([str(override_config)], configs) |
| 293 | + |
| 294 | + # if TORCHXCONFIG points to a non-existing file, then assert exception |
| 295 | + with patch.dict( |
| 296 | + os.environ, |
| 297 | + {ENV_TORCHXCONFIG: str(config_dir / ".torchxconfig_nonexistent")}, |
| 298 | + ): |
| 299 | + with self.assertRaises(FileNotFoundError): |
| 300 | + find_configs(dirs=[str(config_dir)]) |
| 301 | + |
262 | 302 | def test_get_config(self) -> None: |
263 | 303 | configdir0 = Path(self.test_dir) / "test_load_component_defaults" / "0" |
264 | 304 | configdir1 = Path(self.test_dir) / "test_load_component_defaults" / "1" |
@@ -286,6 +326,20 @@ def test_get_config(self) -> None: |
286 | 326 | get_config(prefix="badprefix", name="dist.ddp", key="j", dirs=dirs), |
287 | 327 | ) |
288 | 328 |
|
| 329 | + # check that if TORCHXCONFIG is set then only that config is loaded |
| 330 | + override_config = Path(self.test_dir) / ".torchxconfig_custom" |
| 331 | + override_config_contents = """ |
| 332 | +[component:dist.ddp] |
| 333 | +image = foobar_custom |
| 334 | + """ |
| 335 | + self._write(str(override_config), override_config_contents) |
| 336 | + |
| 337 | + with patch.dict(os.environ, {ENV_TORCHXCONFIG: str(override_config)}): |
| 338 | + self.assertDictEqual( |
| 339 | + {"image": "foobar_custom"}, |
| 340 | + get_configs(prefix="component", name="dist.ddp", dirs=dirs), |
| 341 | + ) |
| 342 | + |
289 | 343 | def test_load(self) -> None: |
290 | 344 | cfg = {} |
291 | 345 | load(scheduler="local_cwd", f=StringIO(_CONFIG), cfg=cfg) |
@@ -328,33 +382,6 @@ def test_apply_dirs(self, _) -> None: |
328 | 382 | self.assertEqual(100, cfg.get("i")) |
329 | 383 | self.assertEqual(1.2, cfg.get("f")) |
330 | 384 |
|
331 | | - @patch( |
332 | | - TORCHX_GET_SCHEDULERS, |
333 | | - return_value={"test": TestScheduler()}, |
334 | | - ) |
335 | | - def test_apply_from_annotated_config(self, _) -> None: |
336 | | - def mock_getenv(x: str, y: Optional[str] = None) -> Optional[str]: |
337 | | - if x != "TORCHX_CONFIG": |
338 | | - return os.environ.get(x, y) |
339 | | - else: |
340 | | - return f"{self.test_dir}/another_torchx_config" |
341 | | - |
342 | | - with patch(PATH_CWD, return_value=Path(self.test_dir)): |
343 | | - with patch( |
344 | | - "torchx.runner.config.os.getenv", |
345 | | - mock_getenv, |
346 | | - ): |
347 | | - cfg: Dict[str, CfgVal] = {"s": "runtime_value"} |
348 | | - apply( |
349 | | - scheduler="test", |
350 | | - cfg=cfg, |
351 | | - # these dirs will be ignored |
352 | | - dirs=[str(Path(self.test_dir) / "home" / "bob"), self.test_dir], |
353 | | - ) |
354 | | - self.assertEqual("runtime_value", cfg.get("s")) |
355 | | - self.assertEqual(200, cfg.get("i")) |
356 | | - self.assertEqual(None, cfg.get("f")) |
357 | | - |
358 | 385 | def test_dump_invalid_scheduler(self) -> None: |
359 | 386 | with self.assertRaises(ValueError): |
360 | 387 | dump(f=StringIO(), schedulers=["does-not-exist"]) |
|
0 commit comments