Skip to content

Commit b8aaa54

Browse files
authored
log_reader: tighten parsing (#438)
The log format has a few places where we insert `\n` for human readability. They should be checked they are indeed just that character. Similarly, checking that the tensor data received matches in size what was expected. Refactored a bit the test utility for constructing examples.
1 parent a8559c1 commit b8aaa54

File tree

3 files changed

+172
-93
lines changed

3 files changed

+172
-93
lines changed

compiler_opt/rl/env_test.py

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -89,23 +89,23 @@ def kill(self):
8989
with io.FileIO(fname + '.out', 'wb+') as f_out:
9090
with io.FileIO(fname + '.in', 'rb+') as f_in:
9191
del f_in
92+
writer = log_reader_test.LogTestExampleBuilder(opened_file=f_out)
9293
# Write the header describing the features/rewards
93-
f_out.write(
94-
log_reader_test.json_to_bytes({
95-
'features': [{
96-
'name': 'times_called',
97-
'port': 0,
98-
'shape': [1],
99-
'type': 'int64_t',
100-
},],
101-
'score': {
102-
'name': 'reward',
103-
'port': 0,
104-
'shape': [1],
105-
'type': 'float',
106-
},
107-
}))
108-
log_reader_test.write_nl(f_out)
94+
writer.write_header({
95+
'features': [{
96+
'name': 'times_called',
97+
'port': 0,
98+
'shape': [1],
99+
'type': 'int64_t',
100+
},],
101+
'score': {
102+
'name': 'reward',
103+
'port': 0,
104+
'shape': [1],
105+
'type': 'float',
106+
},
107+
})
108+
writer.write_newline()
109109

110110
class MockInteractiveProcess(MockProcess):
111111
"""Mock clang interactive process that writes the log."""
@@ -120,14 +120,15 @@ def poll(self):
120120
if self._counter >= _NUM_STEPS:
121121
f_out.close()
122122
return None
123-
log_reader_test.write_context_marker(f_out,
124-
f'context_{self._counter}')
125-
log_reader_test.write_observation_marker(f_out, 0)
126-
log_reader_test.write_buff(f_out, [self._counter], ctypes.c_int64)
127-
log_reader_test.write_nl(f_out)
128-
log_reader_test.write_outcome_marker(f_out, 0)
129-
log_reader_test.write_buff(f_out, [3.14], ctypes.c_float)
130-
log_reader_test.write_nl(f_out)
123+
example_writer = log_reader_test.LogTestExampleBuilder(
124+
opened_file=f_out)
125+
example_writer.write_context_marker(f'context_{self._counter}')
126+
example_writer.write_observation_marker(0)
127+
example_writer.write_buff([self._counter], ctypes.c_int64)
128+
example_writer.write_newline()
129+
example_writer.write_outcome_marker(0)
130+
example_writer.write_buff([3.14], ctypes.c_float)
131+
example_writer.write_newline()
131132
self._counter += 1
132133
return None
133134

compiler_opt/rl/log_reader.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,10 @@ class _Header:
144144
def _read_tensor(fs: BinaryIO, ts: tf.TensorSpec) -> LogReaderTensorValue:
145145
size = math.prod(ts.shape) * ctypes.sizeof(_dtype_to_ctype[ts.dtype])
146146
data = fs.read(size)
147+
if len(data) != size:
148+
raise IOError(
149+
f'Expected to read a total of {size} bytes for tensors, got {len(data)}'
150+
)
147151
return LogReaderTensorValue(ts, data)
148152

149153

@@ -175,20 +179,28 @@ def _enumerate_log_from_stream(
175179
tensor_specs = header.features
176180
score_spec = header.score
177181
context = None
182+
183+
def expect_newline():
184+
expected = f.readline().decode('utf-8')
185+
if '\n' != expected:
186+
raise IOError(f'Expected newline in log stream, got {expected}')
187+
178188
while event_str := f.readline():
179189
event = json.loads(event_str)
180190
if 'context' in event:
181191
context = event['context']
182192
continue
183193
observation_id = int(event['observation'])
184194
features = [_read_tensor(f, ts) for ts in tensor_specs]
185-
f.readline()
195+
expect_newline()
186196
score = None
187197
if score_spec is not None:
188198
score_header = json.loads(f.readline())
189-
assert int(score_header['outcome']) == observation_id
199+
if int(score_header['outcome']) != observation_id:
200+
raise IOError(f'Expected observation ID {observation_id} \
201+
got {score_header["outcome"]}')
190202
score = _read_tensor(f, score_spec)
191-
f.readline()
203+
expect_newline()
192204
yield ObservationRecord(
193205
context=context,
194206
observation_id=observation_id,

compiler_opt/rl/log_reader_test.py

Lines changed: 132 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""Tests for compiler_opt.rl.log_reader."""
1616

1717
import ctypes
18+
import enum
1819
import json
1920
from compiler_opt.rl import log_reader
2021

@@ -30,83 +31,113 @@ def json_to_bytes(d) -> bytes:
3031
return json.dumps(d).encode('utf-8')
3132

3233

33-
nl = '\n'.encode('utf-8')
34-
35-
36-
def write_buff(f: BinaryIO, buffer: list, ct):
37-
# we should get the ctypes array to bytes for pytype to be happy.
38-
f.write((ct * len(buffer))(*buffer)) # pytype:disable=wrong-arg-types
39-
40-
41-
def write_context_marker(f: BinaryIO, name: str):
42-
f.write(json_to_bytes({'context': name}))
43-
f.write(nl)
44-
45-
46-
def write_observation_marker(f: BinaryIO, obs_idx: int):
47-
f.write(json_to_bytes({'observation': obs_idx}))
48-
f.write(nl)
49-
50-
51-
def write_nl(f: BinaryIO):
52-
f.write(nl)
53-
54-
55-
def write_outcome_marker(f: BinaryIO, obs_idx: int):
56-
f.write(json_to_bytes({'outcome': obs_idx}))
57-
f.write(nl)
58-
59-
60-
def create_example(fname: str, nr_contexts=1):
34+
class LogTestExampleBuilder:
35+
"""Construct a log."""
36+
37+
newline = b'\n'
38+
error_newline = b'hi there'
39+
40+
class ErrorMarkers(enum.IntEnum):
41+
NONE = 0
42+
AFTER_HEADER = enum.auto()
43+
CTX_MARKER_POS = enum.auto()
44+
OBS_MARKER_POS = enum.auto()
45+
OUTCOME_MARKER_POS = enum.auto()
46+
TENSOR_BUF_POS = enum.auto()
47+
TENSORS_POS = enum.auto()
48+
OUTCOME_POS = enum.auto()
49+
50+
def __init__(
51+
self,
52+
*,
53+
opened_file: BinaryIO,
54+
introduce_error_pos: ErrorMarkers = ErrorMarkers.NONE,
55+
):
56+
self._opened_file = opened_file
57+
self._introduce_error_pos = introduce_error_pos
58+
59+
def write_buff(self, buffer: list, ct):
60+
# we should get the ctypes array to bytes for pytype to be happy.
61+
if self._introduce_error_pos == self.ErrorMarkers.TENSOR_BUF_POS:
62+
buffer = buffer[len(buffer) // 2:]
63+
# pytype:disable=wrong-arg-types
64+
self._opened_file.write((ct * len(buffer))(*buffer))
65+
# pytype:enable=wrong-arg-types
66+
67+
def write_newline(self, position=None):
68+
self._opened_file.write(self.error_newline if position ==
69+
self._introduce_error_pos else self.newline)
70+
71+
def write_context_marker(self, name: str):
72+
self._opened_file.write(json_to_bytes({'context': name}))
73+
self.write_newline(self.ErrorMarkers.CTX_MARKER_POS)
74+
75+
def write_observation_marker(self, obs_idx: int):
76+
self._opened_file.write(json_to_bytes({'observation': obs_idx}))
77+
self.write_newline(self.ErrorMarkers.OBS_MARKER_POS)
78+
79+
def write_outcome_marker(self, obs_idx: int):
80+
self._opened_file.write(json_to_bytes({'outcome': obs_idx}))
81+
self.write_newline(self.ErrorMarkers.OUTCOME_MARKER_POS)
82+
83+
def write_header(self, json_header: dict):
84+
self._opened_file.write(json_to_bytes(json_header))
85+
86+
87+
def create_example(fname: str,
88+
*,
89+
nr_contexts=1,
90+
introduce_errors_pos: LogTestExampleBuilder
91+
.ErrorMarkers = LogTestExampleBuilder.ErrorMarkers.NONE):
6192
t0_val = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
6293
t1_val = [1, 2, 3]
6394
s = [1.2]
6495

6596
with open(fname, 'wb') as f:
66-
f.write(
67-
json_to_bytes({
68-
'features': [{
69-
'name': 'tensor_name2',
70-
'port': 0,
71-
'shape': [2, 3],
72-
'type': 'float',
73-
}, {
74-
'name': 'tensor_name1',
75-
'port': 0,
76-
'shape': [3, 1],
77-
'type': 'int64_t',
78-
}],
79-
'score': {
80-
'name': 'reward',
81-
'port': 0,
82-
'shape': [1],
83-
'type': 'float'
84-
}
85-
}))
86-
write_nl(f)
97+
example_writer = LogTestExampleBuilder(
98+
opened_file=f, introduce_error_pos=introduce_errors_pos)
99+
example_writer.write_header({
100+
'features': [{
101+
'name': 'tensor_name2',
102+
'port': 0,
103+
'shape': [2, 3],
104+
'type': 'float',
105+
}, {
106+
'name': 'tensor_name1',
107+
'port': 0,
108+
'shape': [3, 1],
109+
'type': 'int64_t',
110+
}],
111+
'score': {
112+
'name': 'reward',
113+
'port': 0,
114+
'shape': [1],
115+
'type': 'float'
116+
}
117+
})
118+
example_writer.write_newline(
119+
LogTestExampleBuilder.ErrorMarkers.AFTER_HEADER)
87120
for ctx_id in range(nr_contexts):
88121
t0_val = [v + ctx_id * 10 for v in t0_val]
89122
t1_val = [v + ctx_id * 10 for v in t1_val]
90-
write_context_marker(f, f'context_nr_{ctx_id}')
91-
write_observation_marker(f, 0)
92-
write_buff(f, t0_val, ctypes.c_float)
93-
write_buff(f, t1_val, ctypes.c_int64)
94-
write_nl(f)
95-
write_outcome_marker(f, 0)
96-
write_buff(f, s, ctypes.c_float)
97-
write_nl(f)
98-
123+
example_writer.write_context_marker(f'context_nr_{ctx_id}')
124+
125+
def write_example_obs(obs: int):
126+
example_writer.write_observation_marker(obs)
127+
example_writer.write_buff(t0_val, ctypes.c_float)
128+
example_writer.write_buff(t1_val, ctypes.c_int64)
129+
example_writer.write_newline(
130+
LogTestExampleBuilder.ErrorMarkers.TENSORS_POS)
131+
example_writer.write_outcome_marker(obs)
132+
example_writer.write_buff(s, ctypes.c_float)
133+
example_writer.write_newline(
134+
LogTestExampleBuilder.ErrorMarkers.OUTCOME_POS)
135+
136+
write_example_obs(0)
99137
t0_val = [v + 1 for v in t0_val]
100138
t1_val = [v + 1 for v in t1_val]
101139
s[0] += 1
102-
103-
write_observation_marker(f, 1)
104-
write_buff(f, t0_val, ctypes.c_float)
105-
write_buff(f, t1_val, ctypes.c_int64)
106-
write_nl(f)
107-
write_outcome_marker(f, 1)
108-
write_buff(f, s, ctypes.c_float)
109-
write_nl(f)
140+
write_example_obs(1)
110141

111142

112143
class LogReaderTest(tf.test.TestCase):
@@ -246,6 +277,41 @@ def test_seq_example_conversion(self):
246277
""", tf.train.SequenceExample())
247278
self.assertProtoEquals(expected_ctx_0, seq_examples['context_nr_0'])
248279

280+
def test_errors(self):
281+
logfile = self.create_tempfile()
282+
for error_marker in LogTestExampleBuilder.ErrorMarkers:
283+
if not error_marker:
284+
continue
285+
create_example(logfile, introduce_errors_pos=error_marker)
286+
with self.assertRaises(Exception):
287+
log_reader.read_log_as_sequence_examples(logfile)
288+
289+
def test_truncated_tensors(self):
290+
logfile = self.create_tempfile()
291+
with open(logfile, 'wb') as f:
292+
writer = LogTestExampleBuilder(opened_file=f)
293+
writer.write_header({
294+
'features': [{
295+
'name': 'tensor_name',
296+
'port': 0,
297+
'shape': [2, 3],
298+
'type': 'float',
299+
}],
300+
'score': {
301+
'name': 'reward',
302+
'port': 0,
303+
'shape': [1],
304+
'type': 'float'
305+
}
306+
})
307+
writer.write_newline()
308+
writer.write_context_marker('whatever')
309+
writer.write_observation_marker(0)
310+
writer.write_buff([1], ctypes.c_int16)
311+
312+
with self.assertRaises(Exception):
313+
log_reader.read_log_as_sequence_examples(logfile)
314+
249315

250316
if __name__ == '__main__':
251317
tf.test.main()

0 commit comments

Comments
 (0)