3535group .add_argument ('--split-value' , default = 0.8 , type = float ,
3636 help = 'test-val split proportion between 0 (only test) and 1 (only train), '
3737 'will be overwritten if a split file is set' )
38+ parser .add_argument (
39+ "--split-seed" ,
40+ type = int ,
41+ default = None ,
42+ help = "Seed the train-val split to enforce reproducibility (consistent restart too)" ,
43+ )
3844parser .add_argument ('--arch' , '-a' , metavar = 'ARCH' , default = 'flownets' ,
3945 choices = model_names ,
4046 help = 'model architecture, overwritten if pretrained is specified: ' +
7985 help = 'value by which flow will be divided. Original value is 20 but 1 with batchNorm gives good results' )
8086parser .add_argument ('--milestones' , default = [100 ,150 ,200 ], metavar = 'N' , nargs = '*' , help = 'epochs at which learning rate is divided by 2' )
8187
88+
89+
8290best_EPE = - 1
83- n_iter = 0
91+ n_iter = int ( start_epoch )
8492device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
8593
8694
@@ -102,6 +110,9 @@ def main():
102110 if not os .path .exists (save_path ):
103111 os .makedirs (save_path )
104112
113+ if args .seed_split is not None :
114+ np .random .seed (args .seed_split )
115+
105116 train_writer = SummaryWriter (os .path .join (save_path ,'train' ))
106117 test_writer = SummaryWriter (os .path .join (save_path ,'test' ))
107118 output_writers = []
@@ -294,7 +305,7 @@ def validate(val_loader, model, epoch, output_writers):
294305 end = time .time ()
295306
296307 if i < len (output_writers ): # log first output of first batches
297- if epoch == 0 :
308+ if epoch == args . start_epoch :
298309 mean_values = torch .tensor ([0.45 ,0.432 ,0.411 ], dtype = input .dtype ).view (3 ,1 ,1 )
299310 output_writers [i ].add_image ('GroundTruth' , flow2rgb (args .div_flow * target [0 ], max_value = 10 ), 0 )
300311 output_writers [i ].add_image ('Inputs' , (input [0 ,:3 ].cpu () + mean_values ).clamp (0 ,1 ), 0 )
0 commit comments