Skip to content

Commit 00129b5

Browse files
authored
fix regalloc priority prediction lint (#159)
* lint * more fix
1 parent 25af176 commit 00129b5

File tree

3 files changed

+121
-72
lines changed

3 files changed

+121
-72
lines changed
Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,18 @@
1+
# coding=utf-8
2+
# Copyright 2020 Google LLC
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""Implementation of the 'regalloc priority prediction' problem."""
116

217
import gin
318

@@ -8,11 +23,12 @@
823

924
@gin.register(module='configs')
1025
class RegallocPriorityConfig(problem_configuration.ProblemConfiguration):
11-
def get_runner(self, *args, **kwargs):
12-
return regalloc_priority_runner.RegAllocPriorityRunner(*args, **kwargs)
1326

14-
def get_signature_spec(self):
15-
return config.get_regalloc_signature_spec()
27+
def get_runner(self, *args, **kwargs):
28+
return regalloc_priority_runner.RegAllocPriorityRunner(*args, **kwargs)
1629

17-
def get_preprocessing_layer_creator(self):
18-
return config.get_observation_processing_layer_creator()
30+
def get_signature_spec(self):
31+
return config.get_regalloc_signature_spec()
32+
33+
def get_preprocessing_layer_creator(self):
34+
return config.get_observation_processing_layer_creator()
Lines changed: 34 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,19 @@
1+
# coding=utf-8
2+
# Copyright 2020 Google LLC
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""RegAlloc priority prediction training config."""
16+
117
import gin
218
import tensorflow as tf
319
from tf_agents.specs import tensor_spec
@@ -7,32 +23,31 @@
723

824
@gin.configurable()
925
def get_regalloc_signature_spec():
10-
observation_spec = dict(
11-
(key, tf.TensorSpec(dtype=tf.int64, shape=(), name=key))
12-
for key in ('li_size', 'stage'))
13-
observation_spec['weight'] = tf.TensorSpec(dtype=tf.float32, shape=(), name='weight')
26+
observation_spec = dict(
27+
(key, tf.TensorSpec(dtype=tf.int64, shape=(), name=key))
28+
for key in ('li_size', 'stage'))
29+
observation_spec['weight'] = tf.TensorSpec(
30+
dtype=tf.float32, shape=(), name='weight')
1431

15-
reward_spec = tf.TensorSpec(dtype=tf.float32, shape=(), name='reward')
16-
time_step_spec = time_step.time_step_spec(observation_spec, reward_spec)
32+
reward_spec = tf.TensorSpec(dtype=tf.float32, shape=(), name='reward')
33+
time_step_spec = time_step.time_step_spec(observation_spec, reward_spec)
1734

18-
action_spec = tensor_spec.TensorSpec(
19-
dtype=tf.float32,
20-
shape=(),
21-
name='priority'
22-
)
35+
action_spec = tensor_spec.TensorSpec(
36+
dtype=tf.float32, shape=(), name='priority')
2337

24-
return time_step_spec, action_spec
38+
return time_step_spec, action_spec
2539

2640

2741
@gin.configurable
2842
def get_observation_processing_layer_creator():
29-
def observation_processing_layer(obs_spec):
30-
"""Creates the layer to process observation given obs_spec."""
3143

32-
if obs_spec.name in ('li_size', 'stage', 'weight'):
33-
return tf.keras.layers.Lambda(feature_ops.identity_fn)
44+
def observation_processing_layer(obs_spec):
45+
"""Creates the layer to process observation given obs_spec."""
46+
47+
if obs_spec.name in ('li_size', 'stage', 'weight'):
48+
return tf.keras.layers.Lambda(feature_ops.identity_fn)
3449

35-
# Make sure all features have a preprocessing function.
36-
raise KeyError('Missing preprocessing function for some feature.')
50+
# Make sure all features have a preprocessing function.
51+
raise KeyError('Missing preprocessing function for some feature.')
3752

38-
return observation_processing_layer
53+
return observation_processing_layer
Lines changed: 65 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,19 @@
1+
# coding=utf-8
2+
# Copyright 2020 Google LLC
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""Module for collect data of regalloc priority prediction."""
16+
117
import gin
218
import tensorflow as tf
319

@@ -6,65 +22,67 @@
622
import os
723
import tempfile
824
from typing import Dict, Optional, Tuple
9-
from absl import logging
1025

11-
from google.protobuf import struct_pb2
26+
# This is https://github.com/google/pytype/issues/764
27+
from google.protobuf import struct_pb2 # pytype: disable=pyi-error
1228
from compiler_opt.rl import compilation_runner
1329

1430

1531
@gin.configurable(module='runners')
1632
class RegAllocPriorityRunner(compilation_runner.CompilationRunner):
17-
def _compile_fn(
18-
self, file_paths: Tuple[str, ...], tf_policy_path: str, reward_only: bool,
19-
cancellation_manager: Optional[
20-
compilation_runner.WorkerCancellationManager]
21-
) -> Dict[str, Tuple[tf.train.SequenceExample, float]]:
33+
"""Class for collecting data for regalloc-priority-prediction."""
34+
35+
def _compile_fn(
36+
self, file_paths: Tuple[str, ...], tf_policy_path: str, reward_only: bool,
37+
cancellation_manager: Optional[
38+
compilation_runner.WorkerCancellationManager]
39+
) -> Dict[str, Tuple[tf.train.SequenceExample, float]]:
2240

23-
file_paths = file_paths[0].replace('.bc', '')
24-
working_dir = tempfile.mkdtemp()
41+
file_paths = file_paths[0].replace('.bc', '')
42+
working_dir = tempfile.mkdtemp()
2543

26-
log_path = os.path.join(working_dir, 'log')
27-
output_native_path = os.path.join(working_dir, 'native')
44+
log_path = os.path.join(working_dir, 'log')
45+
output_native_path = os.path.join(working_dir, 'native')
2846

29-
result = {}
30-
try:
31-
command_line = []
32-
if self._launcher_path:
33-
command_line.append(self._launcher_path)
34-
command_line.extend([self._clang_path] + [
35-
'-c', file_paths, '-O3',
36-
'-mllvm', '-regalloc-priority-training-log=' + log_path,
37-
'-mllvm', '-regalloc-enable-priority-advisor=development',
38-
'-o', output_native_path
39-
])
47+
result = {}
48+
try:
49+
command_line = []
50+
if self._launcher_path:
51+
command_line.append(self._launcher_path)
52+
command_line.extend([self._clang_path] + [
53+
'-c', file_paths, '-O3', '-mllvm', '-regalloc-priority-training-log='
54+
+ log_path, '-mllvm', '-regalloc-enable-priority-advisor=development',
55+
'-o', output_native_path
56+
])
4057

41-
if tf_policy_path:
42-
command_line.extend(['-mllvm', '-regalloc-priority-model=' + tf_policy_path])
43-
compilation_runner.start_cancellable_process(command_line,
44-
self._compilation_timeout,
45-
cancellation_manager)
58+
if tf_policy_path:
59+
command_line.extend(
60+
['-mllvm', '-regalloc-priority-model=' + tf_policy_path])
61+
compilation_runner.start_cancellable_process(command_line,
62+
self._compilation_timeout,
63+
cancellation_manager)
4664

47-
sequence_example = struct_pb2.Struct()
65+
sequence_example = struct_pb2.Struct()
4866

49-
with io.open(log_path, 'rb') as f:
50-
sequence_example.ParseFromString(f.read())
67+
with io.open(log_path, 'rb') as f:
68+
sequence_example.ParseFromString(f.read())
5169

52-
for key, value in sequence_example.fields.items():
53-
e = tf.train.SequenceExample()
54-
e.ParseFromString(base64.b64decode(value.string_value))
55-
print(e)
56-
if not e.HasField('feature_lists'):
57-
continue
58-
r = (
59-
e.feature_lists.feature_list['reward'].feature[-1].float_list
60-
.value[0])
61-
if reward_only:
62-
result[key] = (None, r)
63-
else:
64-
del e.feature_lists.feature_list['reward']
65-
result[key] = (e, r)
70+
for key, value in sequence_example.fields.items():
71+
e = tf.train.SequenceExample()
72+
e.ParseFromString(base64.b64decode(value.string_value))
73+
print(e)
74+
if not e.HasField('feature_lists'):
75+
continue
76+
r = (
77+
e.feature_lists.feature_list['reward'].feature[-1].float_list
78+
.value[0])
79+
if reward_only:
80+
result[key] = (None, r)
81+
else:
82+
del e.feature_lists.feature_list['reward']
83+
result[key] = (e, r)
6684

67-
finally:
68-
tf.io.gfile.rmtree(working_dir)
85+
finally:
86+
tf.io.gfile.rmtree(working_dir)
6987

70-
return result
88+
return result

0 commit comments

Comments
 (0)