Skip to content

Commit 220c0f5

Browse files
authored
Add Corpus abstraction & sorting w/ size (#77)
* Add Corpus abstraction & sorting w/ size - Rather than pass around a list of module_specs,pass around a Corpus object instead - built in filter and sample method - sample will sort by size descending, with the goal of optimizing compile order * change .size -> __len__ * Default to unbiased sampling, add sampler option * Add separate constructor for testing * Make sampler not select repeats * replace 'corp' with 'cps' * Switch to optimized algorithm * resolve comments
1 parent 7e4b19f commit 220c0f5

File tree

6 files changed

+227
-33
lines changed

6 files changed

+227
-33
lines changed

compiler_opt/rl/corpus.py

Lines changed: 103 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
"""ModuleSpec definition and utility command line parsing functions."""
16+
import math
17+
import random
18+
import re
1619

1720
from absl import logging
1821
from dataclasses import dataclass
@@ -30,9 +33,100 @@ class ModuleSpec:
3033
"""
3134
name: str
3235
exec_cmd: Tuple[str, ...] = ()
33-
34-
35-
def build_modulespecs_from_datapath(
36+
size: int = 0
37+
38+
39+
class SamplerBucketRoundRobin:
40+
"""Calls return a list of module_specs sampled randomly from n buckets, in
41+
round-robin order. The buckets are sequential sections of module_specs of
42+
roughly equal lengths."""
43+
44+
def __init__(self):
45+
self._ranges = {}
46+
47+
def __call__(self,
48+
module_specs: List[ModuleSpec],
49+
k: int,
50+
n: int = 20) -> List[ModuleSpec]:
51+
"""
52+
Args:
53+
module_specs: list of module_specs to sample from
54+
k: number of modules to sample
55+
n: number of buckets to use
56+
"""
57+
# Credits to yundi@ for the highly optimized algo.
58+
# Essentially, split module_specs into k buckets, then define the order of
59+
# visiting the k buckets such that it approximates the behaviour of having
60+
# n buckets.
61+
specs_len = len(module_specs)
62+
if (specs_len, k, n) not in self._ranges:
63+
quotient = k // n
64+
# rev_map maps from bucket # (implicitly via index) to order of visiting.
65+
# lower values should be visited first, and earlier indices before later.
66+
rev_map = [i % quotient for i in range(k)] if quotient else [0] * k
67+
# mapping defines the order in which buckets should be visited.
68+
mapping = [t[0] for t in sorted(enumerate(rev_map), key=lambda x: x[1])]
69+
70+
# generate the buckets ranges, in the order which they should be visited.
71+
bucket_size_float = specs_len / k
72+
self._ranges[(specs_len, k, n)] = tuple(
73+
(math.floor(bucket_size_float * i),
74+
math.floor(bucket_size_float * (i + 1))) for i in mapping)
75+
76+
return [
77+
module_specs[random.randrange(start, end)]
78+
for start, end in self._ranges[(specs_len, k, n)]
79+
]
80+
81+
82+
class Corpus:
83+
"""Represents a corpus. Comes along with some utility functions."""
84+
85+
def __init__(self,
86+
data_path: str,
87+
additional_flags: Tuple[str, ...] = (),
88+
delete_flags: Tuple[str, ...] = ()):
89+
self._module_specs = _build_modulespecs_from_datapath(
90+
data_path=data_path,
91+
additional_flags=additional_flags,
92+
delete_flags=delete_flags)
93+
self._root_dir = data_path
94+
self._module_specs.sort(key=lambda m: m.size, reverse=True)
95+
96+
@classmethod
97+
def from_module_specs(cls, module_specs: List[ModuleSpec]):
98+
"""Construct a Corpus from module specs. Mostly for testing purposes."""
99+
cps = cls.__new__(cls) # Avoid calling __init__
100+
super(cls, cps).__init__()
101+
cps._module_specs = list(module_specs) # Don't mutate the original list.
102+
cps._module_specs.sort(key=lambda m: m.size, reverse=True)
103+
cps.root_dir = None
104+
return cps
105+
106+
def sample(self,
107+
k: int,
108+
sort: bool = False,
109+
sampler=SamplerBucketRoundRobin()) -> List[ModuleSpec]:
110+
"""Samples `k` module_specs, optionally sorting by size descending."""
111+
# Note: sampler is intentionally defaulted to a mutable object, as the
112+
# only mutable attribute of SamplerBucketRoundRobin is its range cache.
113+
k = min(len(self._module_specs), k)
114+
if k < 1:
115+
raise ValueError('Attempting to sample <1 module specs from corpus.')
116+
sampled_specs = sampler(self._module_specs, k=k)
117+
if sort:
118+
sampled_specs.sort(key=lambda m: m.size, reverse=True)
119+
return sampled_specs
120+
121+
def filter(self, p: re.Pattern):
122+
"""Filters module specs, keeping those which match the provided pattern."""
123+
self._module_specs = [ms for ms in self._module_specs if p.match(ms.name)]
124+
125+
def __len__(self):
126+
return len(self._module_specs)
127+
128+
129+
def _build_modulespecs_from_datapath(
36130
data_path: str,
37131
additional_flags: Tuple[str, ...] = (),
38132
delete_flags: Tuple[str, ...] = ()
@@ -65,14 +159,17 @@ def build_modulespecs_from_datapath(
65159
module_specs: List[ModuleSpec] = []
66160

67161
# This takes ~7s for 30k modules
68-
for module_path in module_paths:
162+
for rel_module_path in module_paths:
163+
full_module_path = os.path.join(data_path, rel_module_path)
69164
exec_cmd = _load_and_parse_command(
70-
module_path=os.path.join(data_path, module_path),
165+
module_path=full_module_path,
71166
has_thinlto=has_thinlto,
72167
additional_flags=additional_flags,
73168
delete_flags=delete_flags,
74169
cmd_override=cmd_override)
75-
module_specs.append(ModuleSpec(name=module_path, exec_cmd=tuple(exec_cmd)))
170+
size = os.path.getsize(full_module_path + '.bc')
171+
module_specs.append(
172+
ModuleSpec(name=rel_module_path, exec_cmd=tuple(exec_cmd), size=size))
76173

77174
return module_specs
78175

compiler_opt/rl/corpus_test.py

Lines changed: 109 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
# pylint: disable=protected-access
1717
import json
1818
import os
19+
import re
1920

2021
import tensorflow as tf
2122

@@ -137,7 +138,7 @@ def test_get_without_thinlto(self):
137138
tempdir.create_file('2.bc')
138139
tempdir.create_file('2.cmd', content='\0'.join(['-cc1', '-O3']))
139140

140-
ms_list = corpus.build_modulespecs_from_datapath(
141+
ms_list = corpus._build_modulespecs_from_datapath(
141142
tempdir.full_path, additional_flags=('-add',))
142143
self.assertEqual(len(ms_list), 2)
143144
ms1 = ms_list[0]
@@ -165,7 +166,7 @@ def test_get_with_thinlto(self):
165166
tempdir.create_file(
166167
'2.cmd', content='\0'.join(['-cc1', '-fthinlto-index=abc']))
167168

168-
ms_list = corpus.build_modulespecs_from_datapath(
169+
ms_list = corpus._build_modulespecs_from_datapath(
169170
tempdir.full_path,
170171
additional_flags=('-add',),
171172
delete_flags=('-fthinlto-index',))
@@ -201,7 +202,7 @@ def test_get_with_override(self):
201202
tempdir.create_file('2.thinlto.bc')
202203
tempdir.create_file('2.cmd', content='\0'.join(['-fthinlto-index=abc']))
203204

204-
ms_list = corpus.build_modulespecs_from_datapath(
205+
ms_list = corpus._build_modulespecs_from_datapath(
205206
tempdir.full_path,
206207
additional_flags=('-add',),
207208
delete_flags=('-fthinlto-index',))
@@ -220,6 +221,111 @@ def test_get_with_override(self):
220221
'-fthinlto-index=' + tempdir.full_path + '/2.thinlto.bc',
221222
'-mllvm', '-thinlto-assume-merged', '-add'))
222223

224+
def test_size(self):
225+
corpus_description = {'modules': ['1'], 'has_thinlto': False}
226+
tempdir = self.create_tempdir()
227+
tempdir.create_file(
228+
'corpus_description.json', content=json.dumps(corpus_description))
229+
bc_file = tempdir.create_file('1.bc')
230+
tempdir.create_file('1.cmd', content='\0'.join(['-cc1']))
231+
self.assertEqual(
232+
os.path.getsize(bc_file.full_path),
233+
corpus._build_modulespecs_from_datapath(
234+
tempdir.full_path, additional_flags=('-add',))[0].size)
235+
236+
237+
class CorpusTest(tf.test.TestCase):
238+
239+
def test_constructor(self):
240+
corpus_description = {'modules': ['1'], 'has_thinlto': False}
241+
tempdir = self.create_tempdir()
242+
tempdir.create_file(
243+
'corpus_description.json', content=json.dumps(corpus_description))
244+
tempdir.create_file('1.bc')
245+
tempdir.create_file('1.cmd', content='\0'.join(['-cc1']))
246+
247+
cps = corpus.Corpus(tempdir.full_path, additional_flags=('-add',))
248+
self.assertEqual(
249+
corpus._build_modulespecs_from_datapath(
250+
tempdir.full_path, additional_flags=('-add',)), cps._module_specs)
251+
self.assertEqual(len(cps), 1)
252+
253+
def test_sample(self):
254+
cps = corpus.Corpus.from_module_specs(module_specs=[
255+
corpus.ModuleSpec(name='smol', size=1),
256+
corpus.ModuleSpec(name='middle', size=200),
257+
corpus.ModuleSpec(name='largest', size=500),
258+
corpus.ModuleSpec(name='small', size=100)
259+
])
260+
sample = cps.sample(4, sort=True)
261+
self.assertLen(sample, 4)
262+
self.assertEqual(sample[0].name, 'largest')
263+
self.assertEqual(sample[1].name, 'middle')
264+
self.assertEqual(sample[2].name, 'small')
265+
self.assertEqual(sample[3].name, 'smol')
266+
267+
def test_filter(self):
268+
cps = corpus.Corpus.from_module_specs(module_specs=[
269+
corpus.ModuleSpec(name='smol', size=1),
270+
corpus.ModuleSpec(name='largest', size=500),
271+
corpus.ModuleSpec(name='middle', size=200),
272+
corpus.ModuleSpec(name='small', size=100)
273+
])
274+
275+
cps.filter(re.compile(r'.+l'))
276+
sample = cps.sample(999, sort=True)
277+
self.assertLen(sample, 3)
278+
self.assertEqual(sample[0].name, 'middle')
279+
self.assertEqual(sample[1].name, 'small')
280+
self.assertEqual(sample[2].name, 'smol')
281+
282+
def test_sample_zero(self):
283+
cps = corpus.Corpus.from_module_specs(
284+
module_specs=[corpus.ModuleSpec(name='smol')])
285+
286+
self.assertRaises(ValueError, cps.sample, 0)
287+
self.assertRaises(ValueError, cps.sample, -213213213)
288+
289+
def test_bucket_sample(self):
290+
cps = corpus.Corpus.from_module_specs(
291+
module_specs=[corpus.ModuleSpec(name='', size=i) for i in range(100)])
292+
# Odds of passing once by pure luck with random.sample: 1.779e-07
293+
# Try 32 times, for good measure.
294+
for i in range(32):
295+
sample = cps.sample(
296+
k=20, sampler=corpus.SamplerBucketRoundRobin(), sort=True)
297+
self.assertLen(sample, 20)
298+
for idx, s in enumerate(sample):
299+
# Each bucket should be size 5, since n=20 in the sampler
300+
self.assertEqual(s.size // 5, 19 - idx)
301+
302+
def test_bucket_sample_all(self):
303+
# Make sure we can sample everything, even if it's not divisible by the
304+
# `n` in SamplerBucketRoundRobin.
305+
# Create corpus with a prime number of modules.
306+
cps = corpus.Corpus.from_module_specs(
307+
module_specs=[corpus.ModuleSpec(name='', size=i) for i in range(101)])
308+
309+
# Try 32 times, for good measure.
310+
for i in range(32):
311+
sample = cps.sample(
312+
k=101, sampler=corpus.SamplerBucketRoundRobin(), sort=True)
313+
self.assertLen(sample, 101)
314+
for idx, s in enumerate(sample):
315+
# Since everything is sampled, it should be in perfect order.
316+
self.assertEqual(s.size, 100 - idx)
317+
318+
def test_bucket_sample_small(self):
319+
# Make sure we can sample even when k < n.
320+
cps = corpus.Corpus.from_module_specs(
321+
module_specs=[corpus.ModuleSpec(name='', size=i) for i in range(100)])
322+
323+
# Try all 19 possible values 0 < i < n
324+
for i in range(1, 20):
325+
sample = cps.sample(
326+
k=i, sampler=corpus.SamplerBucketRoundRobin(), sort=True)
327+
self.assertLen(sample, i)
328+
223329

224330
if __name__ == '__main__':
225331
tf.test.main()

compiler_opt/rl/local_data_collector.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
import concurrent.futures
1818
import itertools
19-
import random
2019
import time
2120
from typing import Callable, Dict, Iterator, List, Tuple, Optional
2221

@@ -34,7 +33,7 @@ class LocalDataCollector(data_collector.DataCollector):
3433

3534
def __init__(
3635
self,
37-
module_specs: List[corpus.ModuleSpec],
36+
cps: corpus.Corpus,
3837
num_modules: int,
3938
worker_pool: List[compilation_runner.CompilationRunnerStub],
4039
parser: Callable[[List[str]], Iterator[trajectory.Trajectory]],
@@ -44,7 +43,7 @@ def __init__(
4443
# TODO(mtrofin): type exit_checker_ctor when we get typing.Protocol support
4544
super().__init__()
4645

47-
self._module_specs = module_specs
46+
self._corpus = cps
4847
self._num_modules = num_modules
4948
self._parser = parser
5049
self._worker_pool = worker_pool
@@ -86,7 +85,7 @@ def _schedule_jobs(
8685
jobs = [(module_spec, policy_path, self._reward_stat_map[module_spec.name])
8786
for module_spec in sampled_modules]
8887

89-
# Naive load balancing.
88+
# TODO: Issue #91. Naive load balancing.
9089
ret = []
9190
for i in range(len(jobs)):
9291
ret.append(self._worker_pool[i % len(self._worker_pool)].collect_data(
@@ -108,7 +107,7 @@ def collect_data(
108107
They will be reported using `tf.scalar.summary` by the trainer so these
109108
information is viewable in TensorBoard.
110109
"""
111-
sampled_modules = random.sample(self._module_specs, k=self._num_modules)
110+
sampled_modules = self._corpus.sample(k=self._num_modules, sort=False)
112111
results = self._schedule_jobs(policy_path, sampled_modules)
113112

114113
def wait_for_termination():

compiler_opt/rl/local_data_collector_test.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,8 @@ def _test_iterator_fn(data_list):
116116

117117
with LocalWorkerPool(worker_class=MyRunner, count=4) as lwp:
118118
collector = local_data_collector.LocalDataCollector(
119-
module_specs=[corpus.ModuleSpec(name='dummy')] * 100,
119+
cps=corpus.Corpus.from_module_specs(
120+
module_specs=[corpus.ModuleSpec(name='dummy')] * 100),
120121
num_modules=9,
121122
worker_pool=lwp,
122123
parser=create_test_iterator_fn(),
@@ -177,7 +178,8 @@ def wait(self, _):
177178

178179
with LocalWorkerPool(worker_class=Sleeper, count=4) as lwp:
179180
collector = local_data_collector.LocalDataCollector(
180-
module_specs=[corpus.ModuleSpec(name='dummy')] * 200,
181+
cps=corpus.Corpus.from_module_specs(
182+
module_specs=[corpus.ModuleSpec(name='dummy')] * 200),
181183
num_modules=4,
182184
worker_pool=lwp,
183185
parser=parser,

compiler_opt/rl/train_locally.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,9 +99,8 @@ def train_eval(agent_name=constant.AgentName.PPO,
9999
saver = policy_saver.PolicySaver(policy_dict=policy_dict)
100100

101101
logging.info('Loading module specs from corpus at %s.', FLAGS.data_path)
102-
module_specs = corpus.build_modulespecs_from_datapath(
103-
FLAGS.data_path, problem_config.flags_to_add(),
104-
problem_config.flags_to_delete())
102+
cps = corpus.Corpus(FLAGS.data_path, problem_config.flags_to_add(),
103+
problem_config.flags_to_delete())
105104
logging.info('Done loading module specs from corpus.')
106105

107106
dataset_fn = data_reader.create_sequence_example_dataset_fn(
@@ -136,7 +135,7 @@ def sequence_example_iterator_fn(seq_ex: List[str]):
136135
count=FLAGS.num_workers,
137136
moving_average_decay_rate=moving_average_decay_rate) as worker_pool:
138137
data_collector = local_data_collector.LocalDataCollector(
139-
module_specs=module_specs,
138+
cps=cps,
140139
num_modules=num_modules,
141140
worker_pool=worker_pool,
142141
parser=sequence_example_iterator_fn,

0 commit comments

Comments
 (0)