Skip to content

Commit de9cf25

Browse files
fix csv logger bug (#21332)
* fix csv logger bug * address comment * fix tests * fix
1 parent 6d70f49 commit de9cf25

File tree

1 file changed

+12
-5
lines changed

1 file changed

+12
-5
lines changed

keras/src/callbacks/csv_logger.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def __init__(self, filename, separator=",", append=False):
3737
self.writer = None
3838
self.keys = None
3939
self.append_header = True
40+
self.csv_file = None
4041

4142
def on_train_begin(self, logs=None):
4243
if self.append:
@@ -46,7 +47,13 @@ def on_train_begin(self, logs=None):
4647
mode = "a"
4748
else:
4849
mode = "w"
50+
# ensure csv_file is None or closed before reassigning
51+
if self.csv_file and not self.csv_file.closed:
52+
self.csv_file.close()
4953
self.csv_file = file_utils.File(self.filename, mode)
54+
# Reset writer and keys
55+
self.writer = None
56+
self.keys = None
5057

5158
def on_epoch_end(self, epoch, logs=None):
5259
logs = logs or {}
@@ -65,22 +72,21 @@ def handle_value(k):
6572

6673
if self.keys is None:
6774
self.keys = sorted(logs.keys())
68-
# When validation_freq > 1, `val_` keys are not in first epoch logs
69-
# Add the `val_` keys so that its part of the fieldnames of writer.
75+
7076
val_keys_found = False
7177
for key in self.keys:
7278
if key.startswith("val_"):
7379
val_keys_found = True
7480
break
75-
if not val_keys_found:
81+
if not val_keys_found and self.keys:
7682
self.keys.extend(["val_" + k for k in self.keys])
7783

7884
if not self.writer:
7985

8086
class CustomDialect(csv.excel):
8187
delimiter = self.sep
8288

83-
fieldnames = ["epoch"] + self.keys
89+
fieldnames = ["epoch"] + (self.keys or [])
8490

8591
self.writer = csv.DictWriter(
8692
self.csv_file, fieldnames=fieldnames, dialect=CustomDialect
@@ -96,5 +102,6 @@ class CustomDialect(csv.excel):
96102
self.csv_file.flush()
97103

98104
def on_train_end(self, logs=None):
99-
self.csv_file.close()
105+
if self.csv_file and not self.csv_file.closed:
106+
self.csv_file.close()
100107
self.writer = None

0 commit comments

Comments
 (0)