Skip to content

Commit d85dbe4

Browse files
committed
refactor: enhance CSVLogger's metric writing logic for local and remote filesystem support
1 parent 011f0c3 commit d85dbe4

File tree

1 file changed

+38
-21
lines changed

1 file changed

+38
-21
lines changed

src/lightning/fabric/loggers/csv_logs.py

Lines changed: 38 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,8 @@ def __init__(self, log_dir: str) -> None:
208208
self.log_dir = log_dir
209209
self.metrics_file_path = os.path.join(self.log_dir, self.NAME_METRICS_FILE)
210210

211+
self._is_local_fs = isinstance(self._fs, local.LocalFileSystem)
212+
211213
self._check_log_dir_exists()
212214
self._fs.makedirs(self.log_dir, exist_ok=True)
213215

@@ -231,38 +233,53 @@ def save(self) -> None:
231233
if not self.metrics:
232234
return
233235

236+
# Update column list with any new metrics keys
234237
new_keys = self._record_new_keys()
235-
file_exists = self._fs.isfile(self.metrics_file_path)
236-
rewrite_file = not isinstance(self._fs, local.LocalFileSystem) or new_keys
237-
238-
if rewrite_file and file_exists:
239-
self._append_recorded_metrics()
240238

241-
with self._fs.open(self.metrics_file_path, mode=("a" if not rewrite_file else "w"), newline="") as file:
242-
writer = csv.DictWriter(file, fieldnames=self.metrics_keys)
243-
if rewrite_file:
244-
writer.writeheader()
245-
writer.writerows(self.metrics)
239+
file_exists = self._fs.isfile(self.metrics_file_path)
246240

247-
self.metrics = [] # reset
241+
# Decision logic: when can we safely append?
242+
# 1. Must be local filesystem (remote FS don't support append)
243+
# 2. File must already exist
244+
# 3. No new columns (otherwise CSV header would be wrong)
245+
can_append = self._is_local_fs and file_exists and not new_keys
246+
247+
if can_append:
248+
# Safe to append: local FS + existing file + same columns
249+
self._write_metrics(self.metrics, mode="a", write_header=False)
250+
else:
251+
# Need to rewrite: new file OR remote FS OR new columns
252+
all_metrics = self.metrics
253+
if file_exists:
254+
# Include existing data when rewriting
255+
all_metrics = self._read_existing_metrics() + self.metrics
256+
self._write_metrics(all_metrics, mode="w", write_header=True)
257+
258+
self.metrics = []
248259

249260
def _record_new_keys(self) -> set[str]:
250-
"""Records new keys that have not been logged before."""
261+
"""Identifies and records any new metric keys that have not been previously logged."""
251262
current_keys = set().union(*self.metrics)
252263
new_keys = current_keys - set(self.metrics_keys)
253264
self.metrics_keys.extend(new_keys)
254265
self.metrics_keys.sort()
255266
return new_keys
256267

257-
def _append_recorded_metrics(self) -> None:
258-
"""Appends the previous recorded metrics to the current ``self.metrics``."""
259-
metrics = self._fetch_recorded_metrics()
260-
self.metrics = metrics + self.metrics
261-
262-
def _fetch_recorded_metrics(self) -> list[dict[str, Any]]:
263-
"""Fetches the previous recorded metrics."""
264-
with self._fs.open(self.metrics_file_path, "r", newline="") as file:
265-
return list(csv.DictReader(file))
268+
def _read_existing_metrics(self) -> list[dict[str, Any]]:
269+
"""Read all existing metrics from the CSV file."""
270+
try:
271+
with self._fs.open(self.metrics_file_path, "r", newline="") as file:
272+
return list(csv.DictReader(file))
273+
except (FileNotFoundError, OSError):
274+
return []
275+
276+
def _write_metrics(self, metrics: list[dict[str, Any]], mode: str, write_header: bool) -> None:
277+
"""Write metrics to CSV file with the specified mode and header option."""
278+
with self._fs.open(self.metrics_file_path, mode=mode, newline="") as file:
279+
writer = csv.DictWriter(file, fieldnames=self.metrics_keys)
280+
if write_header:
281+
writer.writeheader()
282+
writer.writerows(metrics)
266283

267284
def _check_log_dir_exists(self) -> None:
268285
if self._fs.exists(self.log_dir) and self._fs.listdir(self.log_dir):

0 commit comments

Comments
 (0)