Skip to content

Commit 27e2d7d

Browse files
Do not use iterations done in json log
1 parent 047a3b3 commit 27e2d7d

File tree

3 files changed

+78
-52
lines changed

3 files changed

+78
-52
lines changed

blocks/log/json.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -65,46 +65,55 @@ class JSONLinesLog(TrainingLogBase):
6565
def __init__(self, filename='log.jsonl.gz', maxlen=101, formatter=None,
6666
**kwargs):
6767
self.status = {}
68-
TrainingLogBase.__init__(self)
68+
super(JSONLinesLog, self).__init__()
6969
if os.path.isfile(filename):
7070
os.remove(filename)
7171
self.logger = PicklableLogger(
7272
filename=filename, maxlen=maxlen, formatter=formatter, **kwargs)
7373
self.local_cache = deque()
7474

7575
def flush(self, iterations_done):
76+
if iterations_done < 0:
77+
raise ValueError
7678
if len(self.local_cache) > 0:
7779
self.logger.log({'iterations_done': iterations_done,
7880
'reports': self.local_cache.popleft()})
7981

8082
def __getitem__(self, time):
8183
self._check_time(time)
82-
iterations_done = self.status.get('iterations_done', -1)
84+
logger_len = self.inner_logger_len()
85+
total_length = logger_len + len(self.local_cache)
8386

8487
# Flush local cache
8588
while len(self.local_cache) > 1:
86-
self.flush(iterations_done - len(self.local_cache) + 1)
89+
self.flush(total_length - len(self.local_cache))
90+
logger_len = self.inner_logger_len()
8791

88-
total_length = len(self.logger) + len(self.local_cache)
8992
if time >= total_length:
9093
# Need to create new item in local cache
9194
self.local_cache.extend(
9295
[{} for _ in range(time - total_length + 1)])
93-
last_logged_element = len(self.logger)
94-
if time < last_logged_element:
96+
if time < logger_len:
9597
try:
96-
assert self.logger[time]['iterations_done'] == time
98+
if not self.logger[time]['iterations_done'] == time:
99+
raise ValueError('iterations done')
97100
return self.logger[time]['reports']
98101
except IndexError:
99102
raise ValueError(
100103
'cannot get past log entries for JSON log, max log length '
101104
'in memory is: {}'.format(
102105
self.logger.logger_kwargs['maxlen']))
103-
if time >= last_logged_element:
104-
return self.local_cache[time - last_logged_element]
106+
if time >= logger_len:
107+
return self.local_cache[time - logger_len]
108+
109+
def inner_logger_len(self):
110+
try:
111+
return len(self.logger)
112+
except AttributeError:
113+
return 0
105114

106115
def __len__(self):
107-
return len(self.logger) + len(self.local_cache)
116+
return self.inner_logger_len() + len(self.local_cache)
108117

109118
def __setitem__(self, time, value):
110119
raise ValueError('cannot manually change JSON Lines log')
@@ -115,3 +124,6 @@ def __enter__(self):
115124
def __exit__(self, exc_type, exc_val, exc_tb):
116125
self.flush(self.status.get('iterations_done', -1))
117126
self.logger.close()
127+
128+
def __iter__(self):
129+
return iter([self[i] for i in range(len(self))])

tests/extensions/test_training.py

Lines changed: 28 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -90,33 +90,34 @@ def test_track_the_best():
9090
extension = TrackTheBest("cost")
9191
extension.main_loop = main_loop
9292

