Skip to content

Commit a2d19d3

Browse files
authored
Add TensorFlow profiler to training loop. (#296)
* Use the TF Profiler in sampling mode through the gRPC server API. * This enables on-demand, remote sampling with TPUs or multiple workers. * Add unit test for TF Profiler. * Tests the profiling server by sending a request and ensuring the profile is written to the expected location.
1 parent 8ac007f commit a2d19d3

File tree

3 files changed

+98
-8
lines changed

3 files changed

+98
-8
lines changed

gematria/model/python/main_function.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,23 @@ def main(_):
424424
),
425425
)
426426

427+
_GEMATRIA_RUN_TF_PROFILER = flags.DEFINE_bool(
428+
'gematria_run_tf_profiler',
429+
False,
430+
'Whether the TensorFlow profiler gRPC server is started or not. When set,'
431+
' the server will listen to `gematria_tf_profiler_port` for requests for'
432+
' on-demand profiling. Requests can be sent through'
433+
' `tf.profiler.experimental.client.trace` or through the TensorBoard GUI.',
434+
)
435+
_GEMATRIA_TF_PROFILER_PORT = flags.DEFINE_integer(
436+
'gematria_tf_profiler_port',
437+
6009,
438+
(
439+
'When running under the TensorFlow profiler, this is the port the'
440+
' gRPC server listens for tracing requests from.'
441+
),
442+
)
443+
427444

428445
@flags.validator(
429446
_COLLECTED_PERCENTILE_RANKS.name,
@@ -825,6 +842,9 @@ def checkpoint_model():
825842
_GEMATRIA_SUMMARY_DIR.value
826843
)
827844

845+
if _GEMATRIA_RUN_TF_PROFILER.value:
846+
tf.profiler.experimental.server.start(_GEMATRIA_TF_PROFILER_PORT.value)
847+
828848
with train_summary_writer.as_default(), tf.summary.record_if(
829849
lambda: tf.equal(
830850
model.global_step % _GEMATRIA_SAVE_SUMMARIES_EPOCHS.value, 0

gematria/model/python/main_function_test.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from os import path
1818
import os
1919
import re
20+
import threading
2021
from unittest import mock
2122

2223
from absl import flags
@@ -769,6 +770,72 @@ def test_multi_task_flags(self):
769770
FLAGS.gematria_throughput_source_filter = ['alice', 'bob']
770771
FLAGS.validate_all_flags()
771772

773+
@flagsaver.flagsaver
774+
def test_train_under_tf_profiler(self):
775+
"""Tests the profiling of model training using the TF Profiler.
776+
777+
The tests prepares training data and runs the actual training for a small
778+
number of epochs under the TF Profiler. Then checks that the expected profiles
779+
were recorded and stored at the expected directory.
780+
"""
781+
num_epochs = 10
782+
max_blocks_in_batch = 15
783+
max_instructions_in_batch = 124
784+
learning_rate = 0.321
785+
randomize_batches = False
786+
training_throughput_selection = io_options.ThroughputSelection.RANDOM
787+
checkpoint_dir = path.join(self.work_directory.full_path, 'checkpoint')
788+
summary_dir = path.join(self.work_directory.full_path, 'summary')
789+
tf_profiler_port = 6009
790+
use_seq2seq_loss = False # The default is True.
791+
792+
model = None
793+
794+
def MockModel(*args, **kwargs):
795+
nonlocal model
796+
self.assertEqual(kwargs['learning_rate'], learning_rate)
797+
model = TestModel(*args, **kwargs)
798+
# Record calls to model.train(), but still call the original method.
799+
mock_train = mock.MagicMock(side_effect=model.train)
800+
model.train = mock_train
801+
return model
802+
803+
FLAGS.gematria_action = model_options.Action.TRAIN
804+
FLAGS.gematria_run_tf_profiler = True
805+
FLAGS.gematria_tf_profiler_port = tf_profiler_port
806+
FLAGS.gematria_input_file = (self.input_filename,)
807+
FLAGS.gematria_checkpoint_dir = checkpoint_dir
808+
FLAGS.gematria_summary_dir = summary_dir
809+
FLAGS.gematria_training_num_epochs = num_epochs
810+
FLAGS.gematria_training_randomize_batches = randomize_batches
811+
FLAGS.gematria_max_blocks_in_batch = max_blocks_in_batch
812+
FLAGS.gematria_max_instructions_in_batch = max_instructions_in_batch
813+
FLAGS.gematria_use_seq2seq_loss = use_seq2seq_loss
814+
FLAGS.gematria_learning_rate = learning_rate
815+
FLAGS.gematria_training_throughput_selection = training_throughput_selection
816+
817+
# Set up a thread for the training process running the profiling server.
818+
server_thread = threading.Thread(
819+
target=main_function.run_gematria_model_from_command_line_flags,
820+
args=(MockModel,),
821+
kwargs={'dtype': tf.dtypes.float32},
822+
)
823+
server_thread.start()
824+
825+
# Try sending a trace request to the TF Profiler.
826+
tf.profiler.experimental.client.trace(
827+
service_addr=f'grpc://localhost:{tf_profiler_port}',
828+
logdir=summary_dir,
829+
duration_ms=1000,
830+
num_tracing_attempts=4000, # Keep trying until the server is ready.
831+
)
832+
server_thread.join()
833+
834+
# Check that profile has been written to the expected location.
835+
self._assert_file_exists(
836+
f'summary/plugins/profile/*/localhost_{tf_profiler_port}.xplane.pb'
837+
)
838+
772839

773840
if __name__ == '__main__':
774841
tf.test.main()

gematria/model/python/model_base.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1298,14 +1298,17 @@ def run_one_epoch():
12981298

12991299
with timer.scoped('ModelBase.train - one batch', num_iterations=num_epochs):
13001300
for epoch_index in range(num_epochs):
1301-
tf.summary.experimental.set_step(epoch_index)
1302-
stats = run_one_epoch()
1303-
logging.info('Training: %s', stats)
1304-
if not hooks:
1305-
continue
1306-
for epochs_every, hook_function in hooks:
1307-
if (epoch_index + 1) % epochs_every == 0:
1308-
hook_function()
1301+
with tf.profiler.experimental.Trace(
1302+
'train', step_num=epoch_index, _r=1
1303+
):
1304+
tf.summary.experimental.set_step(epoch_index)
1305+
stats = run_one_epoch()
1306+
logging.info('Training: %s', stats)
1307+
if not hooks:
1308+
continue
1309+
for epochs_every, hook_function in hooks:
1310+
if (epoch_index + 1) % epochs_every == 0:
1311+
hook_function()
13091312
return stats
13101313

13111314
def _compute_loss(self, schedule: FeedDict) -> loss_utils.LossComputation:

0 commit comments

Comments
 (0)