Skip to content

Commit 6d47bf1

Browse files
authored
Fix expanding home directory for Trainer's default_root_dir (#19179)
1 parent 6dfa5cc commit 6d47bf1

File tree

3 files changed

+14
-1
lines changed

3 files changed

+14
-1
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
8383
- Fixed issue where the `precision="transformer-engine"` argument would not replace layers by default ([#19082](https://github.com/Lightning-AI/lightning/pull/19082))
8484

8585

86+
- Fixed `Trainer` not expanding the `default_root_dir` if it has the `~` (home) prefix ([#19179](https://github.com/Lightning-AI/lightning/pull/19179))
87+
88+
8689
## [2.1.2] - 2023-11-15
8790

8891
### Fixed

src/lightning/pytorch/trainer/trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1285,7 +1285,7 @@ def default_root_dir(self) -> str:
12851285
12861286
"""
12871287
if _is_local_file_protocol(self._default_root_dir):
1288-
return os.path.normpath(self._default_root_dir)
1288+
return os.path.normpath(os.path.expanduser(self._default_root_dir))
12891289
return self._default_root_dir
12901290

12911291
@property

tests/tests_pytorch/trainer/test_trainer.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2064,3 +2064,13 @@ def test_init_module_context(monkeypatch):
20642064
with pytest.warns(PossibleUserWarning, match="can't place .* on the device"), trainer.init_module():
20652065
pass
20662066
strategy.tensor_init_context.assert_called_once()
2067+
2068+
2069+
def test_expand_home_trainer():
2070+
"""Test that the dirpath gets expanded if it contains `~`."""
2071+
home_root = Path.home()
2072+
2073+
trainer = Trainer(default_root_dir="~/trainer")
2074+
assert trainer.default_root_dir == str(home_root / "trainer")
2075+
trainer = Trainer(default_root_dir=Path("~/trainer"))
2076+
assert trainer.default_root_dir == str(home_root / "trainer")

0 commit comments

Comments
 (0)