diff --git a/models/networks.py b/models/networks.py index c49b9b7bda0..fcc2b6bc567 100644 --- a/models/networks.py +++ b/models/networks.py @@ -281,6 +281,7 @@ def cal_gradient_penalty(netD, real_data, fake_data, device, type='mixed', const lambda_gp (float) -- weight for this loss Returns the gradient penalty loss + NOTE: Strongly advised not to use batch/instance norm with the Discriminator(or Critic) if using gradient penalty! """ if lambda_gp > 0.0: if type == 'real': # either use real images, fake images, or a linear interpolation of two.