Skip to content

Commit 186e970

Browse files
committed
change requested applied
1 parent 87e51a6 commit 186e970

File tree

1 file changed

+9
-7
lines changed

1 file changed

+9
-7
lines changed

main.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,12 @@
3535
group.add_argument('--split-value', default=0.8, type=float,
3636
help='test-val split proportion between 0 (only test) and 1 (only train), '
3737
'will be overwritten if a split file is set')
38+
parser.add_argument(
39+
"--split-seed",
40+
type=int,
41+
default=None,
42+
help="Seed the train-val split to enforce reproducibility (consistent restart too)",
43+
)
3844
parser.add_argument('--arch', '-a', metavar='ARCH', default='flownets',
3945
choices=model_names,
4046
help='model architecture, overwritten if pretrained is specified: ' +
@@ -79,11 +85,7 @@
7985
help='value by which flow will be divided. Original value is 20 but 1 with batchNorm gives good results')
8086
parser.add_argument('--milestones', default=[100,150,200], metavar='N', nargs='*', help='epochs at which learning rate is divided by 2')
8187

82-
parser.add_argument(
83-
"--seed-split",
84-
default=None,
85-
help="Seed the train-val split to enforce reproducibility (consistent restart too)",
86-
)
88+
8789

8890
best_EPE = -1
8991
n_iter = int(start_epoch)
@@ -108,8 +110,8 @@ def main():
108110
if not os.path.exists(save_path):
109111
os.makedirs(save_path)
110112

111-
if args.seed_split:
112-
np.random.seed(int(args.seed_split))
113+
if args.seed_split is not None:
114+
np.random.seed(args.seed_split)
113115

114116
train_writer = SummaryWriter(os.path.join(save_path,'train'))
115117
test_writer = SummaryWriter(os.path.join(save_path,'test'))

0 commit comments

Comments
 (0)