Skip to content

Commit 251088a

Browse files
Merge pull request #110 from yasserben/reproducibility
ensure reproduciblity & consistent restart on Tensorboard side
2 parents 8465a43 + 186e970 commit 251088a

File tree

1 file changed

+13
-2
lines changed

1 file changed

+13
-2
lines changed

main.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,12 @@
3535
group.add_argument('--split-value', default=0.8, type=float,
3636
help='test-val split proportion between 0 (only test) and 1 (only train), '
3737
'will be overwritten if a split file is set')
38+
parser.add_argument(
39+
"--split-seed",
40+
type=int,
41+
default=None,
42+
help="Seed the train-val split to enforce reproducibility (consistent restart too)",
43+
)
3844
parser.add_argument('--arch', '-a', metavar='ARCH', default='flownets',
3945
choices=model_names,
4046
help='model architecture, overwritten if pretrained is specified: ' +
@@ -79,8 +85,10 @@
7985
help='value by which flow will be divided. Original value is 20 but 1 with batchNorm gives good results')
8086
parser.add_argument('--milestones', default=[100,150,200], metavar='N', nargs='*', help='epochs at which learning rate is divided by 2')
8187

88+
89+
8290
best_EPE = -1
83-
n_iter = 0
91+
n_iter = int(start_epoch)
8492
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8593

8694

@@ -102,6 +110,9 @@ def main():
102110
if not os.path.exists(save_path):
103111
os.makedirs(save_path)
104112

113+
if args.seed_split is not None:
114+
np.random.seed(args.seed_split)
115+
105116
train_writer = SummaryWriter(os.path.join(save_path,'train'))
106117
test_writer = SummaryWriter(os.path.join(save_path,'test'))
107118
output_writers = []
@@ -294,7 +305,7 @@ def validate(val_loader, model, epoch, output_writers):
294305
end = time.time()
295306

296307
if i < len(output_writers): # log first output of first batches
297-
if epoch == 0:
308+
if epoch == args.start_epoch:
298309
mean_values = torch.tensor([0.45,0.432,0.411], dtype=input.dtype).view(3,1,1)
299310
output_writers[i].add_image('GroundTruth', flow2rgb(args.div_flow * target[0], max_value=10), 0)
300311
output_writers[i].add_image('Inputs', (input[0,:3].cpu() + mean_values).clamp(0,1), 0)

0 commit comments

Comments
 (0)