diff --git a/src/lightning/fabric/loggers/csv_logs.py b/src/lightning/fabric/loggers/csv_logs.py index dd7dfc63671f0..efa0221158344 100644 --- a/src/lightning/fabric/loggers/csv_logs.py +++ b/src/lightning/fabric/loggers/csv_logs.py @@ -18,6 +18,7 @@ from argparse import Namespace from typing import Any, Optional, Union +from fsspec.implementations import local from torch import Tensor from typing_extensions import override @@ -207,6 +208,8 @@ def __init__(self, log_dir: str) -> None: self.log_dir = log_dir self.metrics_file_path = os.path.join(self.log_dir, self.NAME_METRICS_FILE) + self._is_local_fs = isinstance(self._fs, local.LocalFileSystem) + self._check_log_dir_exists() self._fs.makedirs(self.log_dir, exist_ok=True) @@ -230,37 +233,52 @@ def save(self) -> None: if not self.metrics: return + # Update column list with any new metrics keys new_keys = self._record_new_keys() - file_exists = self._fs.isfile(self.metrics_file_path) - - if new_keys and file_exists: - # we need to re-write the file if the keys (header) change - self._rewrite_with_new_header(self.metrics_keys) - with self._fs.open(self.metrics_file_path, mode=("a" if file_exists else "w"), newline="") as file: - writer = csv.DictWriter(file, fieldnames=self.metrics_keys) - if not file_exists: - # only write the header if we're writing a fresh file - writer.writeheader() - writer.writerows(self.metrics) + file_exists = self._fs.isfile(self.metrics_file_path) - self.metrics = [] # reset + # Decision logic: when can we safely append? + # 1. Must be local filesystem (remote FS don't support append) + # 2. File must already exist + # 3. No new columns (otherwise CSV header would be wrong) + can_append = self._is_local_fs and file_exists and not new_keys + + if can_append: + # Safe to append: local FS + existing file + same columns + self._write_metrics(self.metrics, mode="a", write_header=False) + else: + # Need to rewrite: new file OR remote FS OR new columns + all_metrics = self.metrics + if file_exists: + # Include existing data when rewriting + all_metrics = self._read_existing_metrics() + self.metrics + self._write_metrics(all_metrics, mode="w", write_header=True) + + self.metrics = [] def _record_new_keys(self) -> set[str]: - """Records new keys that have not been logged before.""" + """Identifies and records any new metric keys that have not been previously logged.""" current_keys = set().union(*self.metrics) new_keys = current_keys - set(self.metrics_keys) self.metrics_keys.extend(new_keys) self.metrics_keys.sort() return new_keys - def _rewrite_with_new_header(self, fieldnames: list[str]) -> None: - with self._fs.open(self.metrics_file_path, "r", newline="") as file: - metrics = list(csv.DictReader(file)) - - with self._fs.open(self.metrics_file_path, "w", newline="") as file: - writer = csv.DictWriter(file, fieldnames=fieldnames) - writer.writeheader() + def _read_existing_metrics(self) -> list[dict[str, Any]]: + """Read all existing metrics from the CSV file.""" + try: + with self._fs.open(self.metrics_file_path, "r", newline="") as file: + return list(csv.DictReader(file)) + except (FileNotFoundError, OSError): + return [] + + def _write_metrics(self, metrics: list[dict[str, Any]], mode: str, write_header: bool) -> None: + """Write metrics to CSV file with the specified mode and header option.""" + with self._fs.open(self.metrics_file_path, mode=mode, newline="") as file: + writer = csv.DictWriter(file, fieldnames=self.metrics_keys) + if write_header: + writer.writeheader() writer.writerows(metrics) def _check_log_dir_exists(self) -> None: diff --git a/tests/tests_fabric/loggers/test_csv.py b/tests/tests_fabric/loggers/test_csv.py index 08ed3990c2435..832d0257d71f2 100644 --- a/tests/tests_fabric/loggers/test_csv.py +++ b/tests/tests_fabric/loggers/test_csv.py @@ -178,14 +178,19 @@ def test_append_columns(tmp_path): # new key appears logger.log_metrics({"a": 1, "b": 2, "c": 3}) with open(logger.experiment.metrics_file_path) as file: - header = file.readline().strip() - assert set(header.split(",")) == {"step", "a", "b", "c"} + lines = file.readlines() + header = lines[0].strip() + assert header.split(",") == ["a", "b", "c", "step"] + assert len(lines) == 3 # header + 2 data rows # key disappears logger.log_metrics({"a": 1, "c": 3}) + logger.save() with open(logger.experiment.metrics_file_path) as file: - header = file.readline().strip() - assert set(header.split(",")) == {"step", "a", "b", "c"} + lines = file.readlines() + header = lines[0].strip() + assert header.split(",") == ["a", "b", "c", "step"] + assert len(lines) == 4 # header + 3 data rows @mock.patch( @@ -193,21 +198,27 @@ def test_append_columns(tmp_path): "lightning.fabric.loggers.csv_logs._ExperimentWriter._check_log_dir_exists" ) def test_rewrite_with_new_header(_, tmp_path): - # write a csv file manually - with open(tmp_path / "metrics.csv", "w") as file: - file.write("step,metric1,metric2\n") - file.write("0,1,22\n") + """Test that existing files get rewritten correctly when new columns are added.""" + # write a csv file manually to simulate existing data + csv_path = tmp_path / "metrics.csv" + with open(csv_path, "w") as file: + file.write("a,b,step\n") + file.write("1,2,0\n") writer = _ExperimentWriter(log_dir=str(tmp_path)) - new_columns = ["step", "metric1", "metric2", "metric3"] - writer._rewrite_with_new_header(new_columns) - # the rewritten file should have the new columns - with open(tmp_path / "metrics.csv") as file: - header = file.readline().strip().split(",") - assert header == new_columns - logs = file.readline().strip().split(",") - assert logs == ["0", "1", "22", ""] + # Add metrics with a new column + writer.log_metrics({"a": 2, "b": 3, "c": 4}, step=1) + writer.save() + # The rewritten file should have the new columns and preserve old data + with open(csv_path) as file: + lines = file.readlines() + assert len(lines) == 3 # header + 2 data rows + header = lines[0].strip() + assert header.split(",") == ["a", "b", "c", "step"] + # verify old data is preserved + assert lines[1].strip().split(",") == ["1", "2", "", "0"] # old row with empty new column + assert lines[2].strip().split(",") == ["2", "3", "4", "1"] def test_log_metrics_column_order_sorted(tmp_path): @@ -221,8 +232,66 @@ def test_log_metrics_column_order_sorted(tmp_path): logger.log_metrics({"d": 0.5}) logger.save() - path_csv = os.path.join(logger.log_dir, _ExperimentWriter.NAME_METRICS_FILE) - with open(path_csv) as fp: + with open(logger.experiment.metrics_file_path) as fp: lines = fp.readlines() assert lines[0].strip() == "a,b,c,d,step" + + +@mock.patch("lightning.fabric.loggers.csv_logs.get_filesystem") +@mock.patch("lightning.fabric.loggers.csv_logs._ExperimentWriter._read_existing_metrics") +def test_remote_filesystem_uses_write_mode(mock_read_existing, mock_get_fs, tmp_path): + """Test that remote filesystems use write mode.""" + mock_fs = MagicMock() + mock_fs.isfile.return_value = False # File doesn't exist + mock_fs.makedirs = MagicMock() + mock_get_fs.return_value = mock_fs + + logger = CSVLogger(tmp_path) + assert not logger.experiment._is_local_fs + + logger.log_metrics({"a": 0.3}, step=1) + logger.save() + + # Verify _read_existing_metrics was NOT called (file doesn't exist) + mock_read_existing.assert_not_called() + + # Verify write mode was used (remote FS should never use append) + mock_fs.open.assert_called() + call_args = mock_fs.open.call_args_list[-1] # Get the last call + + # Extract the mode parameter specifically + args, kwargs = call_args + mode = kwargs.get("mode", "r") # Default to 'r' if mode not specified + assert mode == "w", f"Expected write mode 'w', but got mode: '{mode}'" + + +@mock.patch("lightning.fabric.loggers.csv_logs.get_filesystem") +@mock.patch("lightning.fabric.loggers.csv_logs._ExperimentWriter._read_existing_metrics") +def test_remote_filesystem_preserves_existing_data(mock_read_existing, mock_get_fs, tmp_path): + """Test that remote filesystem reads existing data and preserves it when rewriting.""" + # Mock remote filesystem with existing file + mock_fs = MagicMock() + mock_fs.isfile.return_value = True + mock_fs.makedirs = MagicMock() + mock_get_fs.return_value = mock_fs + + # Mock existing data + mock_read_existing.return_value = [{"a": 0.1, "step": 0}, {"a": 0.2, "step": 1}] + + logger = CSVLogger(tmp_path) + assert not logger.experiment._is_local_fs + + # Add new metrics - should read existing and combine + logger.log_metrics({"a": 0.3}, step=2) + logger.save() + + # Verify that _read_existing_metrics was called (should read existing data) + mock_read_existing.assert_called_once() + + # Verify write mode was used + mock_fs.open.assert_called() + last_call = mock_fs.open.call_args_list[-1] + args, kwargs = last_call + mode = kwargs.get("mode", "r") + assert mode == "w", f"Expected write mode 'w', but got mode: '{mode}'"