-
Notifications
You must be signed in to change notification settings - Fork 7
Description
hi, since you waveglow propose to use a soft-em version of vqvae, the core implementation is:
"
def _square_distance(x, code_book):
x = tf.cast(x, tf.float32)
code_book = tf.cast(code_book, tf.float32)
x_sg = tf.stop_gradient(x)
x_norm_sq = tf.reduce_sum(tf.square(x_sg), axis=-1, keepdims=True) # [b, 1]
code_book_norm_sq = tf.reduce_sum(tf.square(code_book), axis=-1, keepdims=True) # [V, 1]
scalar_prod = tf.matmul(x_sg, code_book, transpose_b=True) # [b, V]
dist_sq = x_norm_sq + tf.transpose(code_book_norm_sq) - 2 * scalar_prod # [b, V]
return tf.cast(dist_sq, x.dtype.base_dtype)
dist_sq = _square_distance(x, code_book)
q = tf.stop_gradient(tf.nn.softmax(-.5 * dist_sq))
discrete = tf.one_hot(tf.argmax(-dist_sq, axis=-1), depth=bottleneck_size, dtype=code_book.dtype.base_dtype)
dense = tf.matmul(discrete, code_book)
dense = dense + x - tf.stop_gradient(x)
def _get_losses(x, x_mask, dense, dist_sq, q):
x = tf.cast(x, tf.float32)
x_mask = tf.cast(x_mask, tf.float32)
dense = tf.cast(dense, tf.float32)
dist_sq = tf.cast(dist_sq, tf.float32)
q = tf.cast(q, tf.float32)
disc_loss = tf.reduce_sum(tf.reduce_sum(tf.square(x - tf.stop_gradient(dense)), -1)*x_mask) / (1e-10+tf.reduce_sum(x_mask))
# # M-step
em_loss = -tf.reduce_sum(tf.reduce_sum(-.5 * dist_sq * q, -1)*x_mask) / (1e-10+tf.reduce_sum(x_mask))
return disc_loss, em_loss
disc_loss, em_loss = _get_losses(x, x_mask, dense, dist_sq, q)
"
, however, the tensor2tensor has a different implementation:
https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/layers/vq_discrete.py
- multisample to get mean of soft-alignment
- when calculate em-loss, it has a different loss funtion type compare to your "M-step"
.
Could you hepl me with it?