Skip to content

Commit 9d00bcf

Browse files
authored
[NFC] Log reader changes for MLGO environments. (#242)
* Log reader changes for MLGO environments. * Addressing comments. * Fix a few more spurious changes. * Tuple -> Union. * Replace raw_bytes with to_numpy. * Type hint for to_numpy. * NDArray -> ndarray.
1 parent c963a8a commit 9d00bcf

File tree

2 files changed

+45
-15
lines changed

2 files changed

+45
-15
lines changed

compiler_opt/rl/log_reader.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,8 @@
6161
import json
6262
import math
6363

64-
from typing import Any, BinaryIO, Dict, Generator, List, Optional
64+
from typing import Any, BinaryIO, Dict, Generator, List, Optional, Union
65+
import numpy as np
6566
import tensorflow as tf
6667

6768
_element_type_name_map = {
@@ -86,6 +87,11 @@
8687
}
8788

8889

90+
def convert_dtype_to_ctype(dtype: str) -> Union[type, tf.dtypes.DType]:
91+
"""Public interface for the _dtype_to_ctype dict."""
92+
return _dtype_to_ctype[dtype]
93+
94+
8995
def create_tensorspec(d: Dict[str, Any]) -> tf.TensorSpec:
9096
name: str = d['name']
9197
shape: List[int] = [int(e) for e in d['shape']]
@@ -120,6 +126,12 @@ def __init__(self, spec: tf.TensorSpec, buffer: bytes):
120126
def spec(self):
121127
return self._spec
122128

129+
def to_numpy(self) -> np.ndarray:
130+
return np.frombuffer(
131+
self._buffer,
132+
dtype=convert_dtype_to_ctype(self._spec.dtype),
133+
count=self._len)
134+
123135
def _set_view(self):
124136
# c_char_p is a nul-terminated string, so the more appropriate cast here
125137
# would be POINTER(c_char), but unfortunately, c_char_p is the only
@@ -205,11 +217,15 @@ def _enumerate_log_from_stream(
205217
score=score)
206218

207219

220+
def read_log_from_file(f) -> Generator[ObservationRecord, None, None]:
221+
header = _read_header(f)
222+
if header:
223+
yield from _enumerate_log_from_stream(f, header)
224+
225+
208226
def read_log(fname: str) -> Generator[ObservationRecord, None, None]:
209227
with open(fname, 'rb') as f:
210-
header = _read_header(f)
211-
if header:
212-
yield from _enumerate_log_from_stream(f, header)
228+
yield from read_log_from_file(f)
213229

214230

215231
def _add_feature(se: tf.train.SequenceExample, spec: tf.TensorSpec,

compiler_opt/rl/log_reader_test.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from google.protobuf import text_format # pytype: disable=pyi-error
2323
from typing import BinaryIO
2424

25+
import numpy as np
2526
import tensorflow as tf
2627

2728

@@ -38,26 +39,25 @@ def write_buff(f: BinaryIO, buffer: list, ct):
3839

3940

4041
def write_context_marker(f: BinaryIO, name: str):
41-
f.write(nl)
4242
f.write(json_to_bytes({'context': name}))
43+
f.write(nl)
4344

4445

4546
def write_observation_marker(f: BinaryIO, obs_idx: int):
46-
f.write(nl)
4747
f.write(json_to_bytes({'observation': obs_idx}))
48+
f.write(nl)
4849

4950

50-
def begin_features(f: BinaryIO):
51+
def write_nl(f: BinaryIO):
5152
f.write(nl)
5253

5354

5455
def write_outcome_marker(f: BinaryIO, obs_idx: int):
55-
f.write(nl)
5656
f.write(json_to_bytes({'outcome': obs_idx}))
57+
f.write(nl)
5758

5859

5960
def create_example(fname: str, nr_contexts=1):
60-
6161
t0_val = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
6262
t1_val = [1, 2, 3]
6363
s = [1.2]
@@ -69,12 +69,12 @@ def create_example(fname: str, nr_contexts=1):
6969
'name': 'tensor_name2',
7070
'port': 0,
7171
'shape': [2, 3],
72-
'type': 'float'
72+
'type': 'float',
7373
}, {
7474
'name': 'tensor_name1',
7575
'port': 0,
7676
'shape': [3, 1],
77-
'type': 'int64_t'
77+
'type': 'int64_t',
7878
}],
7979
'score': {
8080
'name': 'reward',
@@ -83,29 +83,30 @@ def create_example(fname: str, nr_contexts=1):
8383
'type': 'float'
8484
}
8585
}))
86+
write_nl(f)
8687
for ctx_id in range(nr_contexts):
8788
t0_val = [v + ctx_id * 10 for v in t0_val]
8889
t1_val = [v + ctx_id * 10 for v in t1_val]
8990
write_context_marker(f, f'context_nr_{ctx_id}')
9091
write_observation_marker(f, 0)
91-
begin_features(f)
9292
write_buff(f, t0_val, ctypes.c_float)
9393
write_buff(f, t1_val, ctypes.c_int64)
94+
write_nl(f)
9495
write_outcome_marker(f, 0)
95-
begin_features(f)
9696
write_buff(f, s, ctypes.c_float)
97+
write_nl(f)
9798

9899
t0_val = [v + 1 for v in t0_val]
99100
t1_val = [v + 1 for v in t1_val]
100101
s[0] += 1
101102

102103
write_observation_marker(f, 1)
103-
begin_features(f)
104104
write_buff(f, t0_val, ctypes.c_float)
105105
write_buff(f, t1_val, ctypes.c_int64)
106+
write_nl(f)
106107
write_outcome_marker(f, 1)
107-
begin_features(f)
108108
write_buff(f, s, ctypes.c_float)
109+
write_nl(f)
109110

110111

111112
class LogReaderTest(tf.test.TestCase):
@@ -155,6 +156,19 @@ def test_read_log(self):
155156
obs_id += 1
156157
self.assertEqual(obs_id, 2)
157158

159+
def test_to_numpy(self):
160+
logfile = self.create_tempfile()
161+
create_example(logfile)
162+
t0_val = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
163+
t1_val = [1, 2, 3]
164+
for record in log_reader.read_log(logfile):
165+
np.testing.assert_allclose(record.feature_values[0].to_numpy(),
166+
np.array(t0_val))
167+
np.testing.assert_allclose(record.feature_values[1].to_numpy(),
168+
np.array(t1_val))
169+
t0_val = [v + 1 for v in t0_val]
170+
t1_val = [v + 1 for v in t1_val]
171+
158172
def test_seq_example_conversion(self):
159173
logfile = self.create_tempfile()
160174
create_example(logfile, nr_contexts=2)

0 commit comments

Comments
 (0)