Skip to content

Commit 9b06bd6

Browse files
committed
Refactor out shared LSTM/GGNN training loop.
github.com//issues/69
1 parent 9c4c442 commit 9b06bd6

File tree

3 files changed

+50
-89
lines changed

3 files changed

+50
-89
lines changed

programl/task/dataflow/dataflow.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import warnings
2222
from typing import Tuple
2323

24-
from labm8.py import app, pbutil
24+
from labm8.py import app, humanize, pbutil
2525
from sklearn.exceptions import UndefinedMetricWarning
2626

2727
from programl.proto import checkpoint_pb2, epoch_pb2
@@ -208,3 +208,50 @@ def CreateLoggingDirectories(
208208
(log_dir / "checkpoints").mkdir()
209209
(log_dir / "graph_loader").mkdir()
210210
return log_dir
211+
212+
213+
def run_training_loop(log_dir, epochs, start_epoch_step, model):
214+
for (
215+
epoch_step,
216+
(train_graph_count, train_graph_cumsum, train_batches),
217+
) in enumerate(epochs, start=start_epoch_step):
218+
start_time = time.time()
219+
hr_graph_cumsum = f"{humanize.Commas(train_graph_cumsum)} graphs"
220+
221+
train_results = model.RunBatches(
222+
epoch_pb2.TRAIN,
223+
train_batches,
224+
log_prefix=f"Train to {hr_graph_cumsum}",
225+
total_graph_count=train_graph_count,
226+
)
227+
val_results = model.RunBatches(
228+
epoch_pb2.VAL,
229+
val_batches.batches,
230+
log_prefix=f"Val at {hr_graph_cumsum}",
231+
total_graph_count=FLAGS.val_graph_count,
232+
)
233+
234+
# Write the epoch to file as an epoch list. This may seem redundant since
235+
# epoch list contains a single item, but it means that we can easily
236+
# concatenate a sequence of these epoch protos to produce a valid epoch
237+
# list using: `cat *.EpochList.pbtxt > epochs.pbtxt`
238+
epoch = epoch_pb2.EpochList(
239+
epoch=[
240+
epoch_pb2.Epoch(
241+
walltime_seconds=time.time() - start_time,
242+
epoch_num=epoch_step,
243+
train_results=train_results,
244+
val_results=val_results,
245+
)
246+
]
247+
)
248+
print(epoch, end="")
249+
250+
epoch_path = log_dir / "epochs" / f"{epoch_step:03d}.EpochList.pbtxt"
251+
pbutil.ToFile(epoch, epoch_path)
252+
app.Log(1, "Wrote %s", epoch_path)
253+
254+
checkpoint_path = log_dir / "checkpoints" / f"{epoch_step:03d}.Checkpoint.pb"
255+
pbutil.ToFile(model.SaveCheckpoint(), checkpoint_path)
256+
257+
return log_dir

programl/task/dataflow/ggnn.py

Lines changed: 1 addition & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -173,50 +173,7 @@ def TrainDataflowGGNN(
173173
)
174174
)
175175

176-
for (
177-
epoch_step,
178-
(train_graph_count, train_graph_cumsum, train_batches),
179-
) in enumerate(epochs, start=start_epoch_step):
180-
start_time = time.time()
181-
hr_graph_cumsum = f"{humanize.Commas(train_graph_cumsum)} graphs"
182-
183-
train_results = model.RunBatches(
184-
epoch_pb2.TRAIN,
185-
train_batches,
186-
log_prefix=f"Train to {hr_graph_cumsum}",
187-
total_graph_count=train_graph_count,
188-
)
189-
val_results = model.RunBatches(
190-
epoch_pb2.VAL,
191-
val_batches.batches,
192-
log_prefix=f"Val at {hr_graph_cumsum}",
193-
total_graph_count=val_graph_count,
194-
)
195-
196-
# Write the epoch to file as an epoch list. This may seem redundant since
197-
# epoch list contains a single item, but it means that we can easily
198-
# concatenate a sequence of these epoch protos to produce a valid epoch
199-
# list using: `cat *.EpochList.pbtxt > epochs.pbtxt`
200-
epoch = epoch_pb2.EpochList(
201-
epoch=[
202-
epoch_pb2.Epoch(
203-
walltime_seconds=time.time() - start_time,
204-
epoch_num=epoch_step,
205-
train_results=train_results,
206-
val_results=val_results,
207-
)
208-
]
209-
)
210-
print(epoch, end="")
211-
212-
epoch_path = log_dir / "epochs" / f"{epoch_step:03d}.EpochList.pbtxt"
213-
pbutil.ToFile(epoch, epoch_path)
214-
app.Log(1, "Wrote %s", epoch_path)
215-
216-
checkpoint_path = log_dir / "checkpoints" / f"{epoch_step:03d}.Checkpoint.pb"
217-
pbutil.ToFile(model.SaveCheckpoint(), checkpoint_path)
218-
219-
return log_dir
176+
return dataflow.run_training_loop(log_dir, epochs, start_epoch_step, model)
220177

221178

222179
def TestDataflowGGNN(

programl/task/dataflow/train_lstm.py

Lines changed: 1 addition & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -160,50 +160,7 @@ def TrainDataflowLSTM(
160160
)
161161
)
162162

163-
for (
164-
epoch_step,
165-
(train_graph_count, train_graph_cumsum, train_batches),
166-
) in enumerate(epochs, start=start_epoch_step):
167-
start_time = time.time()
168-
hr_graph_cumsum = f"{humanize.Commas(train_graph_cumsum)} graphs"
169-
170-
train_results = model.RunBatches(
171-
epoch_pb2.TRAIN,
172-
train_batches,
173-
log_prefix=f"Train to {hr_graph_cumsum}",
174-
total_graph_count=train_graph_count,
175-
)
176-
val_results = model.RunBatches(
177-
epoch_pb2.VAL,
178-
val_batches.batches,
179-
log_prefix=f"Val at {hr_graph_cumsum}",
180-
total_graph_count=FLAGS.val_graph_count,
181-
)
182-
183-
# Write the epoch to file as an epoch list. This may seem redundant since
184-
# epoch list contains a single item, but it means that we can easily
185-
# concatenate a sequence of these epoch protos to produce a valid epoch
186-
# list using: `cat *.EpochList.pbtxt > epochs.pbtxt`
187-
epoch = epoch_pb2.EpochList(
188-
epoch=[
189-
epoch_pb2.Epoch(
190-
walltime_seconds=time.time() - start_time,
191-
epoch_num=epoch_step,
192-
train_results=train_results,
193-
val_results=val_results,
194-
)
195-
]
196-
)
197-
print(epoch, end="")
198-
199-
epoch_path = log_dir / "epochs" / f"{epoch_step:03d}.EpochList.pbtxt"
200-
pbutil.ToFile(epoch, epoch_path)
201-
app.Log(1, "Wrote %s", epoch_path)
202-
203-
checkpoint_path = log_dir / "checkpoints" / f"{epoch_step:03d}.Checkpoint.pb"
204-
pbutil.ToFile(model.SaveCheckpoint(), checkpoint_path)
205-
206-
return log_dir
163+
return dataflow.run_training_loop(log_dir, epochs, start_epoch_step, model)
207164

208165

209166
def TestDataflowLSTM(

0 commit comments

Comments
 (0)