93-
main_loop.status['epochs_done'] += 1
94-
main_loop.status['iterations_done'] += 10
95-
main_loop.log.current_row['cost'] = 5
96-
extension.dispatch('after_epoch')
97-
assert main_loop.status['best_cost'] == 5
98-
assert main_loop.log.current_row['cost_best_so_far']
99-
100-
main_loop.status['epochs_done'] += 1
101-
main_loop.status['iterations_done'] += 10
102-
main_loop.log.current_row['cost'] = 6
103-
extension.dispatch('after_epoch')
104-
assert main_loop.status['best_cost'] == 5
105-
assert main_loop.log.current_row.get('cost_best_so_far', None) is None
106-
107-
main_loop.status['epochs_done'] += 1
108-
main_loop.status['iterations_done'] += 10
109-
main_loop.log.current_row['cost'] = 5
110-
extension.dispatch('after_epoch')
111-
assert main_loop.status['best_cost'] == 5
112-
assert main_loop.log.current_row.get('cost_best_so_far', None) is None
113-
114-
main_loop.status['epochs_done'] += 1
115-
main_loop.status['iterations_done'] += 10
116-
main_loop.log.current_row['cost'] = 4
117-
extension.dispatch('after_epoch')
118-
assert main_loop.status['best_cost'] == 4
119-
assert main_loop.log.current_row['cost_best_so_far']
93+
with main_loop.log:
94+
main_loop.status['epochs_done'] += 1
95+
main_loop.status['iterations_done'] += 10
96+
main_loop.log.current_row['cost'] = 5
97+
extension.dispatch('after_epoch')
98+
assert main_loop.status['best_cost'] == 5
99+
assert main_loop.log.current_row['cost_best_so_far']
100+
101+
main_loop.status['epochs_done'] += 1
102+
main_loop.status['iterations_done'] += 10
103+
main_loop.log.current_row['cost'] = 6
104+
extension.dispatch('after_epoch')
105+
assert main_loop.status['best_cost'] == 5
106+
assert main_loop.log.current_row.get('cost_best_so_far', None) is None
107+
108+
main_loop.status['epochs_done'] += 1
109+
main_loop.status['iterations_done'] += 10
110+
main_loop.log.current_row['cost'] = 5
111+
extension.dispatch('after_epoch')
112+
assert main_loop.status['best_cost'] == 5
113+
assert main_loop.log.current_row.get('cost_best_so_far', None) is None
114+
115+
main_loop.status['epochs_done'] += 1
116+
main_loop.status['iterations_done'] += 10
117+
main_loop.log.current_row['cost'] = 4
118+
extension.dispatch('after_epoch')
119+
assert main_loop.status['best_cost'] == 4
120+
assert main_loop.log.current_row['cost_best_so_far']
120121

121122

122123
class WriteCostExtension(TrainingExtension):

tests/test_log.py

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,30 +3,43 @@
33

44
from numpy.testing import assert_raises
55

6-
from blocks.log import TrainingLog
6+
from blocks.log import TrainingLog, JSONLinesLog
77
from blocks.serialization import load, dump
88

99

10-
def test_training_log():
11-
log = TrainingLog()
10+
def run_log(log):
11+
with log:
12+
# test basic writing capabilities
13+
log[0]['field'] = 45
14+
assert log[0]['field'] == 45
15+
assert log[1] == {}
16+
assert log.current_row['field'] == 45
17+
log.status['iterations_done'] += 1
18+
assert log.status['iterations_done'] == 1
19+
assert log.previous_row['field'] == 45
20+
21+
assert_raises(ValueError, getitem, log, -1)
22+
23+
# test iteration
24+
assert len(list(log)) == 2
25+
1226

13-
# test basic writing capabilities
14-
log[0]['field'] = 45
15-
assert log[0]['field'] == 45
16-
assert log[1] == {}
17-
assert log.current_row['field'] == 45
18-
log.status['iterations_done'] += 1
19-
assert log.status['iterations_done'] == 1
20-
assert log.previous_row['field'] == 45
27+
def test_json_lines_log():
28+
log = JSONLinesLog(maxlen=2)
29+
run_log(log)
2130

22-
assert_raises(ValueError, getitem, log, -1)
2331

24-
# test iteration
25-
assert len(list(log)) == 2
32+
def test_training_log():
33+
log = TrainingLog()
34+
run_log(log)
2635

2736

2837
def test_pickle_log():
29-
log1 = TrainingLog()
38+
log = TrainingLog()
39+
pickle_log(log)
40+
41+
42+
def pickle_log(log1):
3043
with open('log1.tar', 'wb') as f:
3144
dump(log1, f)
3245
with open('log1.tar', 'rb') as f:

0 commit comments

Comments
 (0)