22
22
from google .protobuf import text_format # pytype: disable=pyi-error
23
23
from typing import BinaryIO
24
24
25
+ import numpy as np
25
26
import tensorflow as tf
26
27
27
28
@@ -38,26 +39,25 @@ def write_buff(f: BinaryIO, buffer: list, ct):
38
39
39
40
40
41
def write_context_marker (f : BinaryIO , name : str ):
41
- f .write (nl )
42
42
f .write (json_to_bytes ({'context' : name }))
43
+ f .write (nl )
43
44
44
45
45
46
def write_observation_marker (f : BinaryIO , obs_idx : int ):
46
- f .write (nl )
47
47
f .write (json_to_bytes ({'observation' : obs_idx }))
48
+ f .write (nl )
48
49
49
50
50
- def begin_features (f : BinaryIO ):
51
+ def write_nl (f : BinaryIO ):
51
52
f .write (nl )
52
53
53
54
54
55
def write_outcome_marker (f : BinaryIO , obs_idx : int ):
55
- f .write (nl )
56
56
f .write (json_to_bytes ({'outcome' : obs_idx }))
57
+ f .write (nl )
57
58
58
59
59
60
def create_example (fname : str , nr_contexts = 1 ):
60
-
61
61
t0_val = [0.1 , 0.2 , 0.3 , 0.4 , 0.5 , 0.6 ]
62
62
t1_val = [1 , 2 , 3 ]
63
63
s = [1.2 ]
@@ -69,12 +69,12 @@ def create_example(fname: str, nr_contexts=1):
69
69
'name' : 'tensor_name2' ,
70
70
'port' : 0 ,
71
71
'shape' : [2 , 3 ],
72
- 'type' : 'float'
72
+ 'type' : 'float' ,
73
73
}, {
74
74
'name' : 'tensor_name1' ,
75
75
'port' : 0 ,
76
76
'shape' : [3 , 1 ],
77
- 'type' : 'int64_t'
77
+ 'type' : 'int64_t' ,
78
78
}],
79
79
'score' : {
80
80
'name' : 'reward' ,
@@ -83,29 +83,30 @@ def create_example(fname: str, nr_contexts=1):
83
83
'type' : 'float'
84
84
}
85
85
}))
86
+ write_nl (f )
86
87
for ctx_id in range (nr_contexts ):
87
88
t0_val = [v + ctx_id * 10 for v in t0_val ]
88
89
t1_val = [v + ctx_id * 10 for v in t1_val ]
89
90
write_context_marker (f , f'context_nr_{ ctx_id } ' )
90
91
write_observation_marker (f , 0 )
91
- begin_features (f )
92
92
write_buff (f , t0_val , ctypes .c_float )
93
93
write_buff (f , t1_val , ctypes .c_int64 )
94
+ write_nl (f )
94
95
write_outcome_marker (f , 0 )
95
- begin_features (f )
96
96
write_buff (f , s , ctypes .c_float )
97
+ write_nl (f )
97
98
98
99
t0_val = [v + 1 for v in t0_val ]
99
100
t1_val = [v + 1 for v in t1_val ]
100
101
s [0 ] += 1
101
102
102
103
write_observation_marker (f , 1 )
103
- begin_features (f )
104
104
write_buff (f , t0_val , ctypes .c_float )
105
105
write_buff (f , t1_val , ctypes .c_int64 )
106
+ write_nl (f )
106
107
write_outcome_marker (f , 1 )
107
- begin_features (f )
108
108
write_buff (f , s , ctypes .c_float )
109
+ write_nl (f )
109
110
110
111
111
112
class LogReaderTest (tf .test .TestCase ):
@@ -155,6 +156,19 @@ def test_read_log(self):
155
156
obs_id += 1
156
157
self .assertEqual (obs_id , 2 )
157
158
159
+ def test_to_numpy (self ):
160
+ logfile = self .create_tempfile ()
161
+ create_example (logfile )
162
+ t0_val = [0.1 , 0.2 , 0.3 , 0.4 , 0.5 , 0.6 ]
163
+ t1_val = [1 , 2 , 3 ]
164
+ for record in log_reader .read_log (logfile ):
165
+ np .testing .assert_allclose (record .feature_values [0 ].to_numpy (),
166
+ np .array (t0_val ))
167
+ np .testing .assert_allclose (record .feature_values [1 ].to_numpy (),
168
+ np .array (t1_val ))
169
+ t0_val = [v + 1 for v in t0_val ]
170
+ t1_val = [v + 1 for v in t1_val ]
171
+
158
172
def test_seq_example_conversion (self ):
159
173
logfile = self .create_tempfile ()
160
174
create_example (logfile , nr_contexts = 2 )
0 commit comments