@@ -37,6 +37,7 @@ def __init__(self, filename, separator=",", append=False):
37
37
self .writer = None
38
38
self .keys = None
39
39
self .append_header = True
40
+ self .csv_file = None
40
41
41
42
def on_train_begin (self , logs = None ):
42
43
if self .append :
@@ -46,7 +47,13 @@ def on_train_begin(self, logs=None):
46
47
mode = "a"
47
48
else :
48
49
mode = "w"
50
+ # ensure csv_file is None or closed before reassigning
51
+ if self .csv_file and not self .csv_file .closed :
52
+ self .csv_file .close ()
49
53
self .csv_file = file_utils .File (self .filename , mode )
54
+ # Reset writer and keys
55
+ self .writer = None
56
+ self .keys = None
50
57
51
58
def on_epoch_end (self , epoch , logs = None ):
52
59
logs = logs or {}
@@ -65,22 +72,21 @@ def handle_value(k):
65
72
66
73
if self .keys is None :
67
74
self .keys = sorted (logs .keys ())
68
- # When validation_freq > 1, `val_` keys are not in first epoch logs
69
- # Add the `val_` keys so that its part of the fieldnames of writer.
75
+
70
76
val_keys_found = False
71
77
for key in self .keys :
72
78
if key .startswith ("val_" ):
73
79
val_keys_found = True
74
80
break
75
- if not val_keys_found :
81
+ if not val_keys_found and self . keys :
76
82
self .keys .extend (["val_" + k for k in self .keys ])
77
83
78
84
if not self .writer :
79
85
80
86
class CustomDialect (csv .excel ):
81
87
delimiter = self .sep
82
88
83
- fieldnames = ["epoch" ] + self .keys
89
+ fieldnames = ["epoch" ] + ( self .keys or [])
84
90
85
91
self .writer = csv .DictWriter (
86
92
self .csv_file , fieldnames = fieldnames , dialect = CustomDialect
@@ -96,5 +102,6 @@ class CustomDialect(csv.excel):
96
102
self .csv_file .flush ()
97
103
98
104
def on_train_end (self , logs = None ):
99
- self .csv_file .close ()
105
+ if self .csv_file and not self .csv_file .closed :
106
+ self .csv_file .close ()
100
107
self .writer = None
0 commit comments