@@ -108,14 +108,15 @@ def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
108108 def load_networks (self , which_epoch ):
109109 for name in self .model_names :
110110 if isinstance (name , str ):
111- save_filename = '%s_net_%s.pth' % (which_epoch , name )
112- save_path = os .path .join (self .save_dir , save_filename )
111+ load_filename = '%s_net_%s.pth' % (which_epoch , name )
112+ load_path = os .path .join (self .save_dir , load_filename )
113113 net = getattr (self , 'net' + name )
114114 if isinstance (net , torch .nn .DataParallel ):
115115 net = net .module
116+ print ('loading the model from %s' % load_path )
116117 # if you are using PyTorch newer than 0.4 (e.g., built from
117118 # GitHub source), you can remove str() on self.device
118- state_dict = torch .load (save_path , map_location = str (self .device ))
119+ state_dict = torch .load (load_path , map_location = str (self .device ))
119120 # patch InstanceNorm checkpoints prior to 0.4
120121 for key in list (state_dict .keys ()): # need to copy keys here because we mutate in loop
121122 self .__patch_instance_norm_state_dict (state_dict , net , key .split ('.' ))
@@ -134,3 +135,12 @@ def print_networks(self, verbose):
134135 print (net )
135136 print ('[Network %s] Total number of parameters : %.3f M' % (name , num_params / 1e6 ))
136137 print ('-----------------------------------------------' )
138+
139+ # set requies_grad=Fasle to avoid computation
140+ def set_requires_grad (self , nets , requires_grad = False ):
141+ if not isinstance (nets , list ):
142+ nets = [nets ]
143+ for net in nets :
144+ if net is not None :
145+ for param in net .parameters ():
146+ param .requires_grad = requires_grad
0 commit comments