15
15
"""Tests for compiler_opt.rl.log_reader."""
16
16
17
17
import ctypes
18
+ import enum
18
19
import json
19
20
from compiler_opt .rl import log_reader
20
21
@@ -30,83 +31,113 @@ def json_to_bytes(d) -> bytes:
30
31
return json .dumps (d ).encode ('utf-8' )
31
32
32
33
33
- nl = '\n ' .encode ('utf-8' )
34
-
35
-
36
- def write_buff (f : BinaryIO , buffer : list , ct ):
37
- # we should get the ctypes array to bytes for pytype to be happy.
38
- f .write ((ct * len (buffer ))(* buffer )) # pytype:disable=wrong-arg-types
39
-
40
-
41
- def write_context_marker (f : BinaryIO , name : str ):
42
- f .write (json_to_bytes ({'context' : name }))
43
- f .write (nl )
44
-
45
-
46
- def write_observation_marker (f : BinaryIO , obs_idx : int ):
47
- f .write (json_to_bytes ({'observation' : obs_idx }))
48
- f .write (nl )
49
-
50
-
51
- def write_nl (f : BinaryIO ):
52
- f .write (nl )
53
-
54
-
55
- def write_outcome_marker (f : BinaryIO , obs_idx : int ):
56
- f .write (json_to_bytes ({'outcome' : obs_idx }))
57
- f .write (nl )
58
-
59
-
60
- def create_example (fname : str , nr_contexts = 1 ):
34
+ class LogTestExampleBuilder :
35
+ """Construct a log."""
36
+
37
+ newline = b'\n '
38
+ error_newline = b'hi there'
39
+
40
+ class ErrorMarkers (enum .IntEnum ):
41
+ NONE = 0
42
+ AFTER_HEADER = enum .auto ()
43
+ CTX_MARKER_POS = enum .auto ()
44
+ OBS_MARKER_POS = enum .auto ()
45
+ OUTCOME_MARKER_POS = enum .auto ()
46
+ TENSOR_BUF_POS = enum .auto ()
47
+ TENSORS_POS = enum .auto ()
48
+ OUTCOME_POS = enum .auto ()
49
+
50
+ def __init__ (
51
+ self ,
52
+ * ,
53
+ opened_file : BinaryIO ,
54
+ introduce_error_pos : ErrorMarkers = ErrorMarkers .NONE ,
55
+ ):
56
+ self ._opened_file = opened_file
57
+ self ._introduce_error_pos = introduce_error_pos
58
+
59
+ def write_buff (self , buffer : list , ct ):
60
+ # we should get the ctypes array to bytes for pytype to be happy.
61
+ if self ._introduce_error_pos == self .ErrorMarkers .TENSOR_BUF_POS :
62
+ buffer = buffer [len (buffer ) // 2 :]
63
+ # pytype:disable=wrong-arg-types
64
+ self ._opened_file .write ((ct * len (buffer ))(* buffer ))
65
+ # pytype:enable=wrong-arg-types
66
+
67
+ def write_newline (self , position = None ):
68
+ self ._opened_file .write (self .error_newline if position ==
69
+ self ._introduce_error_pos else self .newline )
70
+
71
+ def write_context_marker (self , name : str ):
72
+ self ._opened_file .write (json_to_bytes ({'context' : name }))
73
+ self .write_newline (self .ErrorMarkers .CTX_MARKER_POS )
74
+
75
+ def write_observation_marker (self , obs_idx : int ):
76
+ self ._opened_file .write (json_to_bytes ({'observation' : obs_idx }))
77
+ self .write_newline (self .ErrorMarkers .OBS_MARKER_POS )
78
+
79
+ def write_outcome_marker (self , obs_idx : int ):
80
+ self ._opened_file .write (json_to_bytes ({'outcome' : obs_idx }))
81
+ self .write_newline (self .ErrorMarkers .OUTCOME_MARKER_POS )
82
+
83
+ def write_header (self , json_header : dict ):
84
+ self ._opened_file .write (json_to_bytes (json_header ))
85
+
86
+
87
+ def create_example (fname : str ,
88
+ * ,
89
+ nr_contexts = 1 ,
90
+ introduce_errors_pos : LogTestExampleBuilder
91
+ .ErrorMarkers = LogTestExampleBuilder .ErrorMarkers .NONE ):
61
92
t0_val = [0.1 , 0.2 , 0.3 , 0.4 , 0.5 , 0.6 ]
62
93
t1_val = [1 , 2 , 3 ]
63
94
s = [1.2 ]
64
95
65
96
with open (fname , 'wb' ) as f :
66
- f .write (
67
- json_to_bytes ({
68
- 'features' : [{
69
- 'name' : 'tensor_name2' ,
70
- 'port' : 0 ,
71
- 'shape' : [2 , 3 ],
72
- 'type' : 'float' ,
73
- }, {
74
- 'name' : 'tensor_name1' ,
75
- 'port' : 0 ,
76
- 'shape' : [3 , 1 ],
77
- 'type' : 'int64_t' ,
78
- }],
79
- 'score' : {
80
- 'name' : 'reward' ,
81
- 'port' : 0 ,
82
- 'shape' : [1 ],
83
- 'type' : 'float'
84
- }
85
- }))
86
- write_nl (f )
97
+ example_writer = LogTestExampleBuilder (
98
+ opened_file = f , introduce_error_pos = introduce_errors_pos )
99
+ example_writer .write_header ({
100
+ 'features' : [{
101
+ 'name' : 'tensor_name2' ,
102
+ 'port' : 0 ,
103
+ 'shape' : [2 , 3 ],
104
+ 'type' : 'float' ,
105
+ }, {
106
+ 'name' : 'tensor_name1' ,
107
+ 'port' : 0 ,
108
+ 'shape' : [3 , 1 ],
109
+ 'type' : 'int64_t' ,
110
+ }],
111
+ 'score' : {
112
+ 'name' : 'reward' ,
113
+ 'port' : 0 ,
114
+ 'shape' : [1 ],
115
+ 'type' : 'float'
116
+ }
117
+ })
118
+ example_writer .write_newline (
119
+ LogTestExampleBuilder .ErrorMarkers .AFTER_HEADER )
87
120
for ctx_id in range (nr_contexts ):
88
121
t0_val = [v + ctx_id * 10 for v in t0_val ]
89
122
t1_val = [v + ctx_id * 10 for v in t1_val ]
90
- write_context_marker (f , f'context_nr_{ ctx_id } ' )
91
- write_observation_marker (f , 0 )
92
- write_buff (f , t0_val , ctypes .c_float )
93
- write_buff (f , t1_val , ctypes .c_int64 )
94
- write_nl (f )
95
- write_outcome_marker (f , 0 )
96
- write_buff (f , s , ctypes .c_float )
97
- write_nl (f )
98
-
123
+ example_writer .write_context_marker (f'context_nr_{ ctx_id } ' )
124
+
125
+ def write_example_obs (obs : int ):
126
+ example_writer .write_observation_marker (obs )
127
+ example_writer .write_buff (t0_val , ctypes .c_float )
128
+ example_writer .write_buff (t1_val , ctypes .c_int64 )
129
+ example_writer .write_newline (
130
+ LogTestExampleBuilder .ErrorMarkers .TENSORS_POS )
131
+ example_writer .write_outcome_marker (obs )
132
+ example_writer .write_buff (s , ctypes .c_float )
133
+ example_writer .write_newline (
134
+ LogTestExampleBuilder .ErrorMarkers .OUTCOME_POS )
135
+
136
+ write_example_obs (0 )
99
137
t0_val = [v + 1 for v in t0_val ]
100
138
t1_val = [v + 1 for v in t1_val ]
101
139
s [0 ] += 1
102
-
103
- write_observation_marker (f , 1 )
104
- write_buff (f , t0_val , ctypes .c_float )
105
- write_buff (f , t1_val , ctypes .c_int64 )
106
- write_nl (f )
107
- write_outcome_marker (f , 1 )
108
- write_buff (f , s , ctypes .c_float )
109
- write_nl (f )
140
+ write_example_obs (1 )
110
141
111
142
112
143
class LogReaderTest (tf .test .TestCase ):
@@ -246,6 +277,41 @@ def test_seq_example_conversion(self):
246
277
""" , tf .train .SequenceExample ())
247
278
self .assertProtoEquals (expected_ctx_0 , seq_examples ['context_nr_0' ])
248
279
280
+ def test_errors (self ):
281
+ logfile = self .create_tempfile ()
282
+ for error_marker in LogTestExampleBuilder .ErrorMarkers :
283
+ if not error_marker :
284
+ continue
285
+ create_example (logfile , introduce_errors_pos = error_marker )
286
+ with self .assertRaises (Exception ):
287
+ log_reader .read_log_as_sequence_examples (logfile )
288
+
289
+ def test_truncated_tensors (self ):
290
+ logfile = self .create_tempfile ()
291
+ with open (logfile , 'wb' ) as f :
292
+ writer = LogTestExampleBuilder (opened_file = f )
293
+ writer .write_header ({
294
+ 'features' : [{
295
+ 'name' : 'tensor_name' ,
296
+ 'port' : 0 ,
297
+ 'shape' : [2 , 3 ],
298
+ 'type' : 'float' ,
299
+ }],
300
+ 'score' : {
301
+ 'name' : 'reward' ,
302
+ 'port' : 0 ,
303
+ 'shape' : [1 ],
304
+ 'type' : 'float'
305
+ }
306
+ })
307
+ writer .write_newline ()
308
+ writer .write_context_marker ('whatever' )
309
+ writer .write_observation_marker (0 )
310
+ writer .write_buff ([1 ], ctypes .c_int16 )
311
+
312
+ with self .assertRaises (Exception ):
313
+ log_reader .read_log_as_sequence_examples (logfile )
314
+
249
315
250
316
if __name__ == '__main__' :
251
317
tf .test .main ()
0 commit comments