Skip to content

关于WGAN-gp源码的两点问题 #45

@donpromax

Description

@donpromax

在看源码的过程中发现了一点小问题

  1. wgan_train.py源码还是使用了sigmoid再做cross_entro_loss,但是WGAN应该直接返回Discrimintaror的输出logits作为loss
def d_loss_fn(generator, discriminator, batch_z, real_image):
    fake_image = generator(batch_z, training=True)
    d_fake_score = discriminator(fake_image, training=True)
    d_real_score = discriminator(real_image, training=True)

    loss = tf.reduce_mean(d_fake_score - d_real_score)
    # lambda = 10
    gp = gradient_penalty(discriminator, real_image, fake_image) * 10.

    loss = loss + gp
    return loss, gp

def g_loss_fn(generator, discriminator, batch_z):
    fake_image = generator(batch_z, training=True)
    d_fake_logits = discriminator(fake_image, training=True)
    # loss = celoss_ones(d_fake_logits)
    loss = -tf.reduce_mean(d_fake_logits)
    return loss

2.按照WGAN的要求改完logits作为loss后,我发现train起来不能收敛,经过反复检查,发现是gradient penalty的计算有些问题,将原有函数如下之后可以很好地收敛:

def gradient_penalty(discriminator, real_image, fake_image):
    batchsz = real_image.shape[0]
    # dtype caused disconvergence?
    t = tf.random.uniform([batchsz, 1, 1, 1], minval=0., maxval=1., dtype=tf.float32)
    x_hat = t * real_image + (1. - t) * fake_image
    with tf.GradientTape() as tape:
        tape.watch(x_hat)
        Dx = discriminator(x_hat, training=True)
    grads = tape.gradient(Dx, x_hat)
    slopes = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=[1, 2, 3]))
    gp = tf.reduce_mean((slopes - 1.) ** 2)
    return gp

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions