Skip to content

Commit 185b385

Browse files
committed
[unroll] Initial commit to add the loop-unroll problem into MLGO
1 parent 16eb08b commit 185b385

File tree

8 files changed

+354
-2
lines changed

8 files changed

+354
-2
lines changed

compiler_opt/rl/compilation_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
import tensorflow as tf
3131

3232
_COMPILATION_TIMEOUT = flags.DEFINE_integer(
33-
'compilation_timeout', 60,
33+
'compilation_timeout', 120,
3434
'Max duration (in seconds) after which we cancel any compilation job.')
3535
_QUIET = flags.DEFINE_bool(
3636
'quiet', True, 'Whether or not to compile quietly (hiding info logging)')

compiler_opt/rl/registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
# to trigger gin registration.
3434
import compiler_opt.rl.inlining # pylint: disable=unused-import
3535
import compiler_opt.rl.regalloc # pylint: disable=unused-import
36+
import compiler_opt.rl.unroll # pylint: disable=unused-import
3637

3738
types = tfa.typing.types
3839

compiler_opt/rl/unroll/__init__.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# coding=utf-8
2+
# Copyright 2020 Google LLC
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""Implementation of the 'loop unroll' problem."""
16+
17+
import gin
18+
19+
from compiler_opt.rl import problem_configuration
20+
from compiler_opt.rl.unroll import config
21+
from compiler_opt.rl.unroll import unroll_runner
22+
23+
24+
@gin.register(module='configs')
25+
class LoopUnrollConfig(problem_configuration.ProblemConfiguration):
26+
"""Expose the regalloc eviction components."""
27+
28+
def get_runner_type(self):
29+
return unroll_runner.LoopUnrollRunner
30+
31+
def get_signature_spec(self):
32+
return config.get_unroll_signature_spec()
33+
34+
def get_preprocessing_layer_creator(self):
35+
return config.get_observation_processing_layer_creator()
36+
37+
def get_nonnormalized_features(self):
38+
return config.get_nonnormalized_features()

compiler_opt/rl/unroll/config.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# coding=utf-8
2+
# Copyright 2020 Google LLC
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""Loop unroll training config."""
16+
17+
import gin
18+
import tensorflow as tf
19+
from tf_agents.specs import tensor_spec
20+
from tf_agents.trajectories import time_step
21+
from compiler_opt.rl import feature_ops
22+
23+
24+
# pylint: disable=g-complex-comprehension
25+
@gin.configurable()
26+
def get_unroll_signature_spec():
27+
"""Returns (time_step_spec, action_spec) for LLVM loop unroll."""
28+
# LINT.IfChange
29+
observation_spec = dict(
30+
(key, tf.TensorSpec(dtype=tf.int64, shape=(), name=key))
31+
for key in ('loop_size', 'trip_count', 'is_innermost_loop',
32+
'preheader_blocksize', 'bb_count', 'num_of_loop_latch',
33+
'load_inst_count', 'store_inst_count', 'logical_inst_count',
34+
'cast_inst_count'))
35+
reward_spec = tf.TensorSpec(dtype=tf.float32, shape=(), name='reward')
36+
time_step_spec = time_step.time_step_spec(observation_spec, reward_spec)
37+
action_spec = tensor_spec.BoundedTensorSpec(
38+
dtype=tf.int64, shape=(), name='unroll_count')
39+
40+
return time_step_spec, action_spec
41+
42+
43+
@gin.configurable
44+
def get_observation_processing_layer_creator(quantile_file_dir=None,
45+
with_sqrt=True,
46+
with_z_score_normalization=True,
47+
eps=1e-8):
48+
"""Wrapper for observation_processing_layer."""
49+
quantile_map = feature_ops.build_quantile_map(quantile_file_dir)
50+
51+
def observation_processing_layer(obs_spec):
52+
"""Creates the layer to process observation given obs_spec."""
53+
54+
# I guess we discard rewards when observation?
55+
if obs_spec.name in ('icache_pressure', 'latency'):
56+
return tf.keras.layers.Lambda(feature_ops.discard_fn)
57+
58+
# for boolean features, use feature_ops.identity_fn
59+
if obs_spec.name in ('is_innermost_loop'):
60+
return tf.keras.layers.Lambda(feature_ops.identity_fn)
61+
62+
# Do we need to define some layer here to normalize 'loop_size'
63+
# and instruction count features (e.g. 'load_inst_count').
64+
# Bigger loops expect more instruction counts, and we need to
65+
# normalize this?
66+
67+
quantile = quantile_map[obs_spec.name]
68+
return tf.keras.layers.Lambda(
69+
feature_ops.get_normalize_fn(quantile, with_sqrt,
70+
with_z_score_normalization, eps))
71+
72+
return observation_processing_layer
73+
74+
75+
def get_nonnormalized_features():
76+
return ['reward', 'is_innermost_loop']
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import gin.tf.external_configurables
2+
import compiler_opt.rl.constant
3+
import compiler_opt.rl.gin_external_configurables
4+
import compiler_opt.rl.unroll.config
5+
import tf_agents.agents.behavioral_cloning.behavioral_cloning_agent
6+
import tf_agents.networks.q_network
7+
8+
include 'compiler_opt/rl/unroll/gin_configs/common.gin'
9+
10+
train_eval.agent_name=%constant.AgentName.BEHAVIORAL_CLONE
11+
train_eval.num_iterations=100000
12+
train_eval.batch_size=64
13+
train_eval.train_sequence_length=1
14+
15+
unroll.config.get_observation_processing_layer_creator.with_sqrt = False
16+
unroll.config.get_observation_processing_layer_creator.with_z_score_normalization = False
17+
18+
create_agent.policy_network = @q_network.QNetwork
19+
20+
21+
QNetwork.fc_layer_params=(40, 40, 20)
22+
QNetwork.dropout_layer_params=(0.2, 0.2, 0.2)
23+
24+
25+
tf.train.AdamOptimizer.learning_rate = 0.001
26+
tf.train.AdamOptimizer.epsilon = 0.0003125
27+
28+
BehavioralCloningAgent.optimizer = @tf.train.AdamOptimizer()
29+
BehavioralCloningAgent.epsilon_greedy = 0.1
30+
BehavioralCloningAgent.gradient_clipping = None
31+
BehavioralCloningAgent.debug_summaries = True
32+
BehavioralCloningAgent.summarize_grads_and_vars = True
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
config_registry.get_configuration.implementation=@configs.LoopUnrollConfig
2+
3+
clang_path=None
4+
llvm_objcopy_path=None
5+
parse_reward_script_path=None
6+
latency_coefficient=None
7+
8+
runners.LoopUnrollRunner.clang_path=%clang_path
9+
runners.LoopUnrollRunner.llvm_objcopy_path=%llvm_objcopy_path
10+
runners.LoopUnrollRunner.parse_reward_script_path=%parse_reward_script_path
11+
runners.LoopUnrollRunner.latency_coefficient=%latency_coefficient
Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
# coding=utf-8
2+
# Copyright 2020 Google LLC
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""Module for collect data of loop unroll."""
16+
17+
import base64
18+
import io
19+
import os
20+
import tempfile
21+
from typing import Dict, Optional, Tuple
22+
23+
import gin
24+
import tensorflow as tf
25+
26+
from google.protobuf import struct_pb2 # pytype: disable=pyi-error
27+
from compiler_opt.rl import compilation_runner
28+
from compiler_opt.rl import corpus
29+
30+
31+
@gin.configurable(module='runners')
32+
class LoopUnrollRunner(compilation_runner.CompilationRunner):
33+
"""Class for collecting data for loop partial unroll.
34+
35+
Usage:
36+
runner = LoopUnrollRunner(
37+
clang_path, llvm_objcopy_path, parse_reward_script_path,
38+
moving_average_decay_rate)
39+
policy_reward = unroll.collect_data(
40+
ir_path, tf_policy_path, default_reward, moving_average_reward)
41+
"""
42+
43+
def __init__(self, llvm_objcopy_path: str, parse_reward_script_path: str,
44+
latency_coefficient: str, *args, **kwargs):
45+
super().__init__(*args, **kwargs)
46+
self._llvm_objcopy_path = llvm_objcopy_path
47+
self._parse_reward_script_path = parse_reward_script_path
48+
self._latency_coefficient = float(latency_coefficient)
49+
50+
def compile_fn(
51+
self, module_spec: corpus.ModuleSpec, tf_policy_path: str,
52+
reward_only: bool, cancellation_manager: Optional[
53+
compilation_runner.WorkerCancellationManager]
54+
) -> Dict[str, Tuple[tf.train.SequenceExample, float]]:
55+
"""Run loop unroll for the given IR file under the given policy.
56+
57+
Args:
58+
module_spec: a ModuleSpec.
59+
tf_policy_path: path to TF policy directory on local disk.
60+
reward_only: whether to only return reward (icache pressure and latency)
61+
cancellation_manager: handler for early termination by killing any running
62+
processes
63+
64+
Returns:
65+
For loop unroll, the result is in module level. IWS and Latency is
66+
already weighted by the probability to be executed, checkout
67+
parse_reward.py and code embedded under AsmPrinter.cpp for more detail).
68+
69+
Since the reward is calculated at late stage in a compiler that is after
70+
inlining some functions may be inlined and not be found for some loops,
71+
so we sum all functions into a single float, reward_total.
72+
73+
The function returns in the format:
74+
{
75+
"loop1_key": (loop1_features, reward_total),
76+
"loop2_key": (loop2_features, reward_total),
77+
...,
78+
"loopN_key": (loopN_features, reward_total)
79+
}
80+
- reward_total: sum of IWS and Latency of all functions in this module
81+
82+
Early return:
83+
The function early returns when the compiled module doesn't record any
84+
logs or the log file doesn't record any loop. This happens when
85+
`LoopUnrollPass` is not triggered or no loop triggered "partial unroll"
86+
in the pass.
87+
"""
88+
working_dir = tempfile.mkdtemp()
89+
90+
# The compiler will log input feature (loop properties) and decision
91+
# (unroll count) into the specified log path
92+
log_path = os.path.join(working_dir, 'log')
93+
94+
# The compilation will generate object files, and our augmentation under
95+
# AsmPrinter.cpp will create section data `llvm_block_data`.
96+
object_path = os.path.join(working_dir, 'object')
97+
# llvm-objcopy extracts the section data from object to data
98+
data_path = os.path.join(working_dir, 'data')
99+
# Reward parsing script parses data into parsed_reward
100+
parsed_reward_path = os.path.join(working_dir, 'parsed_reward')
101+
102+
try:
103+
# Construct command to execute clang
104+
command_line = []
105+
106+
# parameters for MLGO unroll
107+
command_line.extend([self._clang_path] + list(module_spec.exec_cmd) + [
108+
'-mllvm', '-mlgo-unroll-mode=training', '-mllvm',
109+
'-mlgo-unroll-training-log=' +
110+
log_path, '-mllvm', '-calc-reward', '-o', object_path
111+
])
112+
113+
# Under `training mode`...
114+
# If model path is provided, compiler will use ModelUnderTrainingRunner
115+
# Otherwise, compiler will use NoInferenceModelRunner
116+
if tf_policy_path:
117+
command_line.extend(
118+
['-mllvm', 'mlgo-unroll-train-model=' + tf_policy_path])
119+
120+
print('Command to execute clang: ', command_line)
121+
122+
# run clang
123+
compilation_runner.start_cancellable_process(command_line,
124+
self._compilation_timeout,
125+
cancellation_manager)
126+
127+
# A module may not generate a log if none of the loops go into the
128+
# LoopUnroll decision. Early return here if log_path cannot be found.
129+
if not os.path.exists(log_path):
130+
print('Early return, log file not found.')
131+
return {}
132+
133+
# A log file may not have anything inside when none of the loops goes
134+
# into PartialUnroll decision. Early return a log file is created but
135+
# nothing inside.
136+
if os.path.getsize(log_path) == 0:
137+
print('Early return, log file contains nothing.')
138+
return {}
139+
140+
# Run llvm-objcopy to get section data
141+
command_line = [
142+
self._llvm_objcopy_path,
143+
'--dump-section=.llvm_block_data.=' + data_path, object_path
144+
]
145+
print('Command to get section data: ', command_line)
146+
compilation_runner.start_cancellable_process(command_line,
147+
self._compilation_timeout,
148+
cancellation_manager)
149+
150+
# Run parse_reward.py to get reward
151+
command_line = [
152+
self._parse_reward_script_path, data_path, parsed_reward_path
153+
]
154+
print('Command to parse reward: ', command_line)
155+
compilation_runner.start_cancellable_process(command_line,
156+
self._compilation_timeout,
157+
cancellation_manager)
158+
159+
# Sum rewards of all functions into a single float
160+
reward_total = 0
161+
with io.open(parsed_reward_path, 'r', encoding='utf-8') as reward_f:
162+
for line in reward_f.readlines():
163+
line = line[:-1] # strip end-line
164+
items = line.split(',')
165+
assert len(items) == 3
166+
# function_name = items[0] (commented out because currently unused)
167+
iws = float(items[1])
168+
latency = float(items[2])
169+
reward_total = reward_total + (
170+
iws + latency * self._latency_coefficient)
171+
172+
if reward_only:
173+
return {'default': (None, reward_total)}
174+
175+
result = {}
176+
177+
# Read training log, fill them in to result.
178+
sequence_examples = struct_pb2.Struct()
179+
with io.open(log_path, 'rb') as log_f:
180+
sequence_examples.ParseFromString(log_f.read())
181+
182+
for key, value in sequence_examples.fields.items():
183+
entry = tf.train.SequenceExample()
184+
entry.ParseFromString(base64.b64decode(value.string_value))
185+
186+
if not entry.HasField('feature_lists'):
187+
continue
188+
189+
result[key] = (entry, reward_total)
190+
191+
finally:
192+
tf.io.gfile.rmtree(working_dir)
193+
194+
return result

compiler_opt/tools/sparse_bucket_generator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ def main(_) -> None:
170170
parser_fn = create_tfrecord_parser_fn(sequence_features)
171171
dataset = dataset.map(parser_fn, num_parallel_calls=tf.data.AUTOTUNE)
172172
data_list = np.array(list(dataset.as_numpy_iterator()), dtype=object)
173-
data_list = np.transpose(data_list, [1, 0])
173+
data_list = np.transpose(data_list, [1, 0, 2])
174174

175175
with mp.Pool(FLAGS.parallelism) as pool:
176176
feature_names = list(sorted(sequence_features))

0 commit comments

Comments
 (0)