-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Open
Description
在看源码的过程中发现了一点小问题
- wgan_train.py源码还是使用了sigmoid再做cross_entro_loss,但是WGAN应该直接返回Discrimintaror的输出logits作为loss
def d_loss_fn(generator, discriminator, batch_z, real_image):
fake_image = generator(batch_z, training=True)
d_fake_score = discriminator(fake_image, training=True)
d_real_score = discriminator(real_image, training=True)
loss = tf.reduce_mean(d_fake_score - d_real_score)
# lambda = 10
gp = gradient_penalty(discriminator, real_image, fake_image) * 10.
loss = loss + gp
return loss, gp
def g_loss_fn(generator, discriminator, batch_z):
fake_image = generator(batch_z, training=True)
d_fake_logits = discriminator(fake_image, training=True)
# loss = celoss_ones(d_fake_logits)
loss = -tf.reduce_mean(d_fake_logits)
return loss
2.按照WGAN的要求改完logits作为loss后,我发现train起来不能收敛,经过反复检查,发现是gradient penalty的计算有些问题,将原有函数如下之后可以很好地收敛:
def gradient_penalty(discriminator, real_image, fake_image):
batchsz = real_image.shape[0]
# dtype caused disconvergence?
t = tf.random.uniform([batchsz, 1, 1, 1], minval=0., maxval=1., dtype=tf.float32)
x_hat = t * real_image + (1. - t) * fake_image
with tf.GradientTape() as tape:
tape.watch(x_hat)
Dx = discriminator(x_hat, training=True)
grads = tape.gradient(Dx, x_hat)
slopes = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=[1, 2, 3]))
gp = tf.reduce_mean((slopes - 1.) ** 2)
return gp
Metadata
Metadata
Assignees
Labels
No labels