Skip to content

Commit 800e177

Browse files
authored
Merge pull request #24 from DAA233/fix-gradient-penalty-bug
fix gradient penalty bug for PyTorch 1.2+ and update the trainer code
2 parents 58bd3dc + 4aa6bed commit 800e177

File tree

1 file changed

+40
-34
lines changed

1 file changed

+40
-34
lines changed

trainer.py

Lines changed: 40 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,13 @@
44
from torch import autograd
55
from model.networks import Generator, LocalDis, GlobalDis
66

7+
78
from utils.tools import get_model_list, local_patch, spatial_discounting_mask
89
from utils.logger import get_logger
910

10-
1111
logger = get_logger()
1212

13+
1314
class Trainer(nn.Module):
1415
def __init__(self, config):
1516
super(Trainer, self).__init__()
@@ -33,6 +34,7 @@ def __init__(self, config):
3334

3435
def forward(self, x, bboxes, masks, ground_truth, compute_loss_g=False):
3536
self.train()
37+
l1_loss = nn.L1Loss()
3638
losses = {}
3739

3840
x1, x2, offset_flow = self.netG(x, masks)
@@ -42,35 +44,37 @@ def forward(self, x, bboxes, masks, ground_truth, compute_loss_g=False):
4244
local_patch_x1_inpaint = local_patch(x1_inpaint, bboxes)
4345
local_patch_x2_inpaint = local_patch(x2_inpaint, bboxes)
4446

45-
## D part
47+
# D part
4648
# wgan d loss
47-
local_patch_real_pred, local_patch_fake_pred = \
48-
self.dis_forward(self.localD, local_patch_gt, local_patch_x2_inpaint.detach())
49-
global_real_pred, global_fake_pred = \
50-
self.dis_forward(self.globalD, ground_truth, x2_inpaint.detach())
51-
losses['wgan_d'] = torch.mean(local_patch_fake_pred - local_patch_real_pred) \
52-
+ torch.mean(global_fake_pred - global_real_pred) * self.config['global_wgan_loss_alpha']
49+
local_patch_real_pred, local_patch_fake_pred = self.dis_forward(
50+
self.localD, local_patch_gt, local_patch_x2_inpaint.detach())
51+
global_real_pred, global_fake_pred = self.dis_forward(
52+
self.globalD, ground_truth, x2_inpaint.detach())
53+
losses['wgan_d'] = torch.mean(local_patch_fake_pred - local_patch_real_pred) + \
54+
torch.mean(global_fake_pred - global_real_pred) * self.config['global_wgan_loss_alpha']
5355
# gradients penalty loss
54-
local_penalty = self.calc_gradient_penalty(self.localD, local_patch_gt, local_patch_x2_inpaint.detach())
56+
local_penalty = self.calc_gradient_penalty(
57+
self.localD, local_patch_gt, local_patch_x2_inpaint.detach())
5558
global_penalty = self.calc_gradient_penalty(self.globalD, ground_truth, x2_inpaint.detach())
5659
losses['wgan_gp'] = local_penalty + global_penalty
5760

58-
## G part
61+
# G part
5962
if compute_loss_g:
6063
sd_mask = spatial_discounting_mask(self.config)
61-
losses['l1'] = nn.L1Loss()(local_patch_x1_inpaint * sd_mask, local_patch_gt * sd_mask) \
62-
* self.config['coarse_l1_alpha'] \
63-
+ nn.L1Loss()(local_patch_x2_inpaint * sd_mask, local_patch_gt * sd_mask)
64-
losses['ae'] = nn.L1Loss()(x1 * (1. - masks), ground_truth * (1. - masks)) \
65-
* self.config['coarse_l1_alpha'] \
66-
+ nn.L1Loss()(x2 * (1. - masks), ground_truth * (1. - masks))
67-
# wgan g loss
68-
local_patch_real_pred, local_patch_fake_pred = \
69-
self.dis_forward(self.localD, local_patch_gt, local_patch_x2_inpaint)
70-
global_real_pred, global_fake_pred = self.dis_forward(self.globalD, ground_truth, x2_inpaint)
64+
losses['l1'] = l1_loss(local_patch_x1_inpaint * sd_mask, local_patch_gt * sd_mask) * \
65+
self.config['coarse_l1_alpha'] + \
66+
l1_loss(local_patch_x2_inpaint * sd_mask, local_patch_gt * sd_mask)
67+
losses['ae'] = l1_loss(x1 * (1. - masks), ground_truth * (1. - masks)) * \
68+
self.config['coarse_l1_alpha'] + \
69+
l1_loss(x2 * (1. - masks), ground_truth * (1. - masks))
7170

