From 4f4b50b45d00705f6f833a995dc4cbe8e62a7c70 Mon Sep 17 00:00:00 2001 From: Chris Cummins Date: Sat, 29 Aug 2020 01:26:55 +0100 Subject: [PATCH] Refactor out shared LSTM/GGNN training loop. github.com/ChrisCummins/ProGraML/issues/69 --- programl/task/dataflow/BUILD | 2 + programl/task/dataflow/dataflow.py | 71 +++++++++++++++++++++++++++- programl/task/dataflow/ggnn.py | 47 ++---------------- programl/task/dataflow/train_lstm.py | 47 ++---------------- 4 files changed, 78 insertions(+), 89 deletions(-) diff --git a/programl/task/dataflow/BUILD b/programl/task/dataflow/BUILD index 8ad5bbb00..732fb9f0c 100644 --- a/programl/task/dataflow/BUILD +++ b/programl/task/dataflow/BUILD @@ -55,6 +55,8 @@ py_library( deps = [ ":graph_loader", "//programl/models:async_batch_builder", + "//programl/models:base_batch_builder", + "//programl/models:model", "//programl/models/ggnn", "//programl/proto:checkpoint_py", "//programl/proto:epoch_py", diff --git a/programl/task/dataflow/dataflow.py b/programl/task/dataflow/dataflow.py index 4551c249f..906d8e6cc 100644 --- a/programl/task/dataflow/dataflow.py +++ b/programl/task/dataflow/dataflow.py @@ -21,9 +21,11 @@ import warnings from typing import Tuple -from labm8.py import app, pbutil +from labm8.py import app, humanize, pbutil from sklearn.exceptions import UndefinedMetricWarning +from programl.models.base_batch_builder import BaseBatchBuilder +from programl.models.model import Model from programl.proto import checkpoint_pb2, epoch_pb2 app.DEFINE_string( @@ -208,3 +210,70 @@ def CreateLoggingDirectories( (log_dir / "checkpoints").mkdir() (log_dir / "graph_loader").mkdir() return log_dir + + +def run_training_loop( + log_dir: pathlib.Path, + epochs, + val_batches: BaseBatchBuilder, + start_epoch_step: int, + model: Model, + val_graph_count: int, +) -> pathlib.Path: + """ + + Args: + log_dir: The logging directory. + epochs: An epoch batch builder. + val_batches: A batch builder for validation. + start_epoch_step: The initial step count. + model: The model to train. + val_graph_count: The number of validation graphs. + + Returns: + The log_dir first argument. + """ + for ( + epoch_step, + (train_graph_count, train_graph_cumsum, train_batches), + ) in enumerate(epochs, start=start_epoch_step): + start_time = time.time() + hr_graph_cumsum = f"{humanize.Commas(train_graph_cumsum)} graphs" + + train_results = model.RunBatches( + epoch_pb2.TRAIN, + train_batches, + log_prefix=f"Train to {hr_graph_cumsum}", + total_graph_count=train_graph_count, + ) + val_results = model.RunBatches( + epoch_pb2.VAL, + val_batches.batches, + log_prefix=f"Val at {hr_graph_cumsum}", + total_graph_count=val_graph_count, + ) + + # Write the epoch to file as an epoch list. This may seem redundant since + # epoch list contains a single item, but it means that we can easily + # concatenate a sequence of these epoch protos to produce a valid epoch + # list using: `cat *.EpochList.pbtxt > epochs.pbtxt` + epoch = epoch_pb2.EpochList( + epoch=[ + epoch_pb2.Epoch( + walltime_seconds=time.time() - start_time, + epoch_num=epoch_step, + train_results=train_results, + val_results=val_results, + ) + ] + ) + print(epoch, end="") + + epoch_path = log_dir / "epochs" / f"{epoch_step:03d}.EpochList.pbtxt" + pbutil.ToFile(epoch, epoch_path) + app.Log(1, "Wrote %s", epoch_path) + + checkpoint_path = log_dir / "checkpoints" / f"{epoch_step:03d}.Checkpoint.pb" + pbutil.ToFile(model.SaveCheckpoint(), checkpoint_path) + + return log_dir diff --git a/programl/task/dataflow/ggnn.py b/programl/task/dataflow/ggnn.py index f0c39e1bb..25b0db1c1 100644 --- a/programl/task/dataflow/ggnn.py +++ b/programl/task/dataflow/ggnn.py @@ -173,50 +173,9 @@ def TrainDataflowGGNN( ) ) - for ( - epoch_step, - (train_graph_count, train_graph_cumsum, train_batches), - ) in enumerate(epochs, start=start_epoch_step): - start_time = time.time() - hr_graph_cumsum = f"{humanize.Commas(train_graph_cumsum)} graphs" - - train_results = model.RunBatches( - epoch_pb2.TRAIN, - train_batches, - log_prefix=f"Train to {hr_graph_cumsum}", - total_graph_count=train_graph_count, - ) - val_results = model.RunBatches( - epoch_pb2.VAL, - val_batches.batches, - log_prefix=f"Val at {hr_graph_cumsum}", - total_graph_count=val_graph_count, - ) - - # Write the epoch to file as an epoch list. This may seem redundant since - # epoch list contains a single item, but it means that we can easily - # concatenate a sequence of these epoch protos to produce a valid epoch - # list using: `cat *.EpochList.pbtxt > epochs.pbtxt` - epoch = epoch_pb2.EpochList( - epoch=[ - epoch_pb2.Epoch( - walltime_seconds=time.time() - start_time, - epoch_num=epoch_step, - train_results=train_results, - val_results=val_results, - ) - ] - ) - print(epoch, end="") - - epoch_path = log_dir / "epochs" / f"{epoch_step:03d}.EpochList.pbtxt" - pbutil.ToFile(epoch, epoch_path) - app.Log(1, "Wrote %s", epoch_path) - - checkpoint_path = log_dir / "checkpoints" / f"{epoch_step:03d}.Checkpoint.pb" - pbutil.ToFile(model.SaveCheckpoint(), checkpoint_path) - - return log_dir + return dataflow.run_training_loop( + log_dir, epochs, val_batches, start_epoch_step, model, val_graph_count + ) def TestDataflowGGNN( diff --git a/programl/task/dataflow/train_lstm.py b/programl/task/dataflow/train_lstm.py index 6f3ee3d5b..9108363f4 100644 --- a/programl/task/dataflow/train_lstm.py +++ b/programl/task/dataflow/train_lstm.py @@ -160,50 +160,9 @@ def TrainDataflowLSTM( ) ) - for ( - epoch_step, - (train_graph_count, train_graph_cumsum, train_batches), - ) in enumerate(epochs, start=start_epoch_step): - start_time = time.time() - hr_graph_cumsum = f"{humanize.Commas(train_graph_cumsum)} graphs" - - train_results = model.RunBatches( - epoch_pb2.TRAIN, - train_batches, - log_prefix=f"Train to {hr_graph_cumsum}", - total_graph_count=train_graph_count, - ) - val_results = model.RunBatches( - epoch_pb2.VAL, - val_batches.batches, - log_prefix=f"Val at {hr_graph_cumsum}", - total_graph_count=FLAGS.val_graph_count, - ) - - # Write the epoch to file as an epoch list. This may seem redundant since - # epoch list contains a single item, but it means that we can easily - # concatenate a sequence of these epoch protos to produce a valid epoch - # list using: `cat *.EpochList.pbtxt > epochs.pbtxt` - epoch = epoch_pb2.EpochList( - epoch=[ - epoch_pb2.Epoch( - walltime_seconds=time.time() - start_time, - epoch_num=epoch_step, - train_results=train_results, - val_results=val_results, - ) - ] - ) - print(epoch, end="") - - epoch_path = log_dir / "epochs" / f"{epoch_step:03d}.EpochList.pbtxt" - pbutil.ToFile(epoch, epoch_path) - app.Log(1, "Wrote %s", epoch_path) - - checkpoint_path = log_dir / "checkpoints" / f"{epoch_step:03d}.Checkpoint.pb" - pbutil.ToFile(model.SaveCheckpoint(), checkpoint_path) - - return log_dir + return dataflow.run_training_loop( + log_dir, epochs, val_batches, start_epoch_step, model, FLAGS.val_graph_count + ) def TestDataflowLSTM(