2323best_result  =  Result ()
2424best_result .set_to_worst ()
2525
26+ def  create_data_loaders (args ):
27+     # Data loading code 
28+     print ("=> creating data loaders ..." )
29+     traindir  =  os .path .join ('data' , args .data , 'train' )
30+     valdir  =  os .path .join ('data' , args .data , 'val' )
31+     train_loader  =  None 
32+     val_loader  =  None 
33+ 
34+     # sparsifier is a class for generating random sparse depth input from the ground truth 
35+     sparsifier  =  None 
36+     max_depth  =  args .max_depth  if  args .max_depth  >=  0.0  else  np .inf 
37+     if  args .sparsifier  ==  UniformSampling .name :
38+         sparsifier  =  UniformSampling (num_samples = args .num_samples , max_depth = max_depth )
39+     elif  args .sparsifier  ==  SimulatedStereo .name :
40+         sparsifier  =  SimulatedStereo (num_samples = args .num_samples , max_depth = max_depth )
41+ 
42+     if  args .data  ==  'nyudepthv2' :
43+         from  dataloaders .nyu_dataloader  import  NYUDataset 
44+         if  not  args .evaluate :
45+             train_dataset  =  NYUDataset (traindir , type = 'train' ,
46+                 modality = args .modality , sparsifier = sparsifier )
47+         val_dataset  =  NYUDataset (valdir , type = 'val' ,
48+             modality = args .modality , sparsifier = sparsifier )
49+ 
50+     elif  args .data  ==  'kitti' :
51+         from  dataloaders .kitti_dataloader  import  KITTIDataset 
52+         if  not  args .evaluate :
53+             train_dataset  =  KITTIDataset (traindir , type = 'train' ,
54+                 modality = args .modality , sparsifier = sparsifier )
55+         val_dataset  =  KITTIDataset (valdir , type = 'val' ,
56+             modality = args .modality , sparsifier = sparsifier )
57+ 
58+     else :
59+         raise  RuntimeError ('Dataset not found.'  + 
60+                            'The dataset must be either of nyudepthv2 or kitti.' )
61+ 
62+     # set batch size to be 1 for validation 
63+     val_loader  =  torch .utils .data .DataLoader (val_dataset ,
64+         batch_size = 1 , shuffle = False , num_workers = args .workers , pin_memory = True )
65+ 
66+     # put construction of train loader here, for those who are interested in testing only 
67+     if  not  args .evaluate :
68+         train_loader  =  torch .utils .data .DataLoader (
69+             train_dataset , batch_size = args .batch_size , shuffle = True ,
70+             num_workers = args .workers , pin_memory = True , sampler = None ,
71+             worker_init_fn = lambda  work_id :np .random .seed (work_id ))
72+             # worker_init_fn ensures different sampling patterns for each data loading thread 
73+ 
74+     print ("=> data loaders created." )
75+     return  train_loader , val_loader 
76+ 
2677def  main ():
2778    global  args , best_result , output_directory , train_csv , test_csv 
2879
@@ -33,12 +84,16 @@ def main():
3384        "=> no best model found at '{}'" .format (args .evaluate )
3485        print ("=> loading best model '{}'" .format (args .evaluate ))
3586        checkpoint  =  torch .load (args .evaluate )
87+         output_directory  =  os .path .dirname (args .evaluate )
3688        args  =  checkpoint ['args' ]
37-         args .evaluate  =  True 
3889        start_epoch  =  checkpoint ['epoch' ] +  1 
3990        best_result  =  checkpoint ['best_result' ]
4091        model  =  checkpoint ['model' ]
4192        print ("=> loaded best model (epoch {})" .format (checkpoint ['epoch' ]))
93+         _ , val_loader  =  create_data_loaders (args )
94+         args .evaluate  =  True 
95+         validate (val_loader , model , checkpoint ['epoch' ], write_to_file = False )
96+         return 
4297
4398    # optionally resume from a checkpoint 
4499    elif  args .resume :
@@ -51,93 +106,35 @@ def main():
51106        best_result  =  checkpoint ['best_result' ]
52107        model  =  checkpoint ['model' ]
53108        optimizer  =  checkpoint ['optimizer' ]
54-         output_directory ,  _   =  os .path .split ( args .resume )
109+         output_directory   =  os .path .dirname ( os . path . abspath ( args .resume ) )
55110        print ("=> loaded checkpoint (epoch {})" .format (checkpoint ['epoch' ]))
111+         train_loader , val_loader  =  create_data_loaders (args )
112+         args .resume  =  True 
56113
57114    # create new model 
58115    else :
59-         # define model 
116+         train_loader ,  val_loader   =   create_data_loaders ( args ) 
60117        print ("=> creating Model ({}-{}) ..." .format (args .arch , args .decoder ))
61118        in_channels  =  len (args .modality )
62119        if  args .arch  ==  'resnet50' :
63-             model  =  ResNet (layers = 50 , decoder = args .decoder , output_size = train_dataset .output_size ,
120+             model  =  ResNet (layers = 50 , decoder = args .decoder , output_size = train_loader . dataset .output_size ,
64121                in_channels = in_channels , pretrained = args .pretrained )
65122        elif  args .arch  ==  'resnet18' :
66-             model  =  ResNet (layers = 18 , decoder = args .decoder , output_size = train_dataset .output_size ,
123+             model  =  ResNet (layers = 18 , decoder = args .decoder , output_size = train_loader . dataset .output_size ,
67124                in_channels = in_channels , pretrained = args .pretrained )
68125        print ("=> model created." )
69- 
70126        optimizer  =  torch .optim .SGD (model .parameters (), args .lr , \
71127            momentum = args .momentum , weight_decay = args .weight_decay )
72128
73-         # create new csv files with only header 
74-         with  open (train_csv , 'w' ) as  csvfile :
75-             writer  =  csv .DictWriter (csvfile , fieldnames = fieldnames )
76-             writer .writeheader ()
77-         with  open (test_csv , 'w' ) as  csvfile :
78-             writer  =  csv .DictWriter (csvfile , fieldnames = fieldnames )
79-             writer .writeheader ()
80- 
81-     # model = torch.nn.DataParallel(model).cuda() # for multi-gpu training 
82-     model  =  model .cuda ()
83-     # print(model) 
84-     print ("=> model transferred to GPU." )
129+         # model = torch.nn.DataParallel(model).cuda() # for multi-gpu training 
130+         model  =  model .cuda ()
85131
86132    # define loss function (criterion) and optimizer 
87133    if  args .criterion  ==  'l2' :
88134        criterion  =  criteria .MaskedMSELoss ().cuda ()
89135    elif  args .criterion  ==  'l1' :
90136        criterion  =  criteria .MaskedL1Loss ().cuda ()
91137
92-     # sparsifier is a class for generating random sparse depth input from the ground truth 
93-     sparsifier  =  None 
94-     max_depth  =  args .max_depth  if  args .max_depth  >=  0.0  else  np .inf 
95-     if  args .sparsifier  ==  UniformSampling .name :
96-         sparsifier  =  UniformSampling (num_samples = args .num_samples , max_depth = max_depth )
97-     elif  args .sparsifier  ==  SimulatedStereo .name :
98-         sparsifier  =  SimulatedStereo (num_samples = args .num_samples , max_depth = max_depth )
99- 
100-     # Data loading code 
101-     print ("=> creating data loaders ..." )
102-     traindir  =  os .path .join ('data' , args .data , 'train' )
103-     valdir  =  os .path .join ('data' , args .data , 'val' )
104- 
105-     if  args .data  ==  'nyudepthv2' :
106-         from  dataloaders .nyu_dataloader  import  NYUDataset 
107-         if  not  args .evaluate :
108-             train_dataset  =  NYUDataset (traindir , type = 'train' ,
109-                 modality = args .modality , sparsifier = sparsifier )
110-         val_dataset  =  NYUDataset (valdir , type = 'val' ,
111-             modality = args .modality , sparsifier = sparsifier )
112- 
113-     elif  args .data  ==  'kitti' :
114-         from  dataloaders .kitti_dataloader  import  KITTIDataset 
115-         if  not  args .evaluate :
116-             train_dataset  =  KITTIDataset (traindir , type = 'train' ,
117-                 modality = args .modality , sparsifier = sparsifier )
118-         val_dataset  =  KITTIDataset (valdir , type = 'val' ,
119-             modality = args .modality , sparsifier = sparsifier )
120- 
121-     else :
122-         raise  RuntimeError ('Dataset not found.'  + 
123-                            'The dataset must be either of nyudepthv2 or kitti.' )
124- 
125-     # set batch size to be 1 for validation 
126-     val_loader  =  torch .utils .data .DataLoader (val_dataset ,
127-         batch_size = 1 , shuffle = False , num_workers = args .workers , pin_memory = True )
128-     print ("=> data loaders created." )
129- 
130-     if  args .evaluate :
131-         validate (val_loader , model , checkpoint ['epoch' ], write_to_file = False )
132-         return 
133- 
134-     # put construction of train loader here, for those who are interested in testing only 
135-     train_loader  =  torch .utils .data .DataLoader (
136-         train_dataset , batch_size = args .batch_size , shuffle = True ,
137-         num_workers = args .workers , pin_memory = True , sampler = None ,
138-         worker_init_fn = lambda  work_id :np .random .seed (work_id ))
139-         # worker_init_fn ensures different sampling patterns for each data loading thread 
140- 
141138    # create results folder, if not already exists 
142139    output_directory  =  utils .get_output_directory (args )
143140    if  not  os .path .exists (output_directory ):
@@ -146,6 +143,15 @@ def main():
146143    test_csv  =  os .path .join (output_directory , 'test.csv' )
147144    best_txt  =  os .path .join (output_directory , 'best.txt' )
148145
146+     # create new csv files with only header 
147+     if  not  args .resume :
148+         with  open (train_csv , 'w' ) as  csvfile :
149+             writer  =  csv .DictWriter (csvfile , fieldnames = fieldnames )
150+             writer .writeheader ()
151+         with  open (test_csv , 'w' ) as  csvfile :
152+             writer  =  csv .DictWriter (csvfile , fieldnames = fieldnames )
153+             writer .writeheader ()
154+ 
149155    for  epoch  in  range (start_epoch , args .epochs ):
150156        utils .adjust_learning_rate (optimizer , epoch , args .lr )
151157        train (train_loader , model , criterion , optimizer , epoch ) # train for one epoch 
0 commit comments