44from torch import autograd
55from model .networks import Generator , LocalDis , GlobalDis
66
7+
78from utils .tools import get_model_list , local_patch , spatial_discounting_mask
89from utils .logger import get_logger
910
10-
1111logger = get_logger ()
1212
13+
1314class Trainer (nn .Module ):
1415 def __init__ (self , config ):
1516 super (Trainer , self ).__init__ ()
@@ -33,6 +34,7 @@ def __init__(self, config):
3334
3435 def forward (self , x , bboxes , masks , ground_truth , compute_loss_g = False ):
3536 self .train ()
37+ l1_loss = nn .L1Loss ()
3638 losses = {}
3739
3840 x1 , x2 , offset_flow = self .netG (x , masks )
@@ -42,35 +44,37 @@ def forward(self, x, bboxes, masks, ground_truth, compute_loss_g=False):
4244 local_patch_x1_inpaint = local_patch (x1_inpaint , bboxes )
4345 local_patch_x2_inpaint = local_patch (x2_inpaint , bboxes )
4446
45- ## D part
47+ # D part
4648 # wgan d loss
47- local_patch_real_pred , local_patch_fake_pred = \
48- self .dis_forward ( self . localD , local_patch_gt , local_patch_x2_inpaint .detach ())
49- global_real_pred , global_fake_pred = \
50- self .dis_forward ( self . globalD , ground_truth , x2_inpaint .detach ())
51- losses ['wgan_d' ] = torch .mean (local_patch_fake_pred - local_patch_real_pred ) \
52- + torch .mean (global_fake_pred - global_real_pred ) * self .config ['global_wgan_loss_alpha' ]
49+ local_patch_real_pred , local_patch_fake_pred = self . dis_forward (
50+ self .localD , local_patch_gt , local_patch_x2_inpaint .detach ())
51+ global_real_pred , global_fake_pred = self . dis_forward (
52+ self .globalD , ground_truth , x2_inpaint .detach ())
53+ losses ['wgan_d' ] = torch .mean (local_patch_fake_pred - local_patch_real_pred ) + \
54+ torch .mean (global_fake_pred - global_real_pred ) * self .config ['global_wgan_loss_alpha' ]
5355 # gradients penalty loss
54- local_penalty = self .calc_gradient_penalty (self .localD , local_patch_gt , local_patch_x2_inpaint .detach ())
56+ local_penalty = self .calc_gradient_penalty (
57+ self .localD , local_patch_gt , local_patch_x2_inpaint .detach ())
5558 global_penalty = self .calc_gradient_penalty (self .globalD , ground_truth , x2_inpaint .detach ())
5659 losses ['wgan_gp' ] = local_penalty + global_penalty
5760
58- ## G part
61+ # G part
5962 if compute_loss_g :
6063 sd_mask = spatial_discounting_mask (self .config )
61- losses ['l1' ] = nn .L1Loss ()(local_patch_x1_inpaint * sd_mask , local_patch_gt * sd_mask ) \
62- * self .config ['coarse_l1_alpha' ] \
63- + nn .L1Loss ()(local_patch_x2_inpaint * sd_mask , local_patch_gt * sd_mask )
64- losses ['ae' ] = nn .L1Loss ()(x1 * (1. - masks ), ground_truth * (1. - masks )) \
65- * self .config ['coarse_l1_alpha' ] \
66- + nn .L1Loss ()(x2 * (1. - masks ), ground_truth * (1. - masks ))
67- # wgan g loss
68- local_patch_real_pred , local_patch_fake_pred = \
69- self .dis_forward (self .localD , local_patch_gt , local_patch_x2_inpaint )
70- global_real_pred , global_fake_pred = self .dis_forward (self .globalD , ground_truth , x2_inpaint )
64+ losses ['l1' ] = l1_loss (local_patch_x1_inpaint * sd_mask , local_patch_gt * sd_mask ) * \
65+ self .config ['coarse_l1_alpha' ] + \
66+ l1_loss (local_patch_x2_inpaint * sd_mask , local_patch_gt * sd_mask )
67+ losses ['ae' ] = l1_loss (x1 * (1. - masks ), ground_truth * (1. - masks )) * \
68+ self .config ['coarse_l1_alpha' ] + \
69+ l1_loss (x2 * (1. - masks ), ground_truth * (1. - masks ))
7170
72- losses ['wgan_g' ] = - torch .mean (local_patch_fake_pred ) \
73- - torch .mean (global_fake_pred ) * self .config ['global_wgan_loss_alpha' ]
71+ # wgan g loss
72+ local_patch_real_pred , local_patch_fake_pred = self .dis_forward (
73+ self .localD , local_patch_gt , local_patch_x2_inpaint )
74+ global_real_pred , global_fake_pred = self .dis_forward (
75+ self .globalD , ground_truth , x2_inpaint )
76+ losses ['wgan_g' ] = - torch .mean (local_patch_fake_pred ) - \
77+ torch .mean (global_fake_pred ) * self .config ['global_wgan_loss_alpha' ]
7478
7579 return losses , x2_inpaint , offset_flow
7680
@@ -85,26 +89,26 @@ def dis_forward(self, netD, ground_truth, x_inpaint):
8589
8690 # Calculate gradient penalty
8791 def calc_gradient_penalty (self , netD , real_data , fake_data ):
88- batch_size , channel , height , width = real_data .size ()
89- alpha = torch .rand (batch_size , 1 )
90- alpha = alpha .expand (batch_size , int (real_data .nelement () // batch_size )).contiguous () \
91- .view (batch_size , channel , height , width )
92+ batch_size = real_data .size (0 )
93+ alpha = torch .rand (batch_size , 1 , 1 , 1 )
94+ alpha = alpha .expand_as (real_data )
9295 if self .use_cuda :
9396 alpha = alpha .cuda ()
9497
95- interpolates = alpha * real_data + (( 1 - alpha ) * fake_data )
96- interpolates = autograd . Variable ( interpolates , requires_grad = True )
98+ interpolates = alpha * real_data + (1 - alpha ) * fake_data
99+ interpolates = interpolates . requires_grad_ (). clone ( )
97100
98101 disc_interpolates = netD (interpolates )
99-
100102 grad_outputs = torch .ones (disc_interpolates .size ())
103+
101104 if self .use_cuda :
102105 grad_outputs = grad_outputs .cuda ()
106+
103107 gradients = autograd .grad (outputs = disc_interpolates , inputs = interpolates ,
104- grad_outputs = grad_outputs ,
105- create_graph = True , retain_graph = True , only_inputs = True )[0 ]
106- gradients = gradients .view (gradients .size (0 ), - 1 )
108+ grad_outputs = grad_outputs , create_graph = True ,
109+ retain_graph = True , only_inputs = True )[0 ]
107110
111+ gradients = gradients .view (batch_size , - 1 )
108112 gradient_penalty = ((gradients .norm (2 , dim = 1 ) - 1 ) ** 2 ).mean ()
109113
110114 return gradient_penalty
@@ -123,8 +127,10 @@ def save_model(self, checkpoint_dir, iteration):
123127 dis_name = os .path .join (checkpoint_dir , 'dis_%08d.pt' % iteration )
124128 opt_name = os .path .join (checkpoint_dir , 'optimizer.pt' )
125129 torch .save (self .netG .state_dict (), gen_name )
126- torch .save ({'localD' : self .localD .state_dict (), 'globalD' : self .globalD .state_dict ()}, dis_name )
127- torch .save ({'gen' : self .optimizer_g .state_dict (), 'dis' : self .optimizer_d .state_dict ()}, opt_name )
130+ torch .save ({'localD' : self .localD .state_dict (),
131+ 'globalD' : self .globalD .state_dict ()}, dis_name )
132+ torch .save ({'gen' : self .optimizer_g .state_dict (),
133+ 'dis' : self .optimizer_d .state_dict ()}, opt_name )
128134
129135 def resume (self , checkpoint_dir , iteration = 0 , test = False ):
130136 # Load generators
0 commit comments