@@ -45,10 +45,18 @@ def get_results(
4545 def set_baseline (self ) -> None :
4646 raise NotImplementedError ()
4747
48- @abc .abstractmethod
4948 def get_rewards (
5049 self , results : List [concurrent .futures .Future ]) -> List [Optional [float ]]:
51- raise NotImplementedError ()
50+ rewards = [None ] * len (results )
51+
52+ for i in range (len (results )):
53+ if not results [i ].exception ():
54+ rewards [i ] = results [i ].result ()
55+ else :
56+ logging .info ('Error retrieving result from future: %s' ,
57+ str (results [i ].exception ()))
58+
59+ return rewards
5260
5361
5462@gin .configurable
@@ -96,15 +104,58 @@ def get_results(
96104 def set_baseline (self ) -> None :
97105 pass
98106
99- def get_rewards (
100- self , results : List [concurrent .futures .Future ]) -> List [Optional [float ]]:
101- rewards = [None ] * len (results )
102107
103- for i in range (len (results )):
104- if not results [i ].exception ():
105- rewards [i ] = results [i ].result ()
106- else :
107- logging .info ('Error retrieving result from future: %s' ,
108- str (results [i ].exception ()))
108+ @gin .configurable
109+ class TraceBlackboxEvaluator (BlackboxEvaluator ):
110+ """A blackbox evaluator that utilizes trace based cost modelling."""
109111
110- return rewards
112+ def __init__ (self , train_corpus : corpus .Corpus ,
113+ est_type : blackbox_optimizers .EstimatorType , bb_trace_path : str ,
114+ function_index_path : str ):
115+ self ._train_corpus = train_corpus
116+ self ._est_type = est_type
117+ self ._bb_trace_path = bb_trace_path
118+ self ._function_index_path = function_index_path
119+
120+ self ._has_baseline = False
121+ self ._baseline = 1
122+
123+ def get_results (
124+ self , pool : FixedWorkerPool , perturbations : List [policy_saver .Policy ]
125+ ) -> List [concurrent .futures .Future ]:
126+ job_args = []
127+ for perturbation in perturbations :
128+ job_args .append (
129+ (self ._train_corpus .module_specs , self ._function_index_path ,
130+ self ._bb_trace_path , perturbation ))
131+
132+ _ , futures = buffered_scheduler .schedule_on_worker_pool (
133+ action = lambda w , v : w .compile_corpus_and_evaluate (
134+ v [0 ], v [1 ], v [2 ], v [3 ]),
135+ jobs = job_args ,
136+ worker_pool = pool )
137+ concurrent .futures .wait (
138+ futures , return_when = concurrent .futures .ALL_COMPLETED )
139+ return futures
140+
141+ def set_baseline (self , pool : FixedWorkerPool ) -> None :
142+ if self ._has_baseline :
143+ return
144+
145+ job_args = [(
146+ self ._train_corpus .module_specs ,
147+ self ._function_index_path ,
148+ self ._bb_trace_path ,
149+ None ,
150+ )]
151+
152+ _ , futures = buffered_scheduler .schedule_on_worker_pool (
153+ action = lambda w , v : w .compile_corpus_and_evaluate (
154+ v [0 ], v [1 ], v [2 ], v [3 ]),
155+ jobs = job_args ,
156+ worker_pool = pool )
157+
158+ concurrent .futures .wait (
159+ futures , return_when = concurrent .futures .ALL_COMPLETED )
160+ self ._baseline = futures [0 ].result ()
161+ self ._has_baseline = True
0 commit comments