Skip to content

Commit 0d15a32

Browse files
added config options for vocab generation (#30)
1 parent a525b3e commit 0d15a32

File tree

8 files changed

+68
-20
lines changed

8 files changed

+68
-20
lines changed

compiler_opt/rl/inlining/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,6 @@ def get_signature_spec(self):
3333

3434
def get_preprocessing_layer_creator(self):
3535
return config.get_observation_processing_layer_creator()
36+
37+
def get_nonnormalized_features(self):
38+
return config.get_nonnormalized_features()

compiler_opt/rl/inlining/config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,3 +97,7 @@ def observation_processing_layer(obs_spec):
9797
with_z_score_normalization, eps))
9898

9999
return observation_processing_layer
100+
101+
def get_nonnormalized_features():
102+
return ['reward', 'inlining_default', 'inlining_decision']
103+

compiler_opt/rl/problem_configuration.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,10 +93,12 @@ def get_preprocessing_layer_creator(
9393
self) -> Callable[[types.TensorSpec], tf.keras.layers.Layer]:
9494
raise NotImplementedError
9595

96+
def get_nonnormalized_features(self) -> Iterable[str]:
97+
return []
98+
9699
@abc.abstractmethod
97100
def get_runner_type(self) -> 'type[compilation_runner.CompilationRunner]':
98101
raise NotImplementedError
99102

100-
101103
def is_thinlto(module_paths: Iterable[str]) -> bool:
102104
return tf.io.gfile.exists(next(iter(module_paths)) + '.thinlto.bc')

compiler_opt/rl/regalloc/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,6 @@ def get_signature_spec(self):
3333

3434
def get_preprocessing_layer_creator(self):
3535
return config.get_observation_processing_layer_creator()
36+
37+
def get_nonnormalized_features(self):
38+
return config.get_nonnormalized_features()

compiler_opt/rl/regalloc/config.py

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -88,22 +88,26 @@ def observation_processing_layer(obs_spec):
8888
if obs_spec.name in ('max_stage', 'min_stage'):
8989
return tf.keras.layers.Embedding(7, 4)
9090

91-
quantile = quantile_map[obs_spec.name]
92-
93-
first_non_zero = 0
94-
for x in quantile:
95-
if x > 0:
96-
first_non_zero = x
97-
break
98-
99-
normalize_fn = feature_ops.get_normalize_fn(quantile, with_sqrt,
100-
with_z_score_normalization, eps)
101-
log_normalize_fn = feature_ops.get_normalize_fn(
102-
quantile,
103-
with_sqrt,
104-
with_z_score_normalization,
105-
eps,
106-
preprocessing_fn=lambda x: tf.math.log(x + first_non_zero))
91+
normalize_fn = log_normalize_fn = None
92+
if obs_spec.name not in get_nonnormalized_features():
93+
quantile = quantile_map[obs_spec.name]
94+
95+
first_non_zero = 0
96+
for x in quantile:
97+
if x > 0:
98+
first_non_zero = x
99+
break
100+
101+
normalize_fn = feature_ops.get_normalize_fn(quantile,
102+
with_sqrt,
103+
with_z_score_normalization,
104+
eps)
105+
log_normalize_fn = feature_ops.get_normalize_fn(
106+
quantile,
107+
with_sqrt,
108+
with_z_score_normalization,
109+
eps,
110+
preprocessing_fn=lambda x: tf.math.log(x + first_non_zero))
107111

108112
if obs_spec.name in ['nr_rematerializable', 'nr_broken_hints']:
109113
return tf.keras.layers.Lambda(normalize_fn)
@@ -137,3 +141,9 @@ def progress_processing_fn(obs):
137141
raise KeyError('Missing preprocessing function for some feature.')
138142

139143
return observation_processing_layer
144+
145+
def get_nonnormalized_features():
146+
return ['mask', 'nr_urgent',
147+
'is_hint', 'is_local',
148+
'is_free', 'max_stage',
149+
'min_stage', 'reward']

compiler_opt/tools/sparse_bucket_generator.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,18 @@
2121
import math
2222
import multiprocessing as mp
2323
import os
24-
from typing import Callable, Dict, List
24+
from typing import Callable, Dict, List, Iterable
2525

2626
from absl import app
2727
from absl import flags
2828
from absl import logging
29+
import gin
2930

