16
16
# pylint: disable=protected-access
17
17
import json
18
18
import os
19
+ import re
19
20
20
21
import tensorflow as tf
21
22
@@ -137,7 +138,7 @@ def test_get_without_thinlto(self):
137
138
tempdir .create_file ('2.bc' )
138
139
tempdir .create_file ('2.cmd' , content = '\0 ' .join (['-cc1' , '-O3' ]))
139
140
140
- ms_list = corpus .build_modulespecs_from_datapath (
141
+ ms_list = corpus ._build_modulespecs_from_datapath (
141
142
tempdir .full_path , additional_flags = ('-add' ,))
142
143
self .assertEqual (len (ms_list ), 2 )
143
144
ms1 = ms_list [0 ]
@@ -165,7 +166,7 @@ def test_get_with_thinlto(self):
165
166
tempdir .create_file (
166
167
'2.cmd' , content = '\0 ' .join (['-cc1' , '-fthinlto-index=abc' ]))
167
168
168
- ms_list = corpus .build_modulespecs_from_datapath (
169
+ ms_list = corpus ._build_modulespecs_from_datapath (
169
170
tempdir .full_path ,
170
171
additional_flags = ('-add' ,),
171
172
delete_flags = ('-fthinlto-index' ,))
@@ -201,7 +202,7 @@ def test_get_with_override(self):
201
202
tempdir .create_file ('2.thinlto.bc' )
202
203
tempdir .create_file ('2.cmd' , content = '\0 ' .join (['-fthinlto-index=abc' ]))
203
204
204
- ms_list = corpus .build_modulespecs_from_datapath (
205
+ ms_list = corpus ._build_modulespecs_from_datapath (
205
206
tempdir .full_path ,
206
207
additional_flags = ('-add' ,),
207
208
delete_flags = ('-fthinlto-index' ,))
@@ -220,6 +221,111 @@ def test_get_with_override(self):
220
221
'-fthinlto-index=' + tempdir .full_path + '/2.thinlto.bc' ,
221
222
'-mllvm' , '-thinlto-assume-merged' , '-add' ))
222
223
224
+ def test_size (self ):
225
+ corpus_description = {'modules' : ['1' ], 'has_thinlto' : False }
226
+ tempdir = self .create_tempdir ()
227
+ tempdir .create_file (
228
+ 'corpus_description.json' , content = json .dumps (corpus_description ))
229
+ bc_file = tempdir .create_file ('1.bc' )
230
+ tempdir .create_file ('1.cmd' , content = '\0 ' .join (['-cc1' ]))
231
+ self .assertEqual (
232
+ os .path .getsize (bc_file .full_path ),
233
+ corpus ._build_modulespecs_from_datapath (
234
+ tempdir .full_path , additional_flags = ('-add' ,))[0 ].size )
235
+
236
+
237
+ class CorpusTest (tf .test .TestCase ):
238
+
239
+ def test_constructor (self ):
240
+ corpus_description = {'modules' : ['1' ], 'has_thinlto' : False }
241
+ tempdir = self .create_tempdir ()
242
+ tempdir .create_file (
243
+ 'corpus_description.json' , content = json .dumps (corpus_description ))
244
+ tempdir .create_file ('1.bc' )
245
+ tempdir .create_file ('1.cmd' , content = '\0 ' .join (['-cc1' ]))
246
+
247
+ cps = corpus .Corpus (tempdir .full_path , additional_flags = ('-add' ,))
248
+ self .assertEqual (
249
+ corpus ._build_modulespecs_from_datapath (
250
+ tempdir .full_path , additional_flags = ('-add' ,)), cps ._module_specs )
251
+ self .assertEqual (len (cps ), 1 )
252
+
253
+ def test_sample (self ):
254
+ cps = corpus .Corpus .from_module_specs (module_specs = [
255
+ corpus .ModuleSpec (name = 'smol' , size = 1 ),
256
+ corpus .ModuleSpec (name = 'middle' , size = 200 ),
257
+ corpus .ModuleSpec (name = 'largest' , size = 500 ),
258
+ corpus .ModuleSpec (name = 'small' , size = 100 )
259
+ ])
260
+ sample = cps .sample (4 , sort = True )
261
+ self .assertLen (sample , 4 )
262
+ self .assertEqual (sample [0 ].name , 'largest' )
263
+ self .assertEqual (sample [1 ].name , 'middle' )
264
+ self .assertEqual (sample [2 ].name , 'small' )
265
+ self .assertEqual (sample [3 ].name , 'smol' )
266
+
267
+ def test_filter (self ):
268
+ cps = corpus .Corpus .from_module_specs (module_specs = [
269
+ corpus .ModuleSpec (name = 'smol' , size = 1 ),
270
+ corpus .ModuleSpec (name = 'largest' , size = 500 ),
271
+ corpus .ModuleSpec (name = 'middle' , size = 200 ),
272
+ corpus .ModuleSpec (name = 'small' , size = 100 )
273
+ ])
274
+
275
+ cps .filter (re .compile (r'.+l' ))
276
+ sample = cps .sample (999 , sort = True )
277
+ self .assertLen (sample , 3 )
278
+ self .assertEqual (sample [0 ].name , 'middle' )
279
+ self .assertEqual (sample [1 ].name , 'small' )
280
+ self .assertEqual (sample [2 ].name , 'smol' )
281
+
282
+ def test_sample_zero (self ):
283
+ cps = corpus .Corpus .from_module_specs (
284
+ module_specs = [corpus .ModuleSpec (name = 'smol' )])
285
+
286
+ self .assertRaises (ValueError , cps .sample , 0 )
287
+ self .assertRaises (ValueError , cps .sample , - 213213213 )
288
+
289
+ def test_bucket_sample (self ):
290
+ cps = corpus .Corpus .from_module_specs (
291
+ module_specs = [corpus .ModuleSpec (name = '' , size = i ) for i in range (100 )])
292
+ # Odds of passing once by pure luck with random.sample: 1.779e-07
293
+ # Try 32 times, for good measure.
294
+ for i in range (32 ):
295
+ sample = cps .sample (
296
+ k = 20 , sampler = corpus .SamplerBucketRoundRobin (), sort = True )
297
+ self .assertLen (sample , 20 )
298
+ for idx , s in enumerate (sample ):
299
+ # Each bucket should be size 5, since n=20 in the sampler
300
+ self .assertEqual (s .size // 5 , 19 - idx )
301
+
302
+ def test_bucket_sample_all (self ):
303
+ # Make sure we can sample everything, even if it's not divisible by the
304
+ # `n` in SamplerBucketRoundRobin.
305
+ # Create corpus with a prime number of modules.
306
+ cps = corpus .Corpus .from_module_specs (
307
+ module_specs = [corpus .ModuleSpec (name = '' , size = i ) for i in range (101 )])
308
+
309
+ # Try 32 times, for good measure.
310
+ for i in range (32 ):
311
+ sample = cps .sample (
312
+ k = 101 , sampler = corpus .SamplerBucketRoundRobin (), sort = True )
313
+ self .assertLen (sample , 101 )
314
+ for idx , s in enumerate (sample ):
315
+ # Since everything is sampled, it should be in perfect order.
316
+ self .assertEqual (s .size , 100 - idx )
317
+
318
+ def test_bucket_sample_small (self ):
319
+ # Make sure we can sample even when k < n.
320
+ cps = corpus .Corpus .from_module_specs (
321
+ module_specs = [corpus .ModuleSpec (name = '' , size = i ) for i in range (100 )])
322
+
323
+ # Try all 19 possible values 0 < i < n
324
+ for i in range (1 , 20 ):
325
+ sample = cps .sample (
326
+ k = i , sampler = corpus .SamplerBucketRoundRobin (), sort = True )
327
+ self .assertLen (sample , i )
328
+
223
329
224
330
if __name__ == '__main__' :
225
331
tf .test .main ()
0 commit comments