Skip to content

Commit a777737

Browse files
committed
Switch to tuples
1 parent 99a47d4 commit a777737

File tree

1 file changed

+15
-12
lines changed

1 file changed

+15
-12
lines changed

compiler_opt/rl/corpus.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from absl import logging
2121
from dataclasses import dataclass
22-
from typing import List, Dict, Tuple, Any
22+
from typing import Iterable, List, Dict, Tuple, Any
2323

2424
import json
2525
import os
@@ -45,7 +45,7 @@ def __init__(self):
4545
self._ranges = {}
4646

4747
def __call__(self,
48-
module_specs: List[ModuleSpec],
48+
module_specs: Tuple[ModuleSpec],
4949
k: int,
5050
n: int = 20) -> List[ModuleSpec]:
5151
"""
@@ -86,20 +86,23 @@ def __init__(self,
8686
data_path: str,
8787
additional_flags: Tuple[str, ...] = (),
8888
delete_flags: Tuple[str, ...] = ()):
89-
self._module_specs = _build_modulespecs_from_datapath(
90-
data_path=data_path,
91-
additional_flags=additional_flags,
92-
delete_flags=delete_flags)
89+
self._module_specs = tuple(
90+
sorted(
91+
_build_modulespecs_from_datapath(
92+
data_path=data_path,
93+
additional_flags=additional_flags,
94+
delete_flags=delete_flags),
95+
key=lambda m: m.size,
96+
reverse=True))
9397
self._root_dir = data_path
94-
self._module_specs.sort(key=lambda m: m.size, reverse=True)
9598

9699
@classmethod
97-
def from_module_specs(cls, module_specs: List[ModuleSpec]):
100+
def from_module_specs(cls, module_specs: Iterable[ModuleSpec]):
98101
"""Construct a Corpus from module specs. Mostly for testing purposes."""
99102
cps = cls.__new__(cls) # Avoid calling __init__
100103
super(cls, cps).__init__()
101-
cps._module_specs = list(module_specs) # Don't mutate the original list.
102-
cps._module_specs.sort(key=lambda m: m.size, reverse=True)
104+
cps._module_specs = tuple(
105+
sorted(module_specs, key=lambda m: m.size, reverse=True))
103106
cps.root_dir = None
104107
return cps
105108

@@ -122,8 +125,8 @@ def filter(self, p: re.Pattern):
122125
"""Filters module specs, keeping those which match the provided pattern."""
123126
self._module_specs = [ms for ms in self._module_specs if p.match(ms.name)]
124127

125-
def get_modules_copy(self):
126-
return list(self._module_specs)
128+
def get_modules(self):
129+
return self._module_specs
127130

128131
def __len__(self):
129132
return len(self._module_specs)

0 commit comments

Comments
 (0)