diff --git a/compiler_opt/rl/corpus.py b/compiler_opt/rl/corpus.py index e29abf31..f30c1b6b 100644 --- a/compiler_opt/rl/corpus.py +++ b/compiler_opt/rl/corpus.py @@ -19,7 +19,7 @@ from absl import logging from dataclasses import dataclass -from typing import List, Dict, Tuple, Any +from typing import Iterable, List, Dict, Tuple, Any import json import os @@ -45,7 +45,7 @@ def __init__(self): self._ranges = {} def __call__(self, - module_specs: List[ModuleSpec], + module_specs: Tuple[ModuleSpec], k: int, n: int = 20) -> List[ModuleSpec]: """ @@ -86,20 +86,23 @@ def __init__(self, data_path: str, additional_flags: Tuple[str, ...] = (), delete_flags: Tuple[str, ...] = ()): - self._module_specs = _build_modulespecs_from_datapath( - data_path=data_path, - additional_flags=additional_flags, - delete_flags=delete_flags) + self._module_specs = tuple( + sorted( + _build_modulespecs_from_datapath( + data_path=data_path, + additional_flags=additional_flags, + delete_flags=delete_flags), + key=lambda m: m.size, + reverse=True)) self._root_dir = data_path - self._module_specs.sort(key=lambda m: m.size, reverse=True) @classmethod - def from_module_specs(cls, module_specs: List[ModuleSpec]): + def from_module_specs(cls, module_specs: Iterable[ModuleSpec]): """Construct a Corpus from module specs. Mostly for testing purposes.""" cps = cls.__new__(cls) # Avoid calling __init__ super(cls, cps).__init__() - cps._module_specs = list(module_specs) # Don't mutate the original list. - cps._module_specs.sort(key=lambda m: m.size, reverse=True) + cps._module_specs = tuple( + sorted(module_specs, key=lambda m: m.size, reverse=True)) cps.root_dir = None return cps @@ -120,7 +123,12 @@ def sample(self, def filter(self, p: re.Pattern): """Filters module specs, keeping those which match the provided pattern.""" - self._module_specs = [ms for ms in self._module_specs if p.match(ms.name)] + self._module_specs = tuple( + ms for ms in self._module_specs if p.match(ms.name)) + + @property + def module_specs(self): + return self._module_specs def __len__(self): return len(self._module_specs) diff --git a/compiler_opt/rl/corpus_test.py b/compiler_opt/rl/corpus_test.py index c8202206..5e6d153e 100644 --- a/compiler_opt/rl/corpus_test.py +++ b/compiler_opt/rl/corpus_test.py @@ -246,8 +246,10 @@ def test_constructor(self): cps = corpus.Corpus(tempdir.full_path, additional_flags=('-add',)) self.assertEqual( - corpus._build_modulespecs_from_datapath( - tempdir.full_path, additional_flags=('-add',)), cps._module_specs) + tuple( + corpus._build_modulespecs_from_datapath( + tempdir.full_path, additional_flags=('-add',))), + cps.module_specs) self.assertEqual(len(cps), 1) def test_sample(self):