Skip to content

Commit 3a4a297

Browse files
Add TraceBlackboxEvaluator
This patch adds TraceBlackboxEvaluator, an evaluator designed for trace based cost modelling. It implements the BlackboxEvaluator class, special casing everything that is needed. Reviewers: mtrofin Reviewed By: mtrofin Pull Request: #419
1 parent 96614ea commit 3a4a297

File tree

3 files changed

+111
-3
lines changed

3 files changed

+111
-3
lines changed

compiler_opt/es/blackbox_evaluator.py

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def get_results(
4242
raise NotImplementedError()
4343

4444
@abc.abstractmethod
45-
def set_baseline(self) -> None:
45+
def set_baseline(self, pool: FixedWorkerPool) -> None:
4646
raise NotImplementedError()
4747

4848
def get_rewards(
@@ -101,5 +101,65 @@ def get_results(
101101

102102
return futures
103103

104-
def set_baseline(self) -> None:
104+
def set_baseline(self, pool: FixedWorkerPool) -> None:
105+
del pool # Unused.
105106
pass
107+
108+
109+
@gin.configurable
110+
class TraceBlackboxEvaluator(BlackboxEvaluator):
111+
"""A blackbox evaluator that utilizes trace based cost modelling."""
112+
113+
def __init__(self, train_corpus: corpus.Corpus,
114+
est_type: blackbox_optimizers.EstimatorType, bb_trace_path: str,
115+
function_index_path: str):
116+
self._train_corpus = train_corpus
117+
self._est_type = est_type
118+
self._bb_trace_path = bb_trace_path
119+
self._function_index_path = function_index_path
120+
121+
self._baseline: Optional[float] = None
122+
123+
def get_results(
124+
self, pool: FixedWorkerPool, perturbations: List[policy_saver.Policy]
125+
) -> List[concurrent.futures.Future]:
126+
job_args = []
127+
for perturbation in perturbations:
128+
job_args.append({
129+
'modules': self._train_corpus.module_specs,
130+
'function_index_path': self._function_index_path,
131+
'bb_trace_path': self._bb_trace_path,
132+
'tflite_policy': perturbation
133+
})
134+
135+
_, futures = buffered_scheduler.schedule_on_worker_pool(
136+
action=lambda w, args: w.compile_corpus_and_evaluate(**args),
137+
jobs=job_args,
138+
worker_pool=pool)
139+
concurrent.futures.wait(
140+
futures, return_when=concurrent.futures.ALL_COMPLETED)
141+
return futures
142+
143+
def set_baseline(self, pool: FixedWorkerPool) -> None:
144+
if self._baseline is not None:
145+
raise RuntimeError('The baseline has already been set.')
146+
147+
job_args = [{
148+
'modules': self._train_corpus.module_specs,
149+
'function_index_path': self._function_index_path,
150+
'bb_trace_path': self._bb_trace_path,
151+
'tflite_policy': None,
152+
}]
153+
154+
_, futures = buffered_scheduler.schedule_on_worker_pool(
155+
action=lambda w, args: w.compile_corpus_and_evaluate(**args),
156+
jobs=job_args,
157+
worker_pool=pool)
158+
159+
concurrent.futures.wait(
160+
futures, return_when=concurrent.futures.ALL_COMPLETED)
161+
if len(futures) != 1:
162+
raise ValueError('Expected to have one result for setting the baseline,'
163+
f' got {len(futures)}')
164+
165+
self._baseline = futures[0].result()

compiler_opt/es/blackbox_evaluator_test.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,3 +50,28 @@ def test_get_rewards(self):
5050
evaluator = blackbox_evaluator.SamplingBlackboxEvaluator(None, 5, 5, None)
5151
rewards = evaluator.get_rewards(results)
5252
self.assertEqual(rewards, [None, 2])
53+
54+
def test_trace_get_results(self):
55+
with local_worker_manager.LocalWorkerPoolManager(
56+
blackbox_test_utils.ESTraceWorker, count=3, arg='', kwarg='') as pool:
57+
perturbations = [b'00', b'01', b'10']
58+
test_corpus = corpus.create_corpus_for_testing(
59+
location=self.create_tempdir(),
60+
elements=[corpus.ModuleSpec(name='name1', size=1)])
61+
evaluator = blackbox_evaluator.TraceBlackboxEvaluator(
62+
test_corpus, 5, 'fake_bb_trace_path', 'fake_function_index_path')
63+
results = evaluator.get_results(pool, perturbations)
64+
self.assertSequenceAlmostEqual([result.result() for result in results],
65+
[1.0, 1.0, 1.0])
66+
67+
def test_trace_set_baseline(self):
68+
with local_worker_manager.LocalWorkerPoolManager(
69+
blackbox_test_utils.ESTraceWorker, count=1, arg='', kwarg='') as pool:
70+
test_corpus = corpus.create_corpus_for_testing(
71+
location=self.create_tempdir(),
72+
elements=[corpus.ModuleSpec(name='name1', size=1)])
73+
evaluator = blackbox_evaluator.TraceBlackboxEvaluator(
74+
test_corpus, 5, 'fake_bb_trace_path', 'fake_function_index_path')
75+
evaluator.set_baseline(pool)
76+
# pylint: disable=protected-access
77+
self.assertAlmostEqual(evaluator._baseline, 10)

compiler_opt/es/blackbox_test_utils.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# limitations under the License.
1515
"""Test facilities for Blackbox classes."""
1616

17-
from typing import List
17+
from typing import List, Collection, Optional
1818

1919
import gin
2020

@@ -41,3 +41,26 @@ def compile(self, policy: policy_saver.Policy,
4141
return self.function_value
4242
else:
4343
return 0.0
44+
45+
46+
class ESTraceWorker(worker.Worker):
47+
"""Temporary placeholder worker.
48+
49+
This is a test worker for TraceBlackboxEvaluator that expects a slightly
50+
different interface than other workers.
51+
"""
52+
53+
def __init__(self, arg, *, kwarg):
54+
del arg # Unused.
55+
del kwarg # Unused.
56+
self._function_value = 0.0
57+
58+
def compile_corpus_and_evaluate(
59+
self, modules: Collection[corpus.ModuleSpec], function_index_path: str,
60+
bb_trace_path: str,
61+
tflite_policy: Optional[policy_saver.Policy]) -> float:
62+
if modules and function_index_path and bb_trace_path and tflite_policy:
63+
self._function_value += 1
64+
return self._function_value
65+
else:
66+
return 10

0 commit comments

Comments
 (0)