Skip to content

Commit ea7d2d6

Browse files
authored
script to merge multiple best trajectories into one, and dump in json/csv format (#162)
* script to merge multiple best trajectories into one * add an issue and comment it
1 parent 0dbc785 commit ea7d2d6

File tree

1 file changed

+66
-0
lines changed

1 file changed

+66
-0
lines changed
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# coding=utf-8
2+
# Copyright 2020 Google LLC
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
r"""Merge multiple best trajectory repo into one."""
16+
17+
import json
18+
19+
from absl import app
20+
from absl import flags
21+
from absl import logging
22+
23+
from compiler_opt.rl import best_trajectory
24+
25+
_BEST_TRAJECTORY_PATHS = flags.DEFINE_multi_string(
26+
'best_trajectory_paths', '',
27+
'best trajectory repo dump to be merged in json format.')
28+
_OUTPUT_JSON_PATH = flags.DEFINE_string(
29+
'output_json_path', '',
30+
'output path of the merged best trajectory repo in json format if given.')
31+
_OUTPUT_CSV_PATH = flags.DEFINE_string(
32+
'output_csv_path', '',
33+
'output path of the merged best trajectory repo in csv format if given.')
34+
35+
FLAGS = flags.FLAGS
36+
37+
38+
def main(argv):
39+
if len(argv) > 1:
40+
raise app.UsageError('Too many command-line arguments.')
41+
42+
# we don't use action_name in the merging process here, so just set it to
43+
# empty string.
44+
merged_best_trajectory_repo = best_trajectory.BestTrajectoryRepo(
45+
action_name='')
46+
47+
for path in _BEST_TRAJECTORY_PATHS.value:
48+
logging.info('merging repo: %s', path)
49+
tmp = best_trajectory.BestTrajectoryRepo(action_name='')
50+
# The json file is broken sometimes.
51+
# Open issue: https://github.com/google/ml-compiler-opt/issues/163
52+
try:
53+
tmp.load_from_json_file(path)
54+
merged_best_trajectory_repo.combine_with_other_repo(tmp)
55+
except json.decoder.JSONDecodeError:
56+
logging.error('failed to load input repo: %s', path)
57+
58+
if _OUTPUT_JSON_PATH.value:
59+
tmp.sink_to_json_file(_OUTPUT_JSON_PATH.value)
60+
61+
if _OUTPUT_CSV_PATH.value:
62+
tmp.sink_to_csv_file(_OUTPUT_CSV_PATH.value)
63+
64+
65+
if __name__ == '__main__':
66+
app.run(main)

0 commit comments

Comments
 (0)