Skip to content

Commit 6294cf9

Browse files
Refactor combine_training_corpus.py into script and library (#268)
This patch refactors combine_training_corpus.py into a library file that can easily be imported in downstream projects (and other utilities). This also makes unit testing slightly easier (as no special accomodations have to be made for CLI flags). This patch adds in two unittests for combining training corpora as well.
1 parent dafc347 commit 6294cf9

File tree

3 files changed

+130
-34
lines changed

3 files changed

+130
-34
lines changed

compiler_opt/tools/combine_training_corpus.py

Lines changed: 2 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -34,54 +34,22 @@
3434
and corpus2 are combined into combinedcorpus.
3535
"""
3636

37-
import json
38-
import os
39-
4037
from absl import app
4138
from absl import flags
42-
from absl import logging
4339

44-
import tensorflow as tf
40+
from compiler_opt.tools import combine_training_corpus_lib
4541

4642
flags.DEFINE_string('root_dir', '', 'root dir of module paths to combine.')
4743

4844
FLAGS = flags.FLAGS
4945

50-
_FILE_NAME = 'corpus_description.json'
51-
5246

5347
def main(argv):
5448
if len(argv) > 1:
5549
raise app.UsageError('Too many command-line arguments.')
5650

57-
module_names = []
58-
output_corpus_description = {}
59-
60-
for sub_dir in tf.io.gfile.listdir(FLAGS.root_dir):
61-
path = os.path.join(FLAGS.root_dir, sub_dir, _FILE_NAME)
62-
63-
logging.info('processing %s', path)
64-
65-
if not tf.io.gfile.exists(path):
66-
logging.error('%s does not exist.', path)
67-
continue
68-
69-
with tf.io.gfile.GFile(path, 'r') as f:
70-
corpus_description = json.load(f)
71-
module_names.extend([
72-
os.path.join(sub_dir, name) for name in corpus_description['modules']
73-
])
74-
del corpus_description['modules']
75-
if len(output_corpus_description) == 0:
76-
output_corpus_description = corpus_description
77-
elif corpus_description != output_corpus_description:
78-
raise ValueError('Input corpora differ more than modules.')
79-
80-
output_corpus_description['modules'] = module_names
81-
82-
with tf.io.gfile.GFile(os.path.join(FLAGS.root_dir, _FILE_NAME), 'w') as f:
83-
json.dump(output_corpus_description, f, indent=2)
8451

52+
combine_training_corpus_lib.combine_corpus(FLAGS.root_dir)
8553

8654
if __name__ == '__main__':
8755
app.run(main)
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# coding=utf-8
2+
# Copyright 2020 Google LLC
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""Library for combining training corpora."""
16+
17+
import os
18+
import json
19+
20+
from absl import logging
21+
22+
import tensorflow as tf
23+
24+
_FILE_NAME = 'corpus_description.json'
25+
26+
27+
def combine_corpus(root_dir: str) -> None:
28+
module_names = []
29+
output_corpus_description = {}
30+
31+
for sub_dir in tf.io.gfile.listdir(root_dir):
32+
path = os.path.join(root_dir, sub_dir, _FILE_NAME)
33+
34+
logging.info('processing %s', path)
35+
36+
if not tf.io.gfile.exists(path):
37+
logging.error('%s does not exist.', path)
38+
continue
39+
40+
with tf.io.gfile.GFile(path, 'r') as f:
41+
corpus_description = json.load(f)
42+
module_names.extend([
43+
os.path.join(sub_dir, name) for name in corpus_description['modules']
44+
])
45+
del corpus_description['modules']
46+
if len(output_corpus_description) == 0:
47+
output_corpus_description = corpus_description
48+
elif corpus_description != output_corpus_description:
49+
raise ValueError('Input corpora differ by more than modules.')
50+
51+
output_corpus_description['modules'] = module_names
52+
53+
with tf.io.gfile.GFile(os.path.join(root_dir, _FILE_NAME), 'w') as f:
54+
json.dump(output_corpus_description, f, indent=2)
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# coding=utf-8
2+
# Copyright 2020 Google LLC
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""Tests for combining training corpora."""
16+
17+
import json
18+
import os
19+
20+
from absl.testing import absltest
21+
22+
from compiler_opt.tools import combine_training_corpus_lib
23+
24+
25+
class CombineTrainingCorpusTest(absltest.TestCase):
26+
27+
def test_combine_corpus(self):
28+
corpus_dir = self.create_tempdir()
29+
subcorpus1_dir = corpus_dir.mkdir(dir_path='subcorpus1')
30+
subcorpus2_dir = corpus_dir.mkdir(dir_path='subcorpus2')
31+
subcorpus1_description = {
32+
'has_thinlto': False,
33+
'modules': ['test1.o', 'test2.o']
34+
}
35+
subcorpus2_description = {
36+
'has_thinlto': False,
37+
'modules': ['test3.o', 'test4.o']
38+
}
39+
subcorpus1_description_file = subcorpus1_dir.create_file(
40+
file_path='corpus_description.json')
41+
subcorpus2_description_file = subcorpus2_dir.create_file(
42+
file_path='corpus_description.json')
43+
subcorpus1_description_file.write_text(json.dumps(subcorpus1_description))
44+
subcorpus2_description_file.write_text(json.dumps(subcorpus2_description))
45+
combine_training_corpus_lib.combine_corpus(corpus_dir.full_path)
46+
with open(
47+
os.path.join(corpus_dir, 'corpus_description.json'),
48+
encoding='utf-8') as combined_corpus_description_file:
49+
combined_corpus_description = json.load(combined_corpus_description_file)
50+
self.assertEqual(combined_corpus_description['has_thinlto'], False)
51+
self.assertLen(combined_corpus_description['modules'], 4)
52+
self.assertIn('subcorpus1/test1.o', combined_corpus_description['modules'])
53+
self.assertIn('subcorpus1/test2.o', combined_corpus_description['modules'])
54+
self.assertIn('subcorpus2/test3.o', combined_corpus_description['modules'])
55+
self.assertIn('subcorpus2/test4.o', combined_corpus_description['modules'])
56+
57+
def test_different_corpora(self):
58+
corpus_dir = self.create_tempdir()
59+
subcorpus1_dir = corpus_dir.mkdir(dir_path='subcorpus1')
60+
subcorpus2_dir = corpus_dir.mkdir(dir_path='subcorpus2')
61+
subcorpus1_description = {'has_thinlto': False, 'modules': ['test1.o']}
62+
subcorpus2_description = {'has_thinlto': True, 'modules': ['test2.o']}
63+
subcorpus1_description_file = subcorpus1_dir.create_file(
64+
file_path='corpus_description.json')
65+
subcorpus2_description_file = subcorpus2_dir.create_file(
66+
file_path='corpus_description.json')
67+
subcorpus1_description_file.write_text(json.dumps(subcorpus1_description))
68+
subcorpus2_description_file.write_text(json.dumps(subcorpus2_description))
69+
self.assertRaises(ValueError, combine_training_corpus_lib.combine_corpus,
70+
corpus_dir.full_path)
71+
72+
73+
if __name__ == '__main__':
74+
absltest.main()

0 commit comments

Comments
 (0)