19
19
20
20
from absl import logging
21
21
from dataclasses import dataclass
22
- from typing import List , Dict , Tuple , Any
22
+ from typing import Iterable , List , Dict , Tuple , Any
23
23
24
24
import json
25
25
import os
@@ -45,7 +45,7 @@ def __init__(self):
45
45
self ._ranges = {}
46
46
47
47
def __call__ (self ,
48
- module_specs : List [ModuleSpec ],
48
+ module_specs : Tuple [ModuleSpec ],
49
49
k : int ,
50
50
n : int = 20 ) -> List [ModuleSpec ]:
51
51
"""
@@ -86,20 +86,23 @@ def __init__(self,
86
86
data_path : str ,
87
87
additional_flags : Tuple [str , ...] = (),
88
88
delete_flags : Tuple [str , ...] = ()):
89
- self ._module_specs = _build_modulespecs_from_datapath (
90
- data_path = data_path ,
91
- additional_flags = additional_flags ,
92
- delete_flags = delete_flags )
89
+ self .module_specs = tuple (
90
+ sorted (
91
+ _build_modulespecs_from_datapath (
92
+ data_path = data_path ,
93
+ additional_flags = additional_flags ,
94
+ delete_flags = delete_flags ),
95
+ key = lambda m : m .size ,
96
+ reverse = True ))
93
97
self ._root_dir = data_path
94
- self ._module_specs .sort (key = lambda m : m .size , reverse = True )
95
98
96
99
@classmethod
97
- def from_module_specs (cls , module_specs : List [ModuleSpec ]):
100
+ def from_module_specs (cls , module_specs : Iterable [ModuleSpec ]):
98
101
"""Construct a Corpus from module specs. Mostly for testing purposes."""
99
102
cps = cls .__new__ (cls ) # Avoid calling __init__
100
103
super (cls , cps ).__init__ ()
101
- cps ._module_specs = list ( module_specs ) # Don't mutate the original list.
102
- cps . _module_specs . sort ( key = lambda m : m .size , reverse = True )
104
+ cps .module_specs = tuple (
105
+ sorted ( module_specs , key = lambda m : m .size , reverse = True ) )
103
106
cps .root_dir = None
104
107
return cps
105
108
@@ -110,23 +113,20 @@ def sample(self,
110
113
"""Samples `k` module_specs, optionally sorting by size descending."""
111
114
# Note: sampler is intentionally defaulted to a mutable object, as the
112
115
# only mutable attribute of SamplerBucketRoundRobin is its range cache.
113
- k = min (len (self ._module_specs ), k )
116
+ k = min (len (self .module_specs ), k )
114
117
if k < 1 :
115
118
raise ValueError ('Attempting to sample <1 module specs from corpus.' )
116
- sampled_specs = sampler (self ._module_specs , k = k )
119
+ sampled_specs = sampler (self .module_specs , k = k )
117
120
if sort :
118
121
sampled_specs .sort (key = lambda m : m .size , reverse = True )
119
122
return sampled_specs
120
123
121
124
def filter (self , p : re .Pattern ):
122
125
"""Filters module specs, keeping those which match the provided pattern."""
123
- self ._module_specs = [ms for ms in self ._module_specs if p .match (ms .name )]
124
-
125
- def get_modules_copy (self ):
126
- return list (self ._module_specs )
126
+ self .module_specs = [ms for ms in self .module_specs if p .match (ms .name )]
127
127
128
128
def __len__ (self ):
129
- return len (self ._module_specs )
129
+ return len (self .module_specs )
130
130
131
131
132
132
def _build_modulespecs_from_datapath (
0 commit comments