3031
import numpy as np
3132
import tensorflow as tf
3233

34+
from compiler_opt.rl import registry
35+
3336
flags.DEFINE_string('input', None,
3437
'Path to input file containing tf record datasets.')
3538
flags.DEFINE_string('output_dir', None,
@@ -41,16 +44,23 @@
4144
'Each process does vocab generation for each feature.', 1)
4245
flags.DEFINE_integer('num_buckets', 1000,
4346
'Number of quantiles to bucketize feature values into.')
47+
flags.DEFINE_multi_string('gin_files', [],
48+
'List of paths to gin configuration files.')
49+
flags.DEFINE_multi_string(
50+
'gin_bindings', [],
51+
'Gin bindings to override the values set in the config files.')
4452

4553
FLAGS = flags.FLAGS
4654

4755

4856
def _get_feature_info(
49-
serialized_proto: tf.Tensor) -> Dict[str, tf.io.RaggedFeature]:
57+
serialized_proto: tf.Tensor,
58+
features_to_not_process: Iterable[str]) -> Dict[str, tf.io.RaggedFeature]:
5059
"""Provides feature information by analyzing a single serialized example.
5160
5261
Args:
5362
serialized_proto: serialized SequenceExample.
63+
features_to_not_process: A list of feature names that should not be processed
5464
5565
Returns:
5666
Dictionary of Tensor formats indexed by feature name.
@@ -59,6 +69,8 @@ def _get_feature_info(
5969
example.ParseFromString(serialized_proto.numpy())
6070
sequence_features = {}
6171
for key, feature_list in example.feature_lists.feature_list.items():
72+
if key in features_to_not_process:
73+
continue
6274
feature = feature_list.feature[0]
6375
kind = feature.WhichOneof('kind')
6476
if kind == 'float_list':
@@ -123,17 +135,23 @@ def _generate_vocab(feature_values_arrays, feature_name):
123135

124136

125137
def main(_) -> None:
138+
gin.parse_config_files_and_bindings(
139+
FLAGS.gin_files, bindings=FLAGS.gin_bindings, skip_unknown=False)
140+
logging.info(gin.config_str())
141+
problem_config = registry.get_configuration()
142+
126143
"""Generate num_buckets quantiles for each feature."""
127144
tf.io.gfile.makedirs(FLAGS.output_dir)
128145
dataset = tf.data.Dataset.list_files(FLAGS.input)
129146
dataset = tf.data.TFRecordDataset(dataset)
147+
features_to_not_process = problem_config.get_nonnormalized_features()
130148

131149
sequence_features = {}
132150
# TODO(b/222775595): need to fix this after update to logic for handling
133151
# empty examples during trace generation.
134152
for raw_example in dataset:
135153
try:
136-
sequence_features = _get_feature_info(raw_example)
154+
sequence_features = _get_feature_info(raw_example, features_to_not_process)
137155
logging.info('Found valid sequence_features dict: %s', sequence_features)
138156
break
139157
except IndexError:

docs/adding_features.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,13 @@ First and foremost, **you must regenerate the vocabulary** - technically you
2929
just need a vocab file for the new feature, but it's simpler to regenerate it
3030
all. See the [demo section](demo/demo.md#collect-trace-and-generate-vocab)
3131

32+
**Note:** You only need to regenerate the vocabulary if the feature is going
33+
to be normalized by a preprocessing layer for your model. If your feature does
34+
not need to get put through a lambda normalization preprocessing layer, make sure
35+
to regenerate the vocabulary and that your feature is added to the list
36+
returned by `get_nonnormalized_features()` in `config.py`. In either case,
37+
it is still quite simple and fast to just call the vocab generation again.
38+
3239
After that, retrain from [scratch](demo/demo.md#train-a-new-model).
3340

3441
## Notes

docs/demo/demo.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,7 @@ in the trace changes.
266266
rm -rf $DEFAULT_VOCAB &&
267267
PYTHONPATH=$PYTHONPATH:. python3 \
268268
compiler_opt/tools/sparse_bucket_generator.py \
269+
--gin_files=compiler_opt/rl/inlining/gin_configs/common.gin \
269270
--input=$DEFAULT_TRACE \
270271
--output_dir=$DEFAULT_VOCAB
271272
```

0 commit comments

Comments
 (0)