Skip to content

Commit 25af176

Browse files
authored
Define a problem config (#158)
1 parent 18c0f77 commit 25af176

File tree

5 files changed

+209
-0
lines changed

5 files changed

+209
-0
lines changed
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
2+
import gin
3+
4+
from compiler_opt.rl import problem_configuration
5+
from compiler_opt.rl.regalloc import config
6+
from compiler_opt.rl.regalloc_priority import regalloc_priority_runner
7+
8+
9+
@gin.register(module='configs')
10+
class RegallocPriorityConfig(problem_configuration.ProblemConfiguration):
11+
def get_runner(self, *args, **kwargs):
12+
return regalloc_priority_runner.RegAllocPriorityRunner(*args, **kwargs)
13+
14+
def get_signature_spec(self):
15+
return config.get_regalloc_signature_spec()
16+
17+
def get_preprocessing_layer_creator(self):
18+
return config.get_observation_processing_layer_creator()
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import gin
2+
import tensorflow as tf
3+
from tf_agents.specs import tensor_spec
4+
from tf_agents.trajectories import time_step
5+
from compiler_opt.rl import feature_ops
6+
7+
8+
@gin.configurable()
9+
def get_regalloc_signature_spec():
10+
observation_spec = dict(
11+
(key, tf.TensorSpec(dtype=tf.int64, shape=(), name=key))
12+
for key in ('li_size', 'stage'))
13+
observation_spec['weight'] = tf.TensorSpec(dtype=tf.float32, shape=(), name='weight')
14+
15+
reward_spec = tf.TensorSpec(dtype=tf.float32, shape=(), name='reward')
16+
time_step_spec = time_step.time_step_spec(observation_spec, reward_spec)
17+
18+
action_spec = tensor_spec.TensorSpec(
19+
dtype=tf.float32,
20+
shape=(),
21+
name='priority'
22+
)
23+
24+
return time_step_spec, action_spec
25+
26+
27+
@gin.configurable
28+
def get_observation_processing_layer_creator():
29+
def observation_processing_layer(obs_spec):
30+
"""Creates the layer to process observation given obs_spec."""
31+
32+
if obs_spec.name in ('li_size', 'stage', 'weight'):
33+
return tf.keras.layers.Lambda(feature_ops.identity_fn)
34+
35+
# Make sure all features have a preprocessing function.
36+
raise KeyError('Missing preprocessing function for some feature.')
37+
38+
return observation_processing_layer
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
config_registry.get_configuration.implementation=@configs.RegallocPriorityConfig
2+
3+
launcher_path=None
4+
clang_path=None
5+
6+
runners.RegAllocPriorityRunner.clang_path=%clang_path
7+
runners.RegAllocPriorityRunner.launcher_path=%launcher_path
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
import gin.tf.external_configurables
2+
import compiler_opt.rl.constant
3+
import compiler_opt.rl.constant_value_network
4+
import compiler_opt.rl.gin_external_configurables
5+
import compiler_opt.rl.regalloc_priority.config
6+
import compiler_opt.rl.regalloc_network
7+
import tf_agents.agents.ppo.ppo_agent
8+
import tf_agents.networks.actor_distribution_network
9+
10+
include 'compiler_opt/rl/regalloc_priority/gin_configs/common.gin'
11+
12+
train_eval.agent_name=%constant.AgentName.PPO
13+
train_eval.warmstart_policy_dir=''
14+
train_eval.num_policy_iterations=3000
15+
train_eval.num_iterations=200
16+
train_eval.batch_size=256
17+
train_eval.train_sequence_length=16
18+
train_eval.deploy_policy_name='saved_collect_policy'
19+
train_eval.moving_average_decay_rate=0.8
20+
train_eval.use_random_network_distillation=False
21+
22+
#######################################
23+
# Turn on if using train_with_rpc.py
24+
# train_eval.additional_compilation_flags=()
25+
#######################################
26+
27+
# RandomNetworkDistillation configs, off if train_eval.use_random_network_distillation=False.
28+
RandomNetworkDistillation.encoding_network = @regalloc_network.RegAllocRNDEncodingNetwork
29+
RandomNetworkDistillation.learning_rate = 1e-4
30+
RandomNetworkDistillation.update_frequency = 2
31+
RandomNetworkDistillation.fc_layer_params = [32, 128]
32+
RandomNetworkDistillation.initial_intrinsic_reward_scale = 1.0
33+
RandomNetworkDistillation.half_decay_steps = 10000
34+
35+
create_agent.policy_network = @actor_distribution_network.ActorDistributionNetwork
36+
37+
ActorDistributionNetwork.preprocessing_combiner=@tf.keras.layers.Concatenate()
38+
ActorDistributionNetwork.fc_layer_params=(40, 40, 20)
39+
ActorDistributionNetwork.dropout_layer_params=None
40+
ActorDistributionNetwork.activation_fn=@tf.keras.activations.relu
41+
42+
NormalProjectionNetwork.mean_transform=None
43+
44+
ConstantValueNetwork.constant_output_val=0
45+
46+
tf.train.AdamOptimizer.learning_rate = 0.0003
47+
tf.train.AdamOptimizer.epsilon = 0.0003125
48+
49+
PPOAgent.optimizer = @tf.train.AdamOptimizer()
50+
PPOAgent.importance_ratio_clipping = 0.2
51+
PPOAgent.lambda_value = 0.0
52+
PPOAgent.discount_factor = 0.0
53+
PPOAgent.entropy_regularization = 0.005
54+
PPOAgent.policy_l2_reg = 0.00001
55+
PPOAgent.value_function_l2_reg = 0.0
56+
PPOAgent.shared_vars_l2_reg = 0.0
57+
PPOAgent.value_pred_loss_coef = 0.0
58+
PPOAgent.num_epochs = 1
59+
PPOAgent.use_gae = False
60+
PPOAgent.use_td_lambda_return = False
61+
PPOAgent.normalize_rewards = False
62+
PPOAgent.reward_norm_clipping = 10.0
63+
PPOAgent.normalize_observations = False
64+
PPOAgent.log_prob_clipping = 0.0
65+
PPOAgent.kl_cutoff_factor = 2.0
66+
PPOAgent.kl_cutoff_coef = 1000.0
67+
PPOAgent.initial_adaptive_kl_beta = 1.0
68+
PPOAgent.adaptive_kl_target = 0.01
69+
PPOAgent.adaptive_kl_tolerance = 0.3
70+
PPOAgent.gradient_clipping = None
71+
PPOAgent.value_clipping = None
72+
PPOAgent.check_numerics = False
73+
PPOAgent.compute_value_and_advantage_in_train = True
74+
PPOAgent.update_normalizers_in_train=True
75+
PPOAgent.debug_summaries = True
76+
PPOAgent.summarize_grads_and_vars = True
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import gin
2+
import tensorflow as tf
3+
4+
import base64
5+
import io
6+
import os
7+
import tempfile
8+
from typing import Dict, Optional, Tuple
9+
from absl import logging
10+
11+
from google.protobuf import struct_pb2
12+
from compiler_opt.rl import compilation_runner
13+
14+
15+
@gin.configurable(module='runners')
16+
class RegAllocPriorityRunner(compilation_runner.CompilationRunner):
17+
def _compile_fn(
18+
self, file_paths: Tuple[str, ...], tf_policy_path: str, reward_only: bool,
19+
cancellation_manager: Optional[
20+
compilation_runner.WorkerCancellationManager]
21+
) -> Dict[str, Tuple[tf.train.SequenceExample, float]]:
22+
23+
file_paths = file_paths[0].replace('.bc', '')
24+
working_dir = tempfile.mkdtemp()
25+
26+
log_path = os.path.join(working_dir, 'log')
27+
output_native_path = os.path.join(working_dir, 'native')
28+
29+
result = {}
30+
try:
31+
command_line = []
32+
if self._launcher_path:
33+
command_line.append(self._launcher_path)
34+
command_line.extend([self._clang_path] + [
35+
'-c', file_paths, '-O3',
36+
'-mllvm', '-regalloc-priority-training-log=' + log_path,
37+
'-mllvm', '-regalloc-enable-priority-advisor=development',
38+
'-o', output_native_path
39+
])
40+
41+
if tf_policy_path:
42+
command_line.extend(['-mllvm', '-regalloc-priority-model=' + tf_policy_path])
43+
compilation_runner.start_cancellable_process(command_line,
44+
self._compilation_timeout,
45+
cancellation_manager)
46+
47+
sequence_example = struct_pb2.Struct()
48+
49+
with io.open(log_path, 'rb') as f:
50+
sequence_example.ParseFromString(f.read())
51+
52+
for key, value in sequence_example.fields.items():
53+
e = tf.train.SequenceExample()
54+
e.ParseFromString(base64.b64decode(value.string_value))
55+
print(e)
56+
if not e.HasField('feature_lists'):
57+
continue
58+
r = (
59+
e.feature_lists.feature_list['reward'].feature[-1].float_list
60+
.value[0])
61+
if reward_only:
62+
result[key] = (None, r)
63+
else:
64+
del e.feature_lists.feature_list['reward']
65+
result[key] = (e, r)
66+
67+
finally:
68+
tf.io.gfile.rmtree(working_dir)
69+
70+
return result

0 commit comments

Comments
 (0)