diff --git a/src/lightning/fabric/loggers/csv_logs.py b/src/lightning/fabric/loggers/csv_logs.py index dd7dfc63671f0..1e43a77fb4bea 100644 --- a/src/lightning/fabric/loggers/csv_logs.py +++ b/src/lightning/fabric/loggers/csv_logs.py @@ -45,6 +45,9 @@ class CSVLogger(Logger): overwritten. prefix: A string to put at the beginning of metric keys. flush_logs_every_n_steps: How often to flush logs to disk (defaults to every 100 steps). + sub_dir: Sub-directory to group CSV logs. If a ``sub_dir`` argument is passed + then logs are saved in ``/root_dir/name/version/sub_dir/``. Defaults to ``None`` in which case + logs are saved in ``/root_dir/name/version/``. Example:: @@ -65,6 +68,7 @@ def __init__( version: Optional[Union[int, str]] = None, prefix: str = "", flush_logs_every_n_steps: int = 100, + sub_dir: Optional[_PATH] = None, ): super().__init__() root_dir = os.fspath(root_dir) @@ -75,6 +79,7 @@ def __init__( self._fs = get_filesystem(root_dir) self._experiment: Optional[_ExperimentWriter] = None self._flush_logs_every_n_steps = flush_logs_every_n_steps + self._sub_dir = None if sub_dir is None else os.fspath(sub_dir) @property @override @@ -117,7 +122,22 @@ def log_dir(self) -> str: """ # create a pseudo standard path version = self.version if isinstance(self.version, str) else f"version_{self.version}" - return os.path.join(self._root_dir, self.name, version) + log_dir = os.path.join(self.root_dir, self.name, version) + if isinstance(self.sub_dir, str): + log_dir = os.path.join(log_dir, self.sub_dir) + log_dir = os.path.expandvars(log_dir) + log_dir = os.path.expanduser(log_dir) + return log_dir + + @property + def sub_dir(self) -> Optional[str]: + """Gets the sub directory where the CSV experiments are saved. + + Returns: + The local path to the sub directory where the CSV experiments are saved. + + """ + return self._sub_dir @property @rank_zero_experiment diff --git a/src/lightning/pytorch/loggers/csv_logs.py b/src/lightning/pytorch/loggers/csv_logs.py index 5ad7353310af4..3ccd9ab00901a 100644 --- a/src/lightning/pytorch/loggers/csv_logs.py +++ b/src/lightning/pytorch/loggers/csv_logs.py @@ -92,6 +92,7 @@ def __init__( version: Optional[Union[int, str]] = None, prefix: str = "", flush_logs_every_n_steps: int = 100, + sub_dir: Optional[_PATH] = None, ): super().__init__( root_dir=save_dir, @@ -99,6 +100,7 @@ def __init__( version=version, prefix=prefix, flush_logs_every_n_steps=flush_logs_every_n_steps, + sub_dir=sub_dir, ) self._save_dir = os.fspath(save_dir) @@ -124,7 +126,12 @@ def log_dir(self) -> str: """ # create a pseudo standard path version = self.version if isinstance(self.version, str) else f"version_{self.version}" - return os.path.join(self.root_dir, version) + log_dir = os.path.join(self.root_dir, version) + if isinstance(self.sub_dir, str): + log_dir = os.path.join(log_dir, self.sub_dir) + log_dir = os.path.expandvars(log_dir) + log_dir = os.path.expanduser(log_dir) + return log_dir @property @override diff --git a/tests/tests_fabric/loggers/test_csv.py b/tests/tests_fabric/loggers/test_csv.py index 08ed3990c2435..04169c20e2e17 100644 --- a/tests/tests_fabric/loggers/test_csv.py +++ b/tests/tests_fabric/loggers/test_csv.py @@ -91,6 +91,36 @@ def test_no_name(tmp_path, name): assert os.listdir(tmp_path / "version_0") +def test_csv_log_sub_dir(tmp_path): + # no sub_dir specified + root_dir = tmp_path / "logs" + logger = CSVLogger(root_dir, name="name", version="version") + assert logger.log_dir == os.path.join(root_dir, "name", "version") + + # sub_dir specified + logger = CSVLogger(root_dir, name="name", version="version", sub_dir="sub_dir") + assert logger.log_dir == os.path.join(root_dir, "name", "version", "sub_dir") + + +def test_csv_expand_home(): + """Test that the home dir (`~`) gets expanded properly.""" + root_dir = "~/tmp" + explicit_root_dir = os.path.expanduser(root_dir) + logger = CSVLogger(root_dir, name="name", version="version", sub_dir="sub_dir") + assert logger.root_dir == root_dir + assert logger.log_dir == os.path.join(explicit_root_dir, "name", "version", "sub_dir") + + +@mock.patch.dict(os.environ, {"TEST_ENV_DIR": "some_directory"}) +def test_tensorboard_expand_env_vars(): + """Test that the env vars in path names (`$`) get handled properly.""" + test_env_dir = os.environ["TEST_ENV_DIR"] + root_dir = "$TEST_ENV_DIR/tmp" + explicit_root_dir = f"{test_env_dir}/tmp" + logger = CSVLogger(root_dir, name="name", version="version", sub_dir="sub_dir") + assert logger.log_dir == os.path.join(explicit_root_dir, "name", "version", "sub_dir") + + @pytest.mark.parametrize("step_idx", [10, None]) def test_log_metrics(tmp_path, step_idx): logger = CSVLogger(tmp_path) diff --git a/tests/tests_pytorch/loggers/test_csv.py b/tests/tests_pytorch/loggers/test_csv.py index 3b1e4dc91e391..a29ffe5cca4ca 100644 --- a/tests/tests_pytorch/loggers/test_csv.py +++ b/tests/tests_pytorch/loggers/test_csv.py @@ -89,6 +89,37 @@ def test_no_name(tmp_path, name): assert os.listdir(tmp_path / "version_0") +def test_csv_log_sub_dir(tmp_path): + # no sub_dir specified + root_dir = tmp_path / "logs" + logger = CSVLogger(root_dir, name="name", version="version") + assert logger.log_dir == os.path.join(root_dir, "name", "version") + + # sub_dir specified + logger = CSVLogger(root_dir, name="name", version="version", sub_dir="sub_dir") + assert logger.log_dir == os.path.join(root_dir, "name", "version", "sub_dir") + + +def test_csv_expand_home(): + """Test that the home dir (`~`) gets expanded properly.""" + save_dir = os.path.join("~", "tmp") + root_dir = os.path.join(save_dir, "name") + explicit_root_dir = os.path.expanduser(save_dir) + logger = CSVLogger(save_dir, name="name", version="version", sub_dir="sub_dir") + assert logger.root_dir == root_dir + assert logger.log_dir == os.path.join(explicit_root_dir, "name", "version", "sub_dir") + + +@mock.patch.dict(os.environ, {"TEST_ENV_DIR": "some_directory"}) +def test_tensorboard_expand_env_vars(): + """Test that the env vars in path names (`$`) get handled properly.""" + test_env_dir = os.environ["TEST_ENV_DIR"] + root_dir = "$TEST_ENV_DIR/tmp" + explicit_root_dir = f"{test_env_dir}/tmp" + logger = CSVLogger(root_dir, name="name", version="version", sub_dir="sub_dir") + assert logger.log_dir == os.path.join(explicit_root_dir, "name", "version", "sub_dir") + + @pytest.mark.parametrize("step_idx", [10, None]) def test_log_metrics(tmp_path, step_idx): logger = CSVLogger(tmp_path)