Skip to content

Commit 3ba3ebf

Browse files
authored
utility class for processing and recording best trajectories. (#152)
* best trajetcories * pytype * pytype * pylint
1 parent 9b1232f commit 3ba3ebf

File tree

2 files changed

+235
-0
lines changed

2 files changed

+235
-0
lines changed

compiler_opt/rl/best_trajectory.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# coding=utf-8
2+
# Copyright 2020 Google LLC
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""Module for storing and processing best trajectories."""
16+
17+
import dataclasses
18+
import json
19+
from typing import Dict, List
20+
21+
import tensorflow as tf
22+
23+
from compiler_opt.rl import constant
24+
25+
26+
@dataclasses.dataclass(frozen=True)
27+
class BestTrajectory:
28+
reward: float
29+
action_list: List[int]
30+
31+
32+
class BestTrajectoryRepo:
33+
"""Class for storing and processing best trajectory related operations."""
34+
35+
def __init__(self, action_name: str):
36+
"""Constructor.
37+
38+
Args:
39+
action_name: action name of the trajectory, used for extracting action
40+
list from tensorflow.SequenceExample.
41+
"""
42+
# {module_name: {identifier: best trajectory}}
43+
self._best_trajectories: Dict[str, Dict[str, BestTrajectory]] = {}
44+
self._action_name: str = action_name
45+
46+
@property
47+
def best_trajectories(self) -> Dict[str, Dict[str, BestTrajectory]]:
48+
return self._best_trajectories.copy()
49+
50+
def sink_to_json_file(self, path: str):
51+
with tf.io.gfile.GFile(path, 'w') as f:
52+
json.dump(self._best_trajectories, f, cls=constant.DataClassJSONEncoder)
53+
54+
def load_from_json_file(self, path: str):
55+
with tf.io.gfile.GFile(path, 'r') as f:
56+
data = json.load(f)
57+
for k, v in data.items():
58+
if v:
59+
self._best_trajectories[k] = {
60+
sub_k: BestTrajectory(**sub_v) for sub_k, sub_v in v.items()
61+
}
62+
63+
def sink_to_csv_file(self, path: str):
64+
"""sink to csv file format consumable by compiler."""
65+
with tf.io.gfile.GFile(path, 'w') as f:
66+
for k, v in self._best_trajectories.items():
67+
for sub_k, sub_v in v.items():
68+
f.write(','.join([k, sub_k] + [str(x) for x in sub_v.action_list]) +
69+
'\n')
70+
71+
def combine_with_other_repo(self, other: 'BestTrajectoryRepo'):
72+
"""combine and update with other best trajectory repo."""
73+
for k, v in other.best_trajectories.items():
74+
if k not in self._best_trajectories:
75+
self._best_trajectories[k] = v
76+
continue
77+
for sub_k, sub_v in v.items():
78+
if sub_v.reward < self._best_trajectories[k][sub_k].reward:
79+
self._best_trajectories[k][sub_k] = sub_v
80+
81+
def update_if_better_trajectory(self, module_name: str, identifier: str,
82+
reward: float, trajectory: bytes):
83+
"""update with incoming trajectory if the reward is lower.
84+
85+
Args:
86+
module_name: module name of the trajectory.
87+
identifier: identifier of the trajectory within module.
88+
reward: reward of the trajectory.
89+
trajectory: trajectory in the format of serialized SequenceExample.
90+
"""
91+
if module_name not in self._best_trajectories:
92+
self._best_trajectories[module_name] = {}
93+
if (identifier not in self._best_trajectories[module_name] or
94+
self._best_trajectories[module_name][identifier].reward > reward):
95+
example = tf.train.SequenceExample.FromString(trajectory)
96+
action_list = [
97+
x.int64_list.value[0]
98+
for x in example.feature_lists.feature_list[self._action_name].feature
99+
]
100+
self._best_trajectories[module_name][identifier] = BestTrajectory(
101+
reward=reward, action_list=action_list)
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
# coding=utf-8
2+
# Copyright 2020 Google LLC
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""Tests for compiler_opt.rl.best_trajectory."""
16+
17+
from absl.testing import absltest
18+
from absl.testing import parameterized
19+
20+
import tensorflow as tf
21+
22+
from compiler_opt.rl import best_trajectory
23+
24+
_ACTION_NAME = 'mock'
25+
26+
27+
def _get_test_repo_1():
28+
repo = best_trajectory.BestTrajectoryRepo(action_name=_ACTION_NAME)
29+
# pylint: disable=protected-access
30+
repo._best_trajectories['module_1'] = {
31+
'function_1':
32+
best_trajectory.BestTrajectory(reward=3.4, action_list=[1, 3, 5]),
33+
'function_2':
34+
best_trajectory.BestTrajectory(reward=1.2, action_list=[9, 7, 5])
35+
}
36+
# pylint: enable=protected-access
37+
return repo
38+
39+
40+
def _get_test_repo_2():
41+
repo = best_trajectory.BestTrajectoryRepo(action_name=_ACTION_NAME)
42+
# pylint: disable=protected-access
43+
repo._best_trajectories['module_1'] = {
44+
'function_1':
45+
best_trajectory.BestTrajectory(reward=2.3, action_list=[1, 3]),
46+
'function_2':
47+
best_trajectory.BestTrajectory(reward=3.4, action_list=[9, 7])
48+
}
49+
repo._best_trajectories['module_2'] = {
50+
'function_1':
51+
best_trajectory.BestTrajectory(reward=7.8, action_list=[2, 4, 6]),
52+
}
53+
# pylint: enable=protected-access
54+
return repo
55+
56+
57+
def _get_combined_repo():
58+
repo = best_trajectory.BestTrajectoryRepo(action_name=_ACTION_NAME)
59+
# pylint: disable=protected-access
60+
repo._best_trajectories['module_1'] = {
61+
'function_1':
62+
best_trajectory.BestTrajectory(reward=2.3, action_list=[1, 3]),
63+
'function_2':
64+
best_trajectory.BestTrajectory(reward=1.2, action_list=[9, 7, 5])
65+
}
66+
repo._best_trajectories['module_2'] = {
67+
'function_1':
68+
best_trajectory.BestTrajectory(reward=7.8, action_list=[2, 4, 6]),
69+
}
70+
# pylint: enable=protected-access
71+
return repo
72+
73+
74+
def _create_sequence_example(action_list):
75+
example = tf.train.SequenceExample()
76+
for action in action_list:
77+
example.feature_lists.feature_list[_ACTION_NAME].feature.add(
78+
).int64_list.value.append(action)
79+
return example.SerializeToString()
80+
81+
82+
class BestTrajectoryTest(parameterized.TestCase):
83+
84+
@parameterized.named_parameters(('repo_1', _get_test_repo_1()),
85+
('repo_2', _get_test_repo_2()))
86+
def test_sink_load_json_file(self, repo):
87+
path = self.create_tempfile().full_path
88+
repo.sink_to_json_file(path)
89+
loaded_repo = best_trajectory.BestTrajectoryRepo(action_name=_ACTION_NAME)
90+
loaded_repo.load_from_json_file(path)
91+
self.assertDictEqual(repo.best_trajectories, loaded_repo.best_trajectories)
92+
93+
def test_sink_to_csv_file(self):
94+
path = self.create_tempfile().full_path
95+
repo = _get_test_repo_1()
96+
repo.sink_to_csv_file(path)
97+
with open(path, 'r', encoding='utf-8') as f:
98+
text = f.read()
99+
100+
self.assertEqual(text,
101+
'module_1,function_1,1,3,5\nmodule_1,function_2,9,7,5\n')
102+
103+
@parameterized.named_parameters(
104+
{
105+
'testcase_name': 'repo_1_combine_2',
106+
'base_repo': _get_test_repo_1(),
107+
'second_repo': _get_test_repo_2()
108+
}, {
109+
'testcase_name': 'repo_2_combine_1',
110+
'base_repo': _get_test_repo_2(),
111+
'second_repo': _get_test_repo_1()
112+
})
113+
def test_combine_with_other_repo(self, base_repo, second_repo):
114+
base_repo.combine_with_other_repo(second_repo)
115+
self.assertDictEqual(base_repo.best_trajectories,
116+
_get_combined_repo().best_trajectories)
117+
118+
def test_update_if_better_trajectory(self):
119+
repo = _get_test_repo_1()
120+
repo.update_if_better_trajectory(
121+
'module_1', 'function_1', 2.3,
122+
_create_sequence_example(action_list=[1, 3]))
123+
repo.update_if_better_trajectory(
124+
'module_1', 'function_2', 3.4,
125+
_create_sequence_example(action_list=[9, 7]))
126+
repo.update_if_better_trajectory(
127+
'module_2', 'function_1', 7.8,
128+
_create_sequence_example(action_list=[2, 4, 6]))
129+
self.assertDictEqual(repo.best_trajectories,
130+
_get_combined_repo().best_trajectories)
131+
132+
133+
if __name__ == '__main__':
134+
absltest.main()

0 commit comments

Comments
 (0)