20
20
21
21
from absl import logging
22
22
from dataclasses import dataclass
23
- from typing import Any , Callable , Dict , List , Optional , Tuple
23
+ from typing import Any , Callable , Dict , List , Optional , Tuple , Type
24
24
25
25
import json
26
26
import os
@@ -126,13 +126,17 @@ class Sampler(metaclass=abc.ABCMeta):
126
126
"""Corpus sampler abstraction."""
127
127
128
128
@abc .abstractmethod
129
- def __call__ (self ,
130
- module_specs : Tuple [ModuleSpec ],
131
- k : int ,
132
- n : int = 20 ) -> List [ModuleSpec ]:
129
+ def __init__ (self , module_specs : Tuple [ModuleSpec ]):
130
+ self ._module_specs = module_specs
131
+
132
+ @abc .abstractmethod
133
+ def reset (self ):
134
+ pass
135
+
136
+ @abc .abstractmethod
137
+ def __call__ (self , k : int , n : int = 20 ) -> List [ModuleSpec ]:
133
138
"""
134
139
Args:
135
- module_specs: list of module_specs to sample from
136
140
k: number of modules to sample
137
141
n: number of buckets to use
138
142
"""
@@ -144,13 +148,14 @@ class SamplerBucketRoundRobin(Sampler):
144
148
round-robin order. The buckets are sequential sections of module_specs of
145
149
roughly equal lengths."""
146
150
147
- def __init__ (self ):
151
+ def __init__ (self , module_specs : Tuple [ ModuleSpec ] ):
148
152
self ._ranges = {}
153
+ super ().__init__ (module_specs )
149
154
150
- def __call__ (self ,
151
- module_specs : Tuple [ ModuleSpec ],
152
- k : int ,
153
- n : int = 20 ) -> List [ModuleSpec ]:
155
+ def reset (self ):
156
+ pass
157
+
158
+ def __call__ ( self , k : int , n : int = 20 ) -> List [ModuleSpec ]:
154
159
"""
155
160
Args:
156
161
module_specs: list of module_specs to sample from
@@ -161,7 +166,7 @@ def __call__(self,
161
166
# Essentially, split module_specs into k buckets, then define the order of
162
167
# visiting the k buckets such that it approximates the behaviour of having
163
168
# n buckets.
164
- specs_len = len (module_specs )
169
+ specs_len = len (self . _module_specs )
165
170
if (specs_len , k , n ) not in self ._ranges :
166
171
quotient = k // n
167
172
# rev_map maps from bucket # (implicitly via index) to order of visiting.
@@ -177,11 +182,48 @@ def __call__(self,
177
182
math .floor (bucket_size_float * (i + 1 ))) for i in mapping )
178
183
179
184
return [
180
- module_specs [random .randrange (start , end )]
185
+ self . _module_specs [random .randrange (start , end )]
181
186
for start , end in self ._ranges [(specs_len , k , n )]
182
187
]
183
188
184
189
190
+ class CorpusExhaustedError (Exception ):
191
+ pass
192
+
193
+
194
+ class SamplerWithoutReplacement (Sampler ):
195
+ """Randomly samples the corpus, without replacement."""
196
+
197
+ def __init__ (self , module_specs : Tuple [ModuleSpec ]):
198
+ super ().__init__ (module_specs )
199
+ self ._idx = 0
200
+ self ._shuffle_order ()
201
+
202
+ def _shuffle_order (self ):
203
+ self ._module_specs = tuple (
204
+ random .sample (self ._module_specs , len (self ._module_specs )))
205
+
206
+ def reset (self ):
207
+ self ._shuffle_order ()
208
+ self ._idx = 0
209
+
210
+ def __call__ (self , k : int , n : int = 10 ) -> List [ModuleSpec ]:
211
+ """
212
+ Args:
213
+ k: number of modules to sample
214
+ n: ignored
215
+ Raises:
216
+ CorpusExhaustedError if there are fewer than k elements left to sample in
217
+ the corpus.
218
+ """
219
+ endpoint = self ._idx + k
220
+ if endpoint > len (self ._module_specs ):
221
+ raise CorpusExhaustedError ()
222
+ results = self ._module_specs [self ._idx :endpoint ]
223
+ self ._idx = self ._idx + k
224
+ return list (results )
225
+
226
+
185
227
class Corpus :
186
228
"""Represents a corpus.
187
229
@@ -230,7 +272,7 @@ def __init__(self,
230
272
additional_flags : Tuple [str , ...] = (),
231
273
delete_flags : Tuple [str , ...] = (),
232
274
replace_flags : Optional [Dict [str , str ]] = None ,
233
- sampler : Sampler = SamplerBucketRoundRobin () ):
275
+ sampler_type : Type [ Sampler ] = SamplerBucketRoundRobin ):
234
276
"""
235
277
Prepares the corpus by pre-loading all the CorpusElements and preparing for
236
278
sampling. Command line origin (.cmd file or override) is decided, and final
@@ -252,7 +294,6 @@ def __init__(self,
252
294
matching it. None to include everything.
253
295
"""
254
296
self ._base_dir = data_path
255
- self ._sampler = sampler
256
297
# TODO: (b/233935329) Per-corpus *fdo profile paths can be read into
257
298
# {additional|delete}_flags here
258
299
with tf .io .gfile .GFile (
@@ -337,6 +378,10 @@ def get_cmdline(name: str):
337
378
has_thinlto = has_thinlto ), module_paths )
338
379
self ._module_specs = tuple (
339
380
sorted (contents , key = lambda m : m .size , reverse = True ))
381
+ self ._sampler = sampler_type (self ._module_specs )
382
+
383
+ def reset (self ):
384
+ self ._sampler .reset ()
340
385
341
386
def sample (self , k : int , sort : bool = False ) -> List [ModuleSpec ]:
342
387
"""Samples `k` module_specs, optionally sorting by size descending.
@@ -349,7 +394,7 @@ def sample(self, k: int, sort: bool = False) -> List[ModuleSpec]:
349
394
k = min (len (self ._module_specs ), k )
350
395
if k < 1 :
351
396
raise ValueError ('Attempting to sample <1 module specs from corpus.' )
352
- sampled_specs = self ._sampler (self . _module_specs , k = k )
397
+ sampled_specs = self ._sampler (k = k )
353
398
if sort :
354
399
sampled_specs .sort (key = lambda m : m .size , reverse = True )
355
400
return sampled_specs
0 commit comments