Skip to content

Commit a7e3fdf

Browse files
Move TFLite Policy Generation to ES Workers
This patch moves TFLite Policy Generation for ES to the workers. This step is reasonably expensive and doing it serially on the main node takes quite some time with a reasonable number of perturbations. This patch moves this step to the workers to parallelize it. Reviewers: mtrofin Reviewed By: mtrofin Pull Request: #448
1 parent c450ded commit a7e3fdf

File tree

7 files changed

+88
-92
lines changed

7 files changed

+88
-92
lines changed

compiler_opt/es/blackbox_evaluator.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
from compiler_opt.rl import corpus
2424
from compiler_opt.es import blackbox_optimizers
2525
from compiler_opt.distributed import buffered_scheduler
26-
from compiler_opt.rl import policy_saver
2726

2827

2928
class BlackboxEvaluator(metaclass=abc.ABCMeta):
@@ -35,8 +34,8 @@ def __init__(self, train_corpus: corpus.Corpus):
3534

3635
@abc.abstractmethod
3736
def get_results(
38-
self, pool: FixedWorkerPool, perturbations: list[policy_saver.Policy]
39-
) -> list[concurrent.futures.Future]:
37+
self, pool: FixedWorkerPool,
38+
perturbations: list[bytes]) -> list[concurrent.futures.Future]:
4039
raise NotImplementedError()
4140

4241
@abc.abstractmethod
@@ -73,8 +72,8 @@ def __init__(self, train_corpus: corpus.Corpus,
7372
super().__init__(train_corpus)
7473

7574
def get_results(
76-
self, pool: FixedWorkerPool, perturbations: list[policy_saver.Policy]
77-
) -> list[concurrent.futures.Future]:
75+
self, pool: FixedWorkerPool,
76+
perturbations: list[bytes]) -> list[concurrent.futures.Future]:
7877
if not self._samples:
7978
for _ in range(self._total_num_perturbations):
8079
sample = self._train_corpus.sample(self._num_ir_repeats_within_worker)
@@ -120,13 +119,13 @@ def __init__(self, train_corpus: corpus.Corpus,
120119
self._baseline: float | None = None
121120

122121
def get_results(
123-
self, pool: FixedWorkerPool, perturbations: list[policy_saver.Policy]
124-
) -> list[concurrent.futures.Future]:
122+
self, pool: FixedWorkerPool,
123+
perturbations: list[bytes]) -> list[concurrent.futures.Future]:
125124
job_args = [{
126125
'modules': self._train_corpus.module_specs,
127126
'function_index_path': self._function_index_path,
128127
'bb_trace_path': self._bb_trace_path,
129-
'tflite_policy': perturbation
128+
'policy_as_bytes': perturbation,
130129
} for perturbation in perturbations]
131130

132131
_, futures = buffered_scheduler.schedule_on_worker_pool(
@@ -145,7 +144,7 @@ def set_baseline(self, pool: FixedWorkerPool) -> None:
145144
'modules': self._train_corpus.module_specs,
146145
'function_index_path': self._function_index_path,
147146
'bb_trace_path': self._bb_trace_path,
148-
'tflite_policy': None,
147+
'policy_as_bytes': None,
149148
}]
150149

