Skip to content

Commit 87e51a6

Browse files
committed
ensure reproduciblity & consistent restart on Tensorboard side
1 parent 8465a43 commit 87e51a6

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

main.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,14 @@
7979
help='value by which flow will be divided. Original value is 20 but 1 with batchNorm gives good results')
8080
parser.add_argument('--milestones', default=[100,150,200], metavar='N', nargs='*', help='epochs at which learning rate is divided by 2')
8181

82+
parser.add_argument(
83+
"--seed-split",
84+
default=None,
85+
help="Seed the train-val split to enforce reproducibility (consistent restart too)",
86+
)
87+
8288
best_EPE = -1
83-
n_iter = 0
89+
n_iter = int(start_epoch)
8490
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8591

8692

@@ -102,6 +108,9 @@ def main():
102108
if not os.path.exists(save_path):
103109
os.makedirs(save_path)
104110

111+
if args.seed_split:
112+
np.random.seed(int(args.seed_split))
113+
105114
train_writer = SummaryWriter(os.path.join(save_path,'train'))
106115
test_writer = SummaryWriter(os.path.join(save_path,'test'))
107116
output_writers = []
@@ -294,7 +303,7 @@ def validate(val_loader, model, epoch, output_writers):
294303
end = time.time()
295304

296305
if i < len(output_writers): # log first output of first batches
297-
if epoch == 0:
306+
if epoch == args.start_epoch:
298307
mean_values = torch.tensor([0.45,0.432,0.411], dtype=input.dtype).view(3,1,1)
299308
output_writers[i].add_image('GroundTruth', flow2rgb(args.div_flow * target[0], max_value=10), 0)
300309
output_writers[i].add_image('Inputs', (input[0,:3].cpu() + mean_values).clamp(0,1), 0)

0 commit comments

Comments
 (0)