72-
losses['wgan_g'] = - torch.mean(local_patch_fake_pred) \
73-
- torch.mean(global_fake_pred) * self.config['global_wgan_loss_alpha']
71+
# wgan g loss
72+
local_patch_real_pred, local_patch_fake_pred = self.dis_forward(
73+
self.localD, local_patch_gt, local_patch_x2_inpaint)
74+
global_real_pred, global_fake_pred = self.dis_forward(
75+
self.globalD, ground_truth, x2_inpaint)
76+
losses['wgan_g'] = - torch.mean(local_patch_fake_pred) - \
77+
torch.mean(global_fake_pred) * self.config['global_wgan_loss_alpha']
7478

7579
return losses, x2_inpaint, offset_flow
7680

@@ -85,26 +89,26 @@ def dis_forward(self, netD, ground_truth, x_inpaint):
8589

8690
# Calculate gradient penalty
8791
def calc_gradient_penalty(self, netD, real_data, fake_data):
88-
batch_size, channel, height, width = real_data.size()
89-
alpha = torch.rand(batch_size, 1)
90-
alpha = alpha.expand(batch_size, int(real_data.nelement() // batch_size)).contiguous() \
91-
.view(batch_size, channel, height, width)
92+
batch_size = real_data.size(0)
93+
alpha = torch.rand(batch_size, 1, 1, 1)
94+
alpha = alpha.expand_as(real_data)
9295
if self.use_cuda:
9396
alpha = alpha.cuda()
9497

95-
interpolates = alpha * real_data + ((1 - alpha) * fake_data)
96-
interpolates = autograd.Variable(interpolates, requires_grad=True)
98+
interpolates = alpha * real_data + (1 - alpha) * fake_data
99+
interpolates = interpolates.requires_grad_().clone()
97100

98101
disc_interpolates = netD(interpolates)
99-
100102
grad_outputs = torch.ones(disc_interpolates.size())
103+
101104
if self.use_cuda:
102105
grad_outputs = grad_outputs.cuda()
106+
103107
gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
104-
grad_outputs=grad_outputs,
105-
create_graph=True, retain_graph=True, only_inputs=True)[0]
106-
gradients = gradients.view(gradients.size(0), -1)
108+
grad_outputs=grad_outputs, create_graph=True,
109+
retain_graph=True, only_inputs=True)[0]
107110

111+
gradients = gradients.view(batch_size, -1)
108112
gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
109113

110114
return gradient_penalty
@@ -123,8 +127,10 @@ def save_model(self, checkpoint_dir, iteration):
123127
dis_name = os.path.join(checkpoint_dir, 'dis_%08d.pt' % iteration)
124128
opt_name = os.path.join(checkpoint_dir, 'optimizer.pt')
125129
torch.save(self.netG.state_dict(), gen_name)
126-
torch.save({'localD': self.localD.state_dict(), 'globalD': self.globalD.state_dict()}, dis_name)
127-
torch.save({'gen': self.optimizer_g.state_dict(), 'dis': self.optimizer_d.state_dict()}, opt_name)
130+
torch.save({'localD': self.localD.state_dict(),
131+
'globalD': self.globalD.state_dict()}, dis_name)
132+
torch.save({'gen': self.optimizer_g.state_dict(),
133+
'dis': self.optimizer_d.state_dict()}, opt_name)
128134

129135
def resume(self, checkpoint_dir, iteration=0, test=False):
130136
# Load generators

0 commit comments

Comments
 (0)