Skip to content

Commit 3643388

Browse files
[𝘀𝗽𝗿] initial version
Created using spr 1.3.4
2 parents 3ddfad9 + 9749443 commit 3643388

File tree

3 files changed

+112
-14
lines changed

3 files changed

+112
-14
lines changed

compiler_opt/es/blackbox_evaluator.py

Lines changed: 63 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,18 @@ def get_results(
4545
def set_baseline(self) -> None:
4646
raise NotImplementedError()
4747

48-
@abc.abstractmethod
4948
def get_rewards(
5049
self, results: List[concurrent.futures.Future]) -> List[Optional[float]]:
51-
raise NotImplementedError()
50+
rewards = [None] * len(results)
51+
52+
for i in range(len(results)):
53+
if not results[i].exception():
54+
rewards[i] = results[i].result()
55+
else:
56+
logging.info('Error retrieving result from future: %s',
57+
str(results[i].exception()))
58+
59+
return rewards
5260

5361

5462
@gin.configurable
@@ -96,15 +104,58 @@ def get_results(
96104
def set_baseline(self) -> None:
97105
pass
98106

99-
def get_rewards(
100-
self, results: List[concurrent.futures.Future]) -> List[Optional[float]]:
101-
rewards = [None] * len(results)
102107

103-
for i in range(len(results)):
104-
if not results[i].exception():
105-
rewards[i] = results[i].result()
106-
else:
107-
logging.info('Error retrieving result from future: %s',
108-
str(results[i].exception()))
108+
@gin.configurable
109+
class TraceBlackboxEvaluator(BlackboxEvaluator):
110+
"""A blackbox evaluator that utilizes trace based cost modelling."""
109111

110-
return rewards
112+
def __init__(self, train_corpus: corpus.Corpus,
113+
est_type: blackbox_optimizers.EstimatorType, bb_trace_path: str,
114+
function_index_path: str):
115+
self._train_corpus = train_corpus
116+
self._est_type = est_type
117+
self._bb_trace_path = bb_trace_path
118+
self._function_index_path = function_index_path
119+
120+
self._has_baseline = False
121+
self._baseline = 1
122+
123+
def get_results(
124+
self, pool: FixedWorkerPool, perturbations: List[policy_saver.Policy]
125+
) -> List[concurrent.futures.Future]:
126+
job_args = []
127+
for perturbation in perturbations:
128+
job_args.append(
129+
(self._train_corpus.module_specs, self._function_index_path,
130+
self._bb_trace_path, perturbation))
131+
132+
_, futures = buffered_scheduler.schedule_on_worker_pool(
133+
action=lambda w, v: w.compile_corpus_and_evaluate(
134+
v[0], v[1], v[2], v[3]),
135+
jobs=job_args,
136+
worker_pool=pool)
137+
concurrent.futures.wait(
138+
futures, return_when=concurrent.futures.ALL_COMPLETED)
139+
return futures
140+
141+
def set_baseline(self, pool: FixedWorkerPool) -> None:
142+
if self._has_baseline:
143+
return
144+
145+
job_args = [(
146+
self._train_corpus.module_specs,
147+
self._function_index_path,
148+
self._bb_trace_path,
149+
None,
150+
)]
151+
152+
_, futures = buffered_scheduler.schedule_on_worker_pool(
153+
action=lambda w, v: w.compile_corpus_and_evaluate(
154+
v[0], v[1], v[2], v[3]),
155+
jobs=job_args,
156+
worker_pool=pool)
157+
158+
concurrent.futures.wait(
159+
futures, return_when=concurrent.futures.ALL_COMPLETED)
160+
self._baseline = futures[0].result()
161+
self._has_baseline = True

compiler_opt/es/blackbox_evaluator_test.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def test_sampling_get_results(self):
4141
self.assertSequenceAlmostEqual([result.result() for result in results],
4242
[1.0, 1.0, 1.0])
4343

44-
def test_sampling_get_rewards(self):
44+
def test_get_rewards(self):
4545
f1 = concurrent.futures.Future()
4646
f1.set_exception(None)
4747
f2 = concurrent.futures.Future()
@@ -50,3 +50,27 @@ def test_sampling_get_rewards(self):
5050
evaluator = blackbox_evaluator.SamplingBlackboxEvaluator(None, 5, 5, None)
5151
rewards = evaluator.get_rewards(results)
5252
self.assertEqual(rewards, [None, 2])
53+
54+
def test_trace_get_results(self):
55+
with local_worker_manager.LocalWorkerPoolManager(
56+
blackbox_test_utils.ESTraceWorker, count=3, arg='', kwarg='') as pool:
57+
perturbations = [b'00', b'01', b'10']
58+
test_corpus = corpus.create_corpus_for_testing(
59+
location=self.create_tempdir(),
60+
elements=[corpus.ModuleSpec(name='name1', size=1)])
61+
evaluator = blackbox_evaluator.TraceBlackboxEvaluator(
62+
test_corpus, 5, 'fake_bb_trace_path', 'fake_function_index_path')
63+
results = evaluator.get_results(pool, perturbations)
64+
self.assertSequenceAlmostEqual([result.result() for result in results],
65+
[1.0, 1.0, 1.0])
66+
67+
def test_trace_set_baseline(self):
68+
with local_worker_manager.LocalWorkerPoolManager(
69+
blackbox_test_utils.ESTraceWorker, count=1, arg='', kwarg='') as pool:
70+
test_corpus = corpus.create_corpus_for_testing(
71+
location=self.create_tempdir(),
72+
elements=[corpus.ModuleSpec(name='name1', size=1)])
73+
evaluator = blackbox_evaluator.TraceBlackboxEvaluator(
74+
test_corpus, 5, 'fake_bb_trace_path', 'fake_function_index_path')
75+
evaluator.set_baseline(pool)
76+
self.assertAlmostEqual(evaluator._baseline, 10)

compiler_opt/es/blackbox_test_utils.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# limitations under the License.
1515
"""Test facilities for Blackbox classes."""
1616

17-
from typing import List
17+
from typing import List, Collection, Optional
1818

1919
import gin
2020

@@ -41,3 +41,26 @@ def compile(self, policy: policy_saver.Policy,
4141
return self.function_value
4242
else:
4343
return 0.0
44+
45+
46+
class ESTraceWorker(worker.Worker):
47+
"""Temporary placeholder worker.
48+
49+
This is a test worker for TraceBlackboxEvaluator that expects a slightly
50+
different interface than other workers.
51+
"""
52+
53+
def __init__(self, arg, *, kwarg):
54+
del arg # Unused.
55+
del kwarg # Unused.
56+
self._function_value = 0.0
57+
58+
def compile_corpus_and_evaluate(
59+
self, modules: Collection[corpus.ModuleSpec], function_index_path: str,
60+
bb_trace_path: str,
61+
tflite_policy: Optional[policy_saver.Policy]) -> float:
62+
if modules and function_index_path and bb_trace_path and tflite_policy:
63+
self._function_value += 1
64+
return self._function_value
65+
else:
66+
return 10

0 commit comments

Comments
 (0)