Skip to content

Commit 65e1ec3

Browse files
authored
Corpus: let module_filter be a function (#169)
That makes it easier to filter arbitrary modules (e.g. from an allow/deny list)
1 parent 91fbd44 commit 65e1ec3

File tree

3 files changed

+6
-8
lines changed

3 files changed

+6
-8
lines changed

compiler_opt/rl/corpus.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,10 @@
1717
import concurrent.futures
1818
import math
1919
import random
20-
import re
2120

2221
from absl import logging
2322
from dataclasses import dataclass
24-
from typing import Any, Dict, List, Optional, Tuple
23+
from typing import Any, Callable, Dict, List, Optional, Tuple
2524

2625
import json
2726
import os
@@ -227,7 +226,7 @@ class ReplaceContext:
227226
def __init__(self,
228227
*,
229228
data_path: str,
230-
module_filter: Optional[re.Pattern] = None,
229+
module_filter: Optional[Callable[[str], bool]] = None,
231230
additional_flags: Tuple[str, ...] = (),
232231
delete_flags: Tuple[str, ...] = (),
233232
replace_flags: Optional[Dict[str, str]] = None,
@@ -309,9 +308,7 @@ def __init__(self,
309308
raise ValueError('do not use add/delete flags to replace')
310309

311310
if module_filter:
312-
module_paths = [
313-
name for name in module_paths if module_filter.match(name)
314-
]
311+
module_paths = [name for name in module_paths if module_filter(name)]
315312

316313
def get_cmdline(name: str):
317314
if cmd_override_was_specified:

compiler_opt/rl/corpus_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ def test_filter(self):
230230
corpus.ModuleSpec(name='largest', size=500),
231231
corpus.ModuleSpec(name='small', size=100)
232232
],
233-
module_filter=re.compile(r'.+l'))
233+
module_filter=lambda name: re.compile(r'.+l').match(name))
234234
sample = cps.sample(999, sort=True)
235235
self.assertLen(sample, 3)
236236
self.assertEqual(sample[0].name, 'middle')

compiler_opt/tools/generate_default_trace.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,8 @@ def main(_):
146146

147147
cps = corpus.Corpus(
148148
data_path=_DATA_PATH.value,
149-
module_filter=module_filter,
149+
module_filter=lambda name: True
150+
if not module_filter else module_filter.match(name),
150151
additional_flags=config.flags_to_add(),
151152
delete_flags=config.flags_to_delete(),
152153
replace_flags=config.flags_to_replace())

0 commit comments

Comments
 (0)