Skip to content

Commit aae377a

Browse files
authored
Add blackbox_learner (#286)
Add blackbox_learner
1 parent 5f97020 commit aae377a

File tree

2 files changed

+527
-0
lines changed

2 files changed

+527
-0
lines changed

compiler_opt/es/blackbox_learner.py

Lines changed: 337 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,337 @@
1+
# coding=utf-8
2+
# Copyright 2020 Google LLC
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""Class for coordinating blackbox optimization."""
16+
17+
import os
18+
from absl import logging
19+
import concurrent.futures
20+
import dataclasses
21+
import gin
22+
import math
23+
import numpy as np
24+
import numpy.typing as npt
25+
import tempfile
26+
import tensorflow as tf
27+
from typing import List, Optional, Protocol
28+
29+
from compiler_opt.distributed import buffered_scheduler
30+
from compiler_opt.distributed.worker import FixedWorkerPool
31+
from compiler_opt.es import blackbox_optimizers
32+
from compiler_opt.es import policy_utils
33+
from compiler_opt.rl import corpus
34+
from compiler_opt.rl import policy_saver
35+
36+
# If less than 40% of requests succeed, skip the step.
37+
_SKIP_STEP_SUCCESS_RATIO = 0.4
38+
39+
40+
@gin.configurable
41+
@dataclasses.dataclass(frozen=True)
42+
class BlackboxLearnerConfig:
43+
"""Hyperparameter configuration for BlackboxLearner."""
44+
45+
# Total steps to train for
46+
total_steps: int
47+
48+
# Name of the blackbox optimization algorithm
49+
blackbox_optimizer: blackbox_optimizers.Algorithm
50+
51+
# What kind of ES training?
52+
# - antithetic: for each perturbtation, try an antiperturbation
53+
# - forward_fd: try total_num_perturbations independent perturbations
54+
est_type: blackbox_optimizers.EstimatorType
55+
56+
# Should the rewards for blackbox optimization in a single step be normalized?
57+
fvalues_normalization: bool
58+
59+
# How to update optimizer hyperparameters
60+
hyperparameters_update_method: blackbox_optimizers.UpdateMethod
61+
62+
# Number of top performing perturbations to select in the optimizer
63+
# 0 means all
64+
num_top_directions: int
65+
66+
# How many IR files to try a single perturbation on?
67+
num_ir_repeats_within_worker: int
68+
69+
# How many times should we reuse IR to test different policies?
70+
num_ir_repeats_across_worker: int
71+
72+
# How many IR files to sample from the test corpus at each iteration
73+
num_exact_evals: int
74+
75+
# How many perturbations to attempt at each perturbation
76+
total_num_perturbations: int
77+
78+
# How much to scale the stdev of the perturbations
79+
precision_parameter: float
80+
81+
# Learning rate
82+
step_size: float
83+
84+
85+
def _prune_skipped_perturbations(perturbations: List[npt.NDArray[np.float32]],
86+
rewards: List[Optional[float]]):
87+
"""Remove perturbations that were skipped during the training step.
88+
89+
Perturbations may be skipped due to an early exit condition or a server error
90+
(clang timeout, malformed training example, etc). The blackbox optimizer
91+
assumes that each perturbations has a valid reward, so we must remove any of
92+
these "skipped" perturbations.
93+
94+
Pruning occurs in-place.
95+
96+
Args:
97+
perturbations: the model perturbations used for the ES training step.
98+
rewards: the rewards for each perturbation.
99+
100+
Returns:
101+
The number of perturbations that were pruned.
102+
"""
103+
indices_to_prune = []
104+
for i, reward in enumerate(rewards):
105+
if reward is None:
106+
indices_to_prune.append(i)
107+
108+
# Iterate in reverse so that the indices remain valid
109+
for i in reversed(indices_to_prune):
110+
del perturbations[i]
111+
del rewards[i]
112+
113+
return len(indices_to_prune)
114+
115+
116+
class PolicySaverCallableType(Protocol):
117+
"""Protocol for the policy saver function.
118+
A Protocol is required to type annotate
119+
the function with keyword arguments"""
120+
121+
def __call__(self, parameters: npt.NDArray[np.float32],
122+
policy_name: str) -> None:
123+
...
124+
125+
126+
class BlackboxLearner:
127+
"""Implementation of blackbox learning."""
128+
129+
def __init__(self,
130+
blackbox_opt: blackbox_optimizers.BlackboxOptimizer,
131+
sampler: corpus.Corpus,
132+
tf_policy_path: str,
133+
output_dir: str,
134+
policy_saver_fn: PolicySaverCallableType,
135+
model_weights: npt.NDArray[np.float32],
136+
config: BlackboxLearnerConfig,
137+
initial_step: int = 0,
138+
deadline: float = 30.0,
139+
seed: Optional[int] = None):
140+
"""Construct a BlackboxLeaner.
141+
142+
Args:
143+
blackbox_opt: the blackbox optimizer to use
144+
train_sampler: corpus_sampler for training data.
145+
tf_policy_path: where to write the tf policy
146+
output_dir: the directory to write all outputs
147+
policy_saver_fn: function to save a policy to cns
148+
model_weights: the weights of the current model
149+
config: configuration for blackbox optimization.
150+
stubs: grpc stubs to inlining/regalloc servers
151+
initial_step: the initial step for learning.
152+
deadline: the deadline in seconds for requests to the inlining server.
153+
"""
154+
self._blackbox_opt = blackbox_opt
155+
self._sampler = sampler
156+
self._tf_policy_path = tf_policy_path
157+
self._output_dir = output_dir
158+
self._policy_saver_fn = policy_saver_fn
159+
self._model_weights = model_weights
160+
self._config = config
161+
self._step = initial_step
162+
self._deadline = deadline
163+
self._seed = seed
164+
165+
# While we're waiting for the ES requests, we can
166+
# collect samples for the next round of training.
167+
self._samples = []
168+
169+
self._summary_writer = tf.summary.create_file_writer(output_dir)
170+
171+
def _get_perturbations(self) -> List[npt.NDArray[np.float32]]:
172+
"""Get perturbations for the model weights."""
173+
perturbations = []
174+
rng = np.random.default_rng(seed=self._seed)
175+
for _ in range(self._config.total_num_perturbations):
176+
perturbations.append(
177+
rng.normal(size=(len(self._model_weights))) *
178+
self._config.precision_parameter)
179+
return perturbations
180+
181+
def _get_rewards(
182+
self, results: List[concurrent.futures.Future]) -> List[Optional[float]]:
183+
"""Convert ES results to reward numbers."""
184+
rewards = [None] * len(results)
185+
186+
for i in range(len(results)):
187+
if not results[i].exception():
188+
rewards[i] = results[i].result()
189+
else:
190+
logging.info('Error retrieving result from future: %s',
191+
str(results[i].exception()))
192+
193+
return rewards
194+
195+
def _update_model(self, perturbations: List[npt.NDArray[np.float32]],
196+
rewards: List[float]) -> None:
197+
"""Update the model given a list of perturbations and rewards."""
198+
self._model_weights = self._blackbox_opt.run_step(
199+
perturbations=np.array(perturbations),
200+
function_values=np.array(rewards),
201+
current_input=self._model_weights,
202+
current_value=np.mean(rewards))
203+
204+
def _log_rewards(self, rewards: List[float]) -> None:
205+
"""Log reward to console."""
206+
logging.info('Train reward: [%f]', np.mean(rewards))
207+
208+
def _log_tf_summary(self, rewards: List[float]) -> None:
209+
"""Log tensorboard data."""
210+
with self._summary_writer.as_default():
211+
tf.summary.scalar(
212+
'reward/average_reward_train', np.mean(rewards), step=self._step)
213+
214+
tf.summary.histogram('reward/reward_train', rewards, step=self._step)
215+
216+
train_regressions = [reward for reward in rewards if reward < 0]
217+
tf.summary.scalar(
218+
'reward/regression_probability_train',
219+
len(train_regressions) / len(rewards),
220+
step=self._step)
221+
222+
tf.summary.scalar(
223+
'reward/regression_avg_train',
224+
np.mean(train_regressions) if len(train_regressions) > 0 else 0,
225+
step=self._step)
226+
227+
# The "max regression" is the min value, as the regressions are negative.
228+
tf.summary.scalar(
229+
'reward/regression_max_train',
230+
min(train_regressions, default=0),
231+
step=self._step)
232+
233+
train_wins = [reward for reward in rewards if reward > 0]
234+
tf.summary.scalar(
235+
'reward/win_probability_train',
236+
len(train_wins) / len(rewards),
237+
step=self._step)
238+
239+
def _save_model(self) -> None:
240+
"""Save the model."""
241+
logging.info('Saving the model.')
242+
self._policy_saver_fn(
243+
parameters=self._model_weights, policy_name=f'iteration{self._step}')
244+
245+
def get_model_weights(self) -> npt.NDArray[np.float32]:
246+
return self._model_weights
247+
248+
def _get_results(
249+
self, pool: FixedWorkerPool,
250+
perturbations: List[bytes]) -> List[concurrent.futures.Future]:
251+
if not self._samples:
252+
for _ in range(self._config.total_num_perturbations):
253+
sample = self._sampler.sample(self._config.num_ir_repeats_within_worker)
254+
self._samples.append(sample)
255+
# add copy of sample for antithetic perturbation pair
256+
if self._config.est_type == (
257+
blackbox_optimizers.EstimatorType.ANTITHETIC):
258+
self._samples.append(sample)
259+
260+
compile_args = zip(perturbations, self._samples)
261+
262+
_, futures = buffered_scheduler.schedule_on_worker_pool(
263+
action=lambda w, v: w.compile(v[0], v[1]),
264+
jobs=compile_args,
265+
worker_pool=pool)
266+
267+
not_done = futures
268+
# wait for all futures to finish
269+
while not_done:
270+
# update lists as work gets done
271+
_, not_done = concurrent.futures.wait(
272+
not_done, return_when=concurrent.futures.FIRST_COMPLETED)
273+
274+
return futures
275+
276+
def _get_policy_as_bytes(self,
277+
perturbation: npt.NDArray[np.float32]) -> bytes:
278+
sm = tf.saved_model.load(self._tf_policy_path)
279+
# devectorize the perturbation
280+
policy_utils.set_vectorized_parameters_for_policy(sm, perturbation)
281+
282+
with tempfile.TemporaryDirectory() as tmpdir:
283+
sm_dir = os.path.join(tmpdir, 'sm')
284+
tf.saved_model.save(sm, sm_dir, signatures=sm.signatures)
285+
src = os.path.join(self._tf_policy_path, policy_saver.OUTPUT_SIGNATURE)
286+
dst = os.path.join(sm_dir, policy_saver.OUTPUT_SIGNATURE)
287+
tf.io.gfile.copy(src, dst)
288+
289+
# convert to tflite
290+
tfl_dir = os.path.join(tmpdir, 'tfl')
291+
policy_saver.convert_mlgo_model(sm_dir, tfl_dir)
292+
293+
# create and return policy
294+
policy_obj = policy_saver.Policy.from_filesystem(tfl_dir)
295+
return policy_obj.policy
296+
297+
def run_step(self, pool: FixedWorkerPool) -> None:
298+
"""Run a single step of blackbox learning.
299+
This does not instantaneously return due to several I/O
300+
and executions running while this waits for the responses"""
301+
logging.info('-' * 80)
302+
logging.info('Step [%d]', self._step)
303+
304+
initial_perturbations = self._get_perturbations()
305+
# positive-negative pairs
306+
if self._config.est_type == blackbox_optimizers.EstimatorType.ANTITHETIC:
307+
initial_perturbations = [
308+
p for p in initial_perturbations for p in (p, -p)
309+
]
310+
311+
# convert to bytes for compile job
312+
# TODO: current conversion is inefficient.
313+
# consider doing this on the worker side
314+
perturbations_as_bytes = []
315+
for perturbation in initial_perturbations:
316+
perturbations_as_bytes.append(self._get_policy_as_bytes(perturbation))
317+
318+
results = self._get_results(pool, perturbations_as_bytes)
319+
rewards = self._get_rewards(results)
320+
321+
num_pruned = _prune_skipped_perturbations(initial_perturbations, rewards)
322+
logging.info('Pruned [%d]', num_pruned)
323+
min_num_rewards = math.ceil(_SKIP_STEP_SUCCESS_RATIO * len(results))
324+
if len(rewards) < min_num_rewards:
325+
logging.warning(
326+
'Skipping the step, too many requests failed: %d of %d '
327+
'train requests succeeded (required: %d)', len(rewards), len(results),
328+
min_num_rewards)
329+
return
330+
331+
self._update_model(initial_perturbations, rewards)
332+
self._log_rewards(rewards)
333+
self._log_tf_summary(rewards)
334+
335+
self._save_model()
336+
337+
self._step += 1

0 commit comments

Comments
 (0)