21
21
from compiler_opt .rl import corpus
22
22
from compiler_opt .es import blackbox_test_utils
23
23
from compiler_opt .es import blackbox_evaluator
24
+ from compiler_opt .es import blackbox_optimizers
24
25
25
26
26
27
class BlackboxEvaluatorTests (absltest .TestCase ):
@@ -33,7 +34,8 @@ def test_sampling_get_results(self):
33
34
worker_args = ('' ,),
34
35
worker_kwargs = dict (kwarg = '' )) as pool :
35
36
perturbations = [b'00' , b'01' , b'10' ]
36
- evaluator = blackbox_evaluator .SamplingBlackboxEvaluator (None , 5 , 5 , None )
37
+ evaluator = blackbox_evaluator .SamplingBlackboxEvaluator (
38
+ None , blackbox_optimizers .EstimatorType .FORWARD_FD , 5 , None )
37
39
# pylint: disable=protected-access
38
40
evaluator ._samples = [[corpus .ModuleSpec (name = 'name1' , size = 1 )],
39
41
[corpus .ModuleSpec (name = 'name2' , size = 1 )],
@@ -49,7 +51,8 @@ def test_get_rewards(self):
49
51
f2 = concurrent .futures .Future ()
50
52
f2 .set_result (2 )
51
53
results = [f1 , f2 ]
52
- evaluator = blackbox_evaluator .SamplingBlackboxEvaluator (None , 5 , 5 , None )
54
+ evaluator = blackbox_evaluator .SamplingBlackboxEvaluator (
55
+ None , blackbox_optimizers .EstimatorType .FORWARD_FD , 5 , None )
53
56
rewards = evaluator .get_rewards (results )
54
57
self .assertEqual (rewards , [None , 2 ])
55
58
@@ -64,7 +67,8 @@ def test_trace_get_results(self):
64
67
location = self .create_tempdir ().full_path ,
65
68
elements = [corpus .ModuleSpec (name = 'name1' , size = 1 )])
66
69
evaluator = blackbox_evaluator .TraceBlackboxEvaluator (
67
- test_corpus , 5 , 'fake_bb_trace_path' , 'fake_function_index_path' )
70
+ test_corpus , blackbox_optimizers .EstimatorType .FORWARD_FD ,
71
+ 'fake_bb_trace_path' , 'fake_function_index_path' )
68
72
results = evaluator .get_results (pool , perturbations )
69
73
self .assertSequenceAlmostEqual ([result .result () for result in results ],
70
74
[1.0 , 1.0 , 1.0 ])
@@ -79,7 +83,8 @@ def test_trace_set_baseline(self):
79
83
location = self .create_tempdir ().full_path ,
80
84
elements = [corpus .ModuleSpec (name = 'name1' , size = 1 )])
81
85
evaluator = blackbox_evaluator .TraceBlackboxEvaluator (
82
- test_corpus , 5 , 'fake_bb_trace_path' , 'fake_function_index_path' )
86
+ test_corpus , blackbox_optimizers .EstimatorType .FORWARD_FD ,
87
+ 'fake_bb_trace_path' , 'fake_function_index_path' )
83
88
evaluator .set_baseline (pool )
84
89
# pylint: disable=protected-access
85
90
self .assertAlmostEqual (evaluator ._baseline , 10 )
0 commit comments