Skip to content

Commit 7394179

Browse files
authored
Modify Corpus to store module specs as a tuple (#113)
* Add method to get modules in a Corpus * change to get_modules_copy * Switch to tuples * make module_specs private
1 parent 31cfd12 commit 7394179

File tree

2 files changed

+23
-13
lines changed

2 files changed

+23
-13
lines changed

compiler_opt/rl/corpus.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from absl import logging
2121
from dataclasses import dataclass
22-
from typing import List, Dict, Tuple, Any
22+
from typing import Iterable, List, Dict, Tuple, Any
2323

2424
import json
2525
import os
@@ -45,7 +45,7 @@ def __init__(self):
4545
self._ranges = {}
4646

4747
def __call__(self,
48-
module_specs: List[ModuleSpec],
48+
module_specs: Tuple[ModuleSpec],
4949
k: int,
5050
n: int = 20) -> List[ModuleSpec]:
5151
"""
@@ -86,20 +86,23 @@ def __init__(self,
8686
data_path: str,
8787
additional_flags: Tuple[str, ...] = (),
8888
delete_flags: Tuple[str, ...] = ()):
89-
self._module_specs = _build_modulespecs_from_datapath(
90-
data_path=data_path,
91-
additional_flags=additional_flags,
92-
delete_flags=delete_flags)
89+
self._module_specs = tuple(
90+
sorted(
91+
_build_modulespecs_from_datapath(
92+
data_path=data_path,
93+
additional_flags=additional_flags,
94+
delete_flags=delete_flags),
95+
key=lambda m: m.size,
96+
reverse=True))
9397
self._root_dir = data_path
94-
self._module_specs.sort(key=lambda m: m.size, reverse=True)
9598

9699
@classmethod
97-
def from_module_specs(cls, module_specs: List[ModuleSpec]):
100+
def from_module_specs(cls, module_specs: Iterable[ModuleSpec]):
98101
"""Construct a Corpus from module specs. Mostly for testing purposes."""
99102
cps = cls.__new__(cls) # Avoid calling __init__
100103
super(cls, cps).__init__()
101-
cps._module_specs = list(module_specs) # Don't mutate the original list.
102-
cps._module_specs.sort(key=lambda m: m.size, reverse=True)
104+
cps._module_specs = tuple(
105+
sorted(module_specs, key=lambda m: m.size, reverse=True))
103106
cps.root_dir = None
104107
return cps
105108

@@ -120,7 +123,12 @@ def sample(self,
120123

121124
def filter(self, p: re.Pattern):
122125
"""Filters module specs, keeping those which match the provided pattern."""
123-
self._module_specs = [ms for ms in self._module_specs if p.match(ms.name)]
126+
self._module_specs = tuple(
127+
ms for ms in self._module_specs if p.match(ms.name))
128+
129+
@property
130+
def module_specs(self):
131+
return self._module_specs
124132

125133
def __len__(self):
126134
return len(self._module_specs)

compiler_opt/rl/corpus_test.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -246,8 +246,10 @@ def test_constructor(self):
246246

247247
cps = corpus.Corpus(tempdir.full_path, additional_flags=('-add',))
248248
self.assertEqual(
249-
corpus._build_modulespecs_from_datapath(
250-
tempdir.full_path, additional_flags=('-add',)), cps._module_specs)
249+
tuple(
250+
corpus._build_modulespecs_from_datapath(
251+
tempdir.full_path, additional_flags=('-add',))),
252+
cps.module_specs)
251253
self.assertEqual(len(cps), 1)
252254

253255
def test_sample(self):

0 commit comments

Comments
 (0)