From 241bc537d9ac81bd1b064e8417dfa634853b5aeb Mon Sep 17 00:00:00 2001 From: Tom White Date: Mon, 2 Apr 2018 00:27:59 +1200 Subject: [PATCH] Updates to VAE (and ops) to enable variable batch sizes --- VAE.py | 18 ++++++++---------- ops.py | 13 +++++++++---- 2 files changed, 17 insertions(+), 14 deletions(-) mode change 100644 => 100755 VAE.py mode change 100644 => 100755 ops.py diff --git a/VAE.py b/VAE.py old mode 100644 new mode 100755 index 65f20e78..00ce7520 --- a/VAE.py +++ b/VAE.py @@ -55,7 +55,7 @@ def encoder(self, x, is_training=True, reuse=False): net = lrelu(conv2d(x, 64, 4, 4, 2, 2, name='en_conv1')) net = lrelu(bn(conv2d(net, 128, 4, 4, 2, 2, name='en_conv2'), is_training=is_training, scope='en_bn2')) - net = tf.reshape(net, [self.batch_size, -1]) + net = tf.contrib.layers.flatten(net) net = lrelu(bn(linear(net, 1024, scope='en_fc3'), is_training=is_training, scope='en_bn3')) gaussian_params = linear(net, 2 * self.z_dim, scope='en_fc4') @@ -74,25 +74,24 @@ def decoder(self, z, is_training=True, reuse=False): with tf.variable_scope("decoder", reuse=reuse): net = tf.nn.relu(bn(linear(z, 1024, scope='de_fc1'), is_training=is_training, scope='de_bn1')) net = tf.nn.relu(bn(linear(net, 128 * 7 * 7, scope='de_fc2'), is_training=is_training, scope='de_bn2')) - net = tf.reshape(net, [self.batch_size, 7, 7, 128]) + net = tf.reshape(net, [-1, 7, 7, 128]) net = tf.nn.relu( - bn(deconv2d(net, [self.batch_size, 14, 14, 64], 4, 4, 2, 2, name='de_dc3'), is_training=is_training, + bn(deconv2d(net, [None, 14, 14, 64], 4, 4, 2, 2, name='de_dc3'), is_training=is_training, scope='de_bn3')) - out = tf.nn.sigmoid(deconv2d(net, [self.batch_size, 28, 28, 1], 4, 4, 2, 2, name='de_dc4')) + out = tf.nn.sigmoid(deconv2d(net, [None, 28, 28, 1], 4, 4, 2, 2, name='de_dc4')) return out def build_model(self): # some parameters image_dims = [self.input_height, self.input_width, self.c_dim] - bs = self.batch_size """ Graph Input """ # images - self.inputs = tf.placeholder(tf.float32, [bs] + image_dims, name='real_images') + self.inputs = tf.placeholder(tf.float32, [None] + image_dims, name='real_images') # noises - self.z = tf.placeholder(tf.float32, [bs, self.z_dim], name='z') + self.z = tf.placeholder(tf.float32, [None, self.z_dim], name='z') """ Loss Function """ # encoding @@ -247,9 +246,8 @@ def visualize_results(self, epoch): @property def model_dir(self): - return "{}_{}_{}_{}".format( - self.model_name, self.dataset_name, - self.batch_size, self.z_dim) + return "{}_{}_{}".format( + self.model_name, self.dataset_name, self.z_dim) def save(self, checkpoint_dir, step): checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir, self.model_name) diff --git a/ops.py b/ops.py old mode 100644 new mode 100755 index 53dccac1..caa4a1fa --- a/ops.py +++ b/ops.py @@ -41,14 +41,19 @@ def conv2d(input_, output_dim, k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02, name="co conv = tf.nn.conv2d(input_, w, strides=[1, d_h, d_w, 1], padding='SAME') biases = tf.get_variable('biases', [output_dim], initializer=tf.constant_initializer(0.0)) - conv = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape()) + conv = tf.nn.bias_add(conv, biases) return conv def deconv2d(input_, output_shape, k_h=5, k_w=5, d_h=2, d_w=2, name="deconv2d", stddev=0.02, with_w=False): with tf.variable_scope(name): + output_shape_chan = output_shape[-1] + if output_shape[0] is None: + batch_size = tf.shape(input_)[0] + output_shape = tf.stack([batch_size] + output_shape[1:]) + # filter : [height, width, output_channels, in_channels] - w = tf.get_variable('w', [k_h, k_w, output_shape[-1], input_.get_shape()[-1]], + w = tf.get_variable('w', [k_h, k_w, output_shape_chan, input_.get_shape()[-1]], initializer=tf.random_normal_initializer(stddev=stddev)) try: @@ -58,8 +63,8 @@ def deconv2d(input_, output_shape, k_h=5, k_w=5, d_h=2, d_w=2, name="deconv2d", except AttributeError: deconv = tf.nn.deconv2d(input_, w, output_shape=output_shape, strides=[1, d_h, d_w, 1]) - biases = tf.get_variable('biases', [output_shape[-1]], initializer=tf.constant_initializer(0.0)) - deconv = tf.reshape(tf.nn.bias_add(deconv, biases), deconv.get_shape()) + biases = tf.get_variable('biases', [output_shape_chan], initializer=tf.constant_initializer(0.0)) + deconv = tf.nn.bias_add(deconv, biases) if with_w: return deconv, w, biases