@@ -42,7 +42,7 @@ def get_results(
42
42
raise NotImplementedError ()
43
43
44
44
@abc .abstractmethod
45
- def set_baseline (self ) -> None :
45
+ def set_baseline (self , pool : FixedWorkerPool ) -> None :
46
46
raise NotImplementedError ()
47
47
48
48
def get_rewards (
@@ -101,5 +101,65 @@ def get_results(
101
101
102
102
return futures
103
103
104
- def set_baseline (self ) -> None :
104
+ def set_baseline (self , pool : FixedWorkerPool ) -> None :
105
+ del pool # Unused.
105
106
pass
107
+
108
+
109
+ @gin .configurable
110
+ class TraceBlackboxEvaluator (BlackboxEvaluator ):
111
+ """A blackbox evaluator that utilizes trace based cost modelling."""
112
+
113
+ def __init__ (self , train_corpus : corpus .Corpus ,
114
+ est_type : blackbox_optimizers .EstimatorType , bb_trace_path : str ,
115
+ function_index_path : str ):
116
+ self ._train_corpus = train_corpus
117
+ self ._est_type = est_type
118
+ self ._bb_trace_path = bb_trace_path
119
+ self ._function_index_path = function_index_path
120
+
121
+ self ._baseline : Optional [float ] = None
122
+
123
+ def get_results (
124
+ self , pool : FixedWorkerPool , perturbations : List [policy_saver .Policy ]
125
+ ) -> List [concurrent .futures .Future ]:
126
+ job_args = []
127
+ for perturbation in perturbations :
128
+ job_args .append ({
129
+ 'modules' : self ._train_corpus .module_specs ,
130
+ 'function_index_path' : self ._function_index_path ,
131
+ 'bb_trace_path' : self ._bb_trace_path ,
132
+ 'tflite_policy' : perturbation
133
+ })
134
+
135
+ _ , futures = buffered_scheduler .schedule_on_worker_pool (
136
+ action = lambda w , args : w .compile_corpus_and_evaluate (** args ),
137
+ jobs = job_args ,
138
+ worker_pool = pool )
139
+ concurrent .futures .wait (
140
+ futures , return_when = concurrent .futures .ALL_COMPLETED )
141
+ return futures
142
+
143
+ def set_baseline (self , pool : FixedWorkerPool ) -> None :
144
+ if self ._baseline is not None :
145
+ raise RuntimeError ('The baseline has already been set.' )
146
+
147
+ job_args = [{
148
+ 'modules' : self ._train_corpus .module_specs ,
149
+ 'function_index_path' : self ._function_index_path ,
150
+ 'bb_trace_path' : self ._bb_trace_path ,
151
+ 'tflite_policy' : None ,
152
+ }]
153
+
154
+ _ , futures = buffered_scheduler .schedule_on_worker_pool (
155
+ action = lambda w , args : w .compile_corpus_and_evaluate (** args ),
156
+ jobs = job_args ,
157
+ worker_pool = pool )
158
+
159
+ concurrent .futures .wait (
160
+ futures , return_when = concurrent .futures .ALL_COMPLETED )
161
+ if len (futures ) != 1 :
162
+ raise ValueError ('Expected to have one result for setting the baseline,'
163
+ f' got { len (futures )} ' )
164
+
165
+ self ._baseline = futures [0 ].result ()
0 commit comments