Skip to content

Commit 87104f3

Browse files
authored
Add caching to input pipeline (#34)
Caching the results pre-shuffling prevents TF from recomputing the pre-shuffle batches from scratch repeatedly. A speedup of about >5x is expected for the input pipeline, and similarly for the training cycle times as well.
1 parent 906a2f3 commit 87104f3

File tree

2 files changed

+11
-10
lines changed

2 files changed

+11
-10
lines changed

compiler_opt/rl/data_reader.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -165,15 +165,16 @@ def create_sequence_example_dataset_fn(
165165
def _sequence_example_dataset_fn(sequence_examples):
166166
# Data collector returns empty strings for corner cases, filter them out
167167
# here.
168-
dataset = tf.data.Dataset.from_tensor_slices(sequence_examples).filter(
169-
lambda string: tf.strings.length(string) > 0).map(parser_fn).filter(
170-
lambda traj: tf.size(traj.reward) > 2)
171-
dataset = (
172-
dataset.unbatch().batch(
173-
train_sequence_length,
174-
drop_remainder=True).shuffle(trajectory_shuffle_buffer_size).batch(
175-
batch_size,
176-
drop_remainder=True).prefetch(tf.data.experimental.AUTOTUNE))
168+
dataset = (tf.data.Dataset.from_tensor_slices(sequence_examples)
169+
.filter(lambda string: tf.strings.length(string) > 0)
170+
.map(parser_fn)
171+
.filter(lambda traj: tf.size(traj.reward) > 2)
172+
.unbatch()
173+
.batch(train_sequence_length, drop_remainder=True)
174+
.cache()
175+
.shuffle(trajectory_shuffle_buffer_size)
176+
.batch(batch_size, drop_remainder=True)
177+
)
177178
return dataset
178179

179180
return _sequence_example_dataset_fn

compiler_opt/rl/train_locally.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def train_eval(agent_name=constant.AgentName.PPO,
126126
train_sequence_length=train_sequence_length)
127127

128128
def sequence_example_iterator_fn(seq_ex: List[str]):
129-
return iter(dataset_fn(seq_ex).repeat())
129+
return iter(dataset_fn(seq_ex).repeat().prefetch(tf.data.AUTOTUNE))
130130

131131
reward_stat_map = collections.defaultdict(lambda: None)
132132
reward_stat_map_path = os.path.join(root_dir, 'reward_stat_map')

0 commit comments

Comments
 (0)