Skip to content

Commit 711e230

Browse files
Pass EstimatorType in blackbox_evaluator_test
This patch passes an actual EstimatorType value when instantiating BlackboxEvaluator classes. This makes the test quite a bit easier to read as the estimator type names actually mean something rather than just the integers that were stuck there before given the values do not really matter.
1 parent 8a3c1fa commit 711e230

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

compiler_opt/es/blackbox_evaluator_test.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from compiler_opt.rl import corpus
2222
from compiler_opt.es import blackbox_test_utils
2323
from compiler_opt.es import blackbox_evaluator
24+
from compiler_opt.es import blackbox_optimizers
2425

2526

2627
class BlackboxEvaluatorTests(absltest.TestCase):
@@ -33,7 +34,8 @@ def test_sampling_get_results(self):
3334
worker_args=('',),
3435
worker_kwargs=dict(kwarg='')) as pool:
3536
perturbations = [b'00', b'01', b'10']
36-
evaluator = blackbox_evaluator.SamplingBlackboxEvaluator(None, 5, 5, None)
37+
evaluator = blackbox_evaluator.SamplingBlackboxEvaluator(
38+
None, blackbox_optimizers.EstimatorType.FORWARD_FD, 5, None)
3739
# pylint: disable=protected-access
3840
evaluator._samples = [[corpus.ModuleSpec(name='name1', size=1)],
3941
[corpus.ModuleSpec(name='name2', size=1)],
@@ -49,7 +51,8 @@ def test_get_rewards(self):
4951
f2 = concurrent.futures.Future()
5052
f2.set_result(2)
5153
results = [f1, f2]
52-
evaluator = blackbox_evaluator.SamplingBlackboxEvaluator(None, 5, 5, None)
54+
evaluator = blackbox_evaluator.SamplingBlackboxEvaluator(
55+
None, blackbox_optimizers.EstimatorType.FORWARD_FD, 5, None)
5356
rewards = evaluator.get_rewards(results)
5457
self.assertEqual(rewards, [None, 2])
5558

@@ -64,7 +67,8 @@ def test_trace_get_results(self):
6467
location=self.create_tempdir().full_path,
6568
elements=[corpus.ModuleSpec(name='name1', size=1)])
6669
evaluator = blackbox_evaluator.TraceBlackboxEvaluator(
67-
test_corpus, 5, 'fake_bb_trace_path', 'fake_function_index_path')
70+
test_corpus, blackbox_optimizers.EstimatorType.FORWARD_FD,
71+
'fake_bb_trace_path', 'fake_function_index_path')
6872
results = evaluator.get_results(pool, perturbations)
6973
self.assertSequenceAlmostEqual([result.result() for result in results],
7074
[1.0, 1.0, 1.0])
@@ -79,7 +83,8 @@ def test_trace_set_baseline(self):
7983
location=self.create_tempdir().full_path,
8084
elements=[corpus.ModuleSpec(name='name1', size=1)])
8185
evaluator = blackbox_evaluator.TraceBlackboxEvaluator(
82-
test_corpus, 5, 'fake_bb_trace_path', 'fake_function_index_path')
86+
test_corpus, blackbox_optimizers.EstimatorType.FORWARD_FD,
87+
'fake_bb_trace_path', 'fake_function_index_path')
8388
evaluator.set_baseline(pool)
8489
# pylint: disable=protected-access
8590
self.assertAlmostEqual(evaluator._baseline, 10)

0 commit comments

Comments
 (0)