151150
_, futures = buffered_scheduler.schedule_on_worker_pool(

compiler_opt/es/blackbox_learner.py

Lines changed: 7 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -13,22 +13,18 @@
1313
# limitations under the License.
1414
"""Class for coordinating blackbox optimization."""
1515

16-
import os
1716
from absl import logging
1817
import dataclasses
1918
import gin
2019
import math
2120
import numpy as np
2221
import numpy.typing as npt
23-
import tempfile
2422
import tensorflow as tf
2523
from typing import Protocol
2624

2725
from compiler_opt.distributed.worker import FixedWorkerPool
2826
from compiler_opt.es import blackbox_optimizers
29-
from compiler_opt.es import policy_utils
3027
from compiler_opt.rl import corpus
31-
from compiler_opt.rl import policy_saver
3228
from compiler_opt.es import blackbox_evaluator # pylint: disable=unused-import
3329

3430
# Pytype cannot pick up the pyi file for tensorflow.summary. Disable the error
@@ -128,7 +124,6 @@ class BlackboxLearner:
128124
def __init__(self,
129125
blackbox_opt: blackbox_optimizers.BlackboxOptimizer,
130126
train_corpus: corpus.Corpus,
131-
tf_policy_path: str,
132127
output_dir: str,
133128
policy_saver_fn: PolicySaverCallableType,
134129
model_weights: npt.NDArray[np.float32],
@@ -141,7 +136,6 @@ def __init__(self,
141136
Args:
142137
blackbox_opt: the blackbox optimizer to use
143138
train_corpus: the training corpus to utiilize
144-
tf_policy_path: where to write the tf policy
145139
output_dir: the directory to write all outputs
146140
policy_saver_fn: function to save a policy to cns
147141
model_weights: the weights of the current model
@@ -151,7 +145,6 @@ def __init__(self,
151145
"""
152146
self._blackbox_opt = blackbox_opt
153147
self._train_corpus = train_corpus
154-
self._tf_policy_path = tf_policy_path
155148
self._output_dir = output_dir
156149
self._policy_saver_fn = policy_saver_fn
157150
self._model_weights = model_weights
@@ -237,29 +230,6 @@ def _save_model(self) -> None:
237230
def get_model_weights(self) -> npt.NDArray[np.float32]:
238231
return self._model_weights
239232

240-
# TODO: The current conversion is inefficient (performance-wise). We should
241-
# consider doing this on the worker side.
242-
def _get_policy_from_perturbation(
243-
self, perturbation: npt.NDArray[np.float32]) -> policy_saver.Policy:
244-
sm = tf.saved_model.load(self._tf_policy_path)
245-
# devectorize the perturbation
246-
policy_utils.set_vectorized_parameters_for_policy(sm, perturbation)
247-
248-
with tempfile.TemporaryDirectory() as tmpdir:
249-
sm_dir = os.path.join(tmpdir, 'sm')
250-
tf.saved_model.save(sm, sm_dir, signatures=sm.signatures)
251-
src = os.path.join(self._tf_policy_path, policy_saver.OUTPUT_SIGNATURE)
252-
dst = os.path.join(sm_dir, policy_saver.OUTPUT_SIGNATURE)
253-
tf.io.gfile.copy(src, dst)
254-
255-
# convert to tflite
256-
tfl_dir = os.path.join(tmpdir, 'tfl')
257-
policy_saver.convert_mlgo_model(sm_dir, tfl_dir)
258-
259-
# create and return policy
260-
policy_obj = policy_saver.Policy.from_filesystem(tfl_dir)
261-
return policy_obj
262-
263233
def run_step(self, pool: FixedWorkerPool) -> None:
264234
"""Run a single step of blackbox learning.
265235
This does not instantaneously return due to several I/O
@@ -275,12 +245,16 @@ def run_step(self, pool: FixedWorkerPool) -> None:
275245
p for p in initial_perturbations for p in (p, -p)
276246
]
277247

278-
perturbations_as_policies = [
279-
self._get_policy_from_perturbation(perturbation)
248+
# TODO(boomanaiden154): This should be adding the perturbation to
249+
# the existing model weights. That currently results in the model
250+
# weights all being NaN, presumably due to rewards not being scaled for
251+
# the regalloc_trace problem.
252+
perturbations_as_bytes = [
253+
perturbation.astype(np.float32).tobytes()
280254
for perturbation in initial_perturbations
281255
]
282256

283-
results = self._evaluator.get_results(pool, perturbations_as_policies)
257+
results = self._evaluator.get_results(pool, perturbations_as_bytes)
284258
rewards = self._evaluator.get_rewards(results)
285259

286260
num_pruned = _prune_skipped_perturbations(initial_perturbations, rewards)

compiler_opt/es/blackbox_learner_test.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@
2424
from tf_agents.networks import actor_distribution_network
2525
from tf_agents.policies import actor_policy
2626

27+
# Pytype cannot pick up the pyi file for tensorflow.summary. Disable the error
28+
# here as these errors are false positives.
29+
# pytype: disable=pyi-error
30+
2731
from compiler_opt.distributed.local import local_worker_manager
2832
from compiler_opt.es import blackbox_learner
2933
from compiler_opt.es import policy_utils
@@ -124,7 +128,6 @@ def _policy_saver_fn(parameters: npt.NDArray[np.float32],
124128
extra_params=None,
125129
step_size=1),
126130
train_corpus=self._cps,
127-
tf_policy_path=os.path.join(policy_save_path, policy_name),
128131
output_dir=output_dir,
129132
policy_saver_fn=_policy_saver_fn,
130133
model_weights=init_params,

compiler_opt/es/blackbox_test_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,10 @@ def __init__(self, arg, *, kwarg):
5454
del kwarg # Unused.
5555
self._function_value = 0.0
5656

57-
def compile_corpus_and_evaluate(
58-
self, modules: Collection[corpus.ModuleSpec], function_index_path: str,
59-
bb_trace_path: str, tflite_policy: policy_saver.Policy | None) -> float:
60-
if modules and function_index_path and bb_trace_path and tflite_policy:
57+
def compile_corpus_and_evaluate(self, modules: Collection[corpus.ModuleSpec],
58+
function_index_path: str, bb_trace_path: str,
59+
policy_as_bytes: bytes | None) -> float:
60+
if modules and function_index_path and bb_trace_path and policy_as_bytes:
6161
self._function_value += 1
6262
return self._function_value
6363
else:

compiler_opt/es/es_trainer_lib.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,17 @@
2020
import tensorflow as tf
2121
import os
2222

23+
# Pytype cannot pick up the pyi file for tensorflow.summary. Disable the error
24+
# here as these errors are false positives.
25+
# pytype: disable=pyi-error
26+
2327
from compiler_opt.distributed.local import local_worker_manager
2428
from compiler_opt.es import blackbox_optimizers
2529
from compiler_opt.es import gradient_ascent_optimization_algorithms
2630
from compiler_opt.es import blackbox_learner
2731
from compiler_opt.es import policy_utils
28-
from compiler_opt.rl import policy_saver
2932
from compiler_opt.rl import corpus
3033

31-
POLICY_NAME = "policy"
32-
3334
FLAGS = flags.FLAGS
3435

3536
_GRAD_REG_ALPHA = flags.DEFINE_float(
@@ -79,11 +80,6 @@ def train(additional_compilation_flags=(),
7980

8081
# Construct the policy and upload it
8182
policy = policy_utils.create_actor_policy()
82-
saver = policy_saver.PolicySaver({POLICY_NAME: policy})
83-
84-
# Save the policy
85-
policy_save_path = os.path.join(_OUTPUT_PATH.value, "policy")
86-
saver.save(policy_save_path)
8783

8884
# Get initial parameter
8985
if not _PRETRAINED_POLICY_PATH.value:
@@ -201,7 +197,6 @@ def train(additional_compilation_flags=(),
201197
learner = blackbox_learner.BlackboxLearner(
202198
blackbox_opt=blackbox_optimizer,
203199
train_corpus=cps,
204-
tf_policy_path=os.path.join(policy_save_path, POLICY_NAME),
205200
output_dir=_OUTPUT_PATH.value,
206201
policy_saver_fn=policy_saver_function,
207202
model_weights=init_current_input,
@@ -216,8 +211,10 @@ def train(additional_compilation_flags=(),
216211
logging.info("Ready to train: running for %d steps.",
217212
learner_config.total_steps)
218213

219-
with worker_manager_class(worker_class,
220-
learner_config.total_num_perturbations) as pool:
214+
with worker_manager_class(
215+
worker_class,
216+
learner_config.total_num_perturbations,
217+
worker_kwargs=dict(gin_config=gin.operative_config_str())) as pool:
221218
for _ in range(learner_config.total_steps):
222219
learner.run_step(pool)
223220

compiler_opt/es/regalloc_trace/regalloc_trace_worker.py

Lines changed: 38 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,14 @@
2626
import json
2727
import concurrent.futures
2828
import tempfile
29+
import shutil
2930

3031
import gin
3132

3233
from compiler_opt.rl import corpus
3334
from compiler_opt.distributed import worker
3435
from compiler_opt.rl import policy_saver
36+
from compiler_opt.es import policy_utils
3537

3638

3739
@gin.configurable
@@ -44,8 +46,16 @@ class RegallocTraceWorker(worker.Worker):
4446
segments.
4547
"""
4648

47-
def __init__(self, clang_path: str, basic_block_trace_model_path: str,
48-
thread_count: int, corpus_path: str):
49+
def _setup_base_policy(self):
50+
self._tf_base_temp_dir = tempfile.mkdtemp()
51+
policy = policy_utils.create_actor_policy()
52+
saver = policy_saver.PolicySaver({"policy": policy})
53+
saver.save(self._tf_base_temp_dir)
54+
self._tf_base_policy_path = os.path.join(self._tf_base_temp_dir, "policy")
55+
56+
def __init__(self, *, gin_config: str, clang_path: str,
57+
basic_block_trace_model_path: str, thread_count: int,
58+
corpus_path: str):
4959
"""Initializes the RegallocTraceWorker class.
5060
5161
Args:
@@ -64,6 +74,16 @@ def __init__(self, clang_path: str, basic_block_trace_model_path: str,
6474
self._thread_count = thread_count
6575
self._corpus_path = corpus_path
6676

77+
gin.parse_config(gin_config)
78+
self._setup_base_policy()
79+
80+
# Deletion here is best effort as it occurs at GC time. If the shutdown is
81+
# forced, cleanup might not happen as expected. This does not matter too
82+
# much though as resource leakage will be small, and any cloud setups will
83+
# have tempdirs wiped periodically.
84+
def __del__(self):
85+
shutil.rmtree(self._tf_base_temp_dir)
86+
6787
def _compile_module(self, module_to_compile: corpus.ModuleSpec,
6888
output_directory: str, tflite_policy_path: str | None):
6989
command_vector = [self._clang_path]
@@ -97,20 +117,13 @@ def _compile_module(self, module_to_compile: corpus.ModuleSpec,
97117
subprocess.run(command_vector, check=True, capture_output=True)
98118

99119
def _build_corpus(self, modules: Collection[corpus.ModuleSpec],
100-
output_directory: str,
101-
tflite_policy: policy_saver.Policy | None):
102-
with tempfile.TemporaryDirectory() as tflite_policy_dir:
103-
if tflite_policy:
104-
tflite_policy.to_filesystem(tflite_policy_dir)
105-
else:
106-
tflite_policy_dir = None
107-
108-
with concurrent.futures.ThreadPoolExecutor(
109-
max_workers=self._thread_count) as thread_pool:
110-
compile_futures = [
111-
thread_pool.submit(self._compile_module, module, output_directory,
112-
tflite_policy_dir) for module in modules
113-
]
120+
output_directory: str, tflite_policy_path: str | None):
121+
with concurrent.futures.ThreadPoolExecutor(
122+
max_workers=self._thread_count) as thread_pool:
123+
compile_futures = [
124+
thread_pool.submit(self._compile_module, module, output_directory,
125+
tflite_policy_path) for module in modules
126+
]
114127

115128
for future in compile_futures:
116129
if future.exception() is not None:
@@ -158,11 +171,16 @@ def _evaluate_corpus(self, module_directory: str, function_index_path: str,
158171

159172
return segment_costs
160173

161-
def compile_corpus_and_evaluate(
162-
self, modules: Collection[corpus.ModuleSpec], function_index_path: str,
163-
bb_trace_path: str, tflite_policy: policy_saver.Policy | None) -> float:
174+
def compile_corpus_and_evaluate(self, modules: Collection[corpus.ModuleSpec],
175+
function_index_path: str, bb_trace_path: str,
176+
policy_as_bytes: bytes | None) -> float:
164177
with tempfile.TemporaryDirectory() as compilation_dir:
165-
self._build_corpus(modules, compilation_dir, tflite_policy)
178+
tflite_policy_path = None
179+
if policy_as_bytes is not None:
180+
tflite_policy_path = policy_utils.convert_to_tflite(
181+
policy_as_bytes, compilation_dir, self._tf_base_policy_path)
182+
183+
self._build_corpus(modules, compilation_dir, tflite_policy_path)
166184

167185
segment_costs = self._evaluate_corpus(compilation_dir,
168186
function_index_path, bb_trace_path)

0 commit comments

Comments
 (0)