Skip to content

Commit 60e0c59

Browse files
authored
Add SamplerWithoutReplacement to corpus.py. (#259)
* Add SamplerWithoutReplacement to corpus.py. * Fix function signature for type annotation. * Yapf. * Fix local_data_collector_test.py * Fix construction of sampler in local_data_collector_test.py. * Fix formatting. * Yapf. * Re-add multiprocessing include.
1 parent 52fa04f commit 60e0c59

File tree

3 files changed

+103
-31
lines changed

3 files changed

+103
-31
lines changed

compiler_opt/rl/corpus.py

Lines changed: 61 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from absl import logging
2222
from dataclasses import dataclass
23-
from typing import Any, Callable, Dict, List, Optional, Tuple
23+
from typing import Any, Callable, Dict, List, Optional, Tuple, Type
2424

2525
import json
2626
import os
@@ -126,13 +126,17 @@ class Sampler(metaclass=abc.ABCMeta):
126126
"""Corpus sampler abstraction."""
127127

128128
@abc.abstractmethod
129-
def __call__(self,
130-
module_specs: Tuple[ModuleSpec],
131-
k: int,
132-
n: int = 20) -> List[ModuleSpec]:
129+
def __init__(self, module_specs: Tuple[ModuleSpec]):
130+
self._module_specs = module_specs
131+
132+
@abc.abstractmethod
133+
def reset(self):
134+
pass
135+
136+
@abc.abstractmethod
137+
def __call__(self, k: int, n: int = 20) -> List[ModuleSpec]:
133138
"""
134139
Args:
135-
module_specs: list of module_specs to sample from
136140
k: number of modules to sample
137141
n: number of buckets to use
138142
"""
@@ -144,13 +148,14 @@ class SamplerBucketRoundRobin(Sampler):
144148
round-robin order. The buckets are sequential sections of module_specs of
145149
roughly equal lengths."""
146150

147-
def __init__(self):
151+
def __init__(self, module_specs: Tuple[ModuleSpec]):
148152
self._ranges = {}
153+
super().__init__(module_specs)
149154

150-
def __call__(self,
151-
module_specs: Tuple[ModuleSpec],
152-
k: int,
153-
n: int = 20) -> List[ModuleSpec]:
155+
def reset(self):
156+
pass
157+
158+
def __call__(self, k: int, n: int = 20) -> List[ModuleSpec]:
154159
"""
155160
Args:
156161
module_specs: list of module_specs to sample from
@@ -161,7 +166,7 @@ def __call__(self,
161166
# Essentially, split module_specs into k buckets, then define the order of
162167
# visiting the k buckets such that it approximates the behaviour of having
163168
# n buckets.
164-
specs_len = len(module_specs)
169+
specs_len = len(self._module_specs)
165170
if (specs_len, k, n) not in self._ranges:
166171
quotient = k // n
167172
# rev_map maps from bucket # (implicitly via index) to order of visiting.
@@ -177,11 +182,48 @@ def __call__(self,
177182
math.floor(bucket_size_float * (i + 1))) for i in mapping)
178183

179184
return [
180-
module_specs[random.randrange(start, end)]
185+
self._module_specs[random.randrange(start, end)]
181186
for start, end in self._ranges[(specs_len, k, n)]
182187
]
183188

184189

190+
class CorpusExhaustedError(Exception):
191+
pass
192+
193+
194+
class SamplerWithoutReplacement(Sampler):
195+
"""Randomly samples the corpus, without replacement."""
196+
197+
def __init__(self, module_specs: Tuple[ModuleSpec]):
198+
super().__init__(module_specs)
199+
self._idx = 0
200+
self._shuffle_order()
201+
202+
def _shuffle_order(self):
203+
self._module_specs = tuple(
204+
random.sample(self._module_specs, len(self._module_specs)))
205+
206+
def reset(self):
207+
self._shuffle_order()
208+
self._idx = 0
209+
210+
def __call__(self, k: int, n: int = 10) -> List[ModuleSpec]:
211+
"""
212+
Args:
213+
k: number of modules to sample
214+
n: ignored
215+
Raises:
216+
CorpusExhaustedError if there are fewer than k elements left to sample in
217+
the corpus.
218+
"""
219+
endpoint = self._idx + k
220+
if endpoint > len(self._module_specs):
221+
raise CorpusExhaustedError()
222+
results = self._module_specs[self._idx:endpoint]
223+
self._idx = self._idx + k
224+
return list(results)
225+
226+
185227
class Corpus:
186228
"""Represents a corpus.
187229
@@ -230,7 +272,7 @@ def __init__(self,
230272
additional_flags: Tuple[str, ...] = (),
231273
delete_flags: Tuple[str, ...] = (),
232274
replace_flags: Optional[Dict[str, str]] = None,
233-
sampler: Sampler = SamplerBucketRoundRobin()):
275+
sampler_type: Type[Sampler] = SamplerBucketRoundRobin):
234276
"""
235277
Prepares the corpus by pre-loading all the CorpusElements and preparing for
236278
sampling. Command line origin (.cmd file or override) is decided, and final
@@ -252,7 +294,6 @@ def __init__(self,
252294
matching it. None to include everything.
253295
"""
254296
self._base_dir = data_path
255-
self._sampler = sampler
256297
# TODO: (b/233935329) Per-corpus *fdo profile paths can be read into
257298
# {additional|delete}_flags here
258299
with tf.io.gfile.GFile(
@@ -337,6 +378,10 @@ def get_cmdline(name: str):
337378
has_thinlto=has_thinlto), module_paths)
338379
self._module_specs = tuple(
339380
sorted(contents, key=lambda m: m.size, reverse=True))
381+
self._sampler = sampler_type(self._module_specs)
382+
383+
def reset(self):
384+
self._sampler.reset()
340385

341386
def sample(self, k: int, sort: bool = False) -> List[ModuleSpec]:
342387
"""Samples `k` module_specs, optionally sorting by size descending.
@@ -349,7 +394,7 @@ def sample(self, k: int, sort: bool = False) -> List[ModuleSpec]:
349394
k = min(len(self._module_specs), k)
350395
if k < 1:
351396
raise ValueError('Attempting to sample <1 module specs from corpus.')
352-
sampled_specs = self._sampler(self._module_specs, k=k)
397+
sampled_specs = self._sampler(k=k)
353398
if sort:
354399
sampled_specs.sort(key=lambda m: m.size, reverse=True)
355400
return sampled_specs

compiler_opt/rl/corpus_test.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,36 @@ def test_sample(self):
237237
self.assertEqual(sample[2].name, 'small')
238238
self.assertEqual(sample[3].name, 'smol')
239239

240+
def test_sample_without_replacement(self):
241+
cps = corpus.create_corpus_for_testing(
242+
location=self.create_tempdir(),
243+
elements=[
244+
corpus.ModuleSpec(name='smol', size=1),
245+
corpus.ModuleSpec(name='middle', size=200),
246+
corpus.ModuleSpec(name='largest', size=500),
247+
corpus.ModuleSpec(name='small', size=100)
248+
],
249+
sampler_type=corpus.SamplerWithoutReplacement)
250+
samples = []
251+
samples.extend(cps.sample(1, sort=True))
252+
self.assertLen(samples, 1)
253+
samples.extend(cps.sample(1, sort=True))
254+
self.assertLen(samples, 2)
255+
# Can't sample 3 from the corpus because there are only 2 elements left
256+
with self.assertRaises(corpus.CorpusExhaustedError):
257+
samples.extend(cps.sample(3, sort=True))
258+
# But, we can sample exactly 2 more
259+
self.assertLen(samples, 2)
260+
samples.extend(cps.sample(2, sort=True))
261+
self.assertLen(samples, 4)
262+
with self.assertRaises(corpus.CorpusExhaustedError):
263+
samples.extend(cps.sample(1, sort=True))
264+
samples.sort(key=lambda m: m.size, reverse=True)
265+
self.assertEqual(samples[0].name, 'largest')
266+
self.assertEqual(samples[1].name, 'middle')
267+
self.assertEqual(samples[2].name, 'small')
268+
self.assertEqual(samples[3].name, 'smol')
269+
240270
def test_filter(self):
241271
cps = corpus.create_corpus_for_testing(
242272
location=self.create_tempdir(),

compiler_opt/rl/local_data_collector_test.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -116,16 +116,14 @@ def collect_data(self, *args, **kwargs):
116116
class DeterministicSampler(corpus.Sampler):
117117
"""A corpus sampler that returns modules in order, and can also be reset."""
118118

119-
def __init__(self):
119+
def __init__(self, module_specs: Tuple[corpus.ModuleSpec]):
120+
super().__init__(module_specs)
120121
self._cur_pos = 0
121122

122-
def __call__(self,
123-
module_specs: Tuple[corpus.ModuleSpec],
124-
k: int,
125-
n: int = 20) -> List[corpus.ModuleSpec]:
123+
def __call__(self, k: int, n: int = 20) -> List[corpus.ModuleSpec]:
126124
ret = []
127125
for _ in range(k):
128-
ret.append(module_specs[self._cur_pos % len(module_specs)])
126+
ret.append(self._module_specs[self._cur_pos % len(self._module_specs)])
129127
self._cur_pos += 1
130128
return ret
131129

@@ -152,16 +150,15 @@ def _test_iterator_fn(data_list):
152150

153151
return _test_iterator_fn
154152

155-
sampler = DeterministicSampler()
156153
with LocalWorkerPoolManager(worker_class=MyRunner, count=4) as lwp:
154+
cps = corpus.create_corpus_for_testing(
155+
location=self.create_tempdir(),
156+
elements=[
157+
corpus.ModuleSpec(name=f'dummy{i}', size=i) for i in range(100)
158+
],
159+
sampler_type=DeterministicSampler)
157160
collector = local_data_collector.LocalDataCollector(
158-
cps=corpus.create_corpus_for_testing(
159-
location=self.create_tempdir(),
160-
elements=[
161-
corpus.ModuleSpec(name=f'dummy{i}', size=i)
162-
for i in range(100)
163-
],
164-
sampler=sampler),
161+
cps=cps,
165162
num_modules=9,
166163
worker_pool=lwp,
167164
parser=create_test_iterator_fn(),
@@ -171,7 +168,7 @@ def _test_iterator_fn(data_list):
171168
# reset the sampler, so the next time we collect, we collect the same
172169
# modules. We do it before the collect_data call, because that's when
173170
# we'll re-sample to prefetch the next batch.
174-
sampler.reset()
171+
cps.reset()
175172

176173
data_iterator, monitor_dict = collector.collect_data(
177174
policy=_mock_policy, model_id=0)

0 commit comments

Comments
 (0)