7979 help = 'value by which flow will be divided. Original value is 20 but 1 with batchNorm gives good results' )
8080parser .add_argument ('--milestones' , default = [100 ,150 ,200 ], metavar = 'N' , nargs = '*' , help = 'epochs at which learning rate is divided by 2' )
8181
82+ parser .add_argument (
83+ "--seed-split" ,
84+ default = None ,
85+ help = "Seed the train-val split to enforce reproducibility (consistent restart too)" ,
86+ )
87+
8288best_EPE = - 1
83- n_iter = 0
89+ n_iter = int ( start_epoch )
8490device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
8591
8692
@@ -102,6 +108,9 @@ def main():
102108 if not os .path .exists (save_path ):
103109 os .makedirs (save_path )
104110
111+ if args .seed_split :
112+ np .random .seed (int (args .seed_split ))
113+
105114 train_writer = SummaryWriter (os .path .join (save_path ,'train' ))
106115 test_writer = SummaryWriter (os .path .join (save_path ,'test' ))
107116 output_writers = []
@@ -294,7 +303,7 @@ def validate(val_loader, model, epoch, output_writers):
294303 end = time .time ()
295304
296305 if i < len (output_writers ): # log first output of first batches
297- if epoch == 0 :
306+ if epoch == args . start_epoch :
298307 mean_values = torch .tensor ([0.45 ,0.432 ,0.411 ], dtype = input .dtype ).view (3 ,1 ,1 )
299308 output_writers [i ].add_image ('GroundTruth' , flow2rgb (args .div_flow * target [0 ], max_value = 10 ), 0 )
300309 output_writers [i ].add_image ('Inputs' , (input [0 ,:3 ].cpu () + mean_values ).clamp (0 ,1 ), 0 )
0 commit comments