Skip to content

Commit 94e0a55

Browse files
committed
choose latent distribution using --z_dist from { 'normal01': standard normal; 'uniform_signed': (-1,1); 'uniform_unsigned':(0,1) }
1 parent 351a654 commit 94e0a55

File tree

2 files changed

+13
-5
lines changed

2 files changed

+13
-5
lines changed

main.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@
2525
flags.DEFINE_boolean("train", False, "True for training, False for testing [False]")
2626
flags.DEFINE_boolean("crop", False, "True for training, False for testing [False]")
2727
flags.DEFINE_boolean("visualize", False, "True for visualizing, False for nothing [False]")
28-
flags.DEFINE_integer("generate_test_images", 100, "Number of images to generate during test. [100]")
28+
flags.DEFINE_integer("z_dim", 100, "dimensions of z")
29+
flags.DEFINE_string("z_dist", "uniform_signed", "'normal01' or 'uniform_unsigned' or uniform_signed")
30+
#flags.DEFINE_integer("generate_test_images", 100, "Number of images to generate during test. [100]")
2931
FLAGS = flags.FLAGS
3032

3133
def main(_):
@@ -56,7 +58,7 @@ def main(_):
5658
batch_size=FLAGS.batch_size,
5759
sample_num=FLAGS.batch_size,
5860
y_dim=10,
59-
z_dim=FLAGS.generate_test_images,
61+
z_dim=FLAGS.z_dim,
6062
dataset_name=FLAGS.dataset,
6163
input_fname_pattern=FLAGS.input_fname_pattern,
6264
crop=FLAGS.crop,
@@ -72,7 +74,7 @@ def main(_):
7274
output_height=FLAGS.output_height,
7375
batch_size=FLAGS.batch_size,
7476
sample_num=FLAGS.batch_size,
75-
z_dim=FLAGS.generate_test_images,
77+
z_dim=FLAGS.z_dim,
7678
dataset_name=FLAGS.dataset,
7779
input_fname_pattern=FLAGS.input_fname_pattern,
7880
crop=FLAGS.crop,

model.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,12 @@
1313
def conv_out_size_same(size, stride):
1414
return int(math.ceil(float(size) / float(stride)))
1515

16+
def gen_random(mode, size):
17+
if mode=='normal01': return np.random.normal(0,1,size=size)
18+
if mode=='uniform_signed': return np.random.uniform(-1,1,size=size)
19+
if mode=='uniform_unsigned': return np.random.uniform(0,1,size=size)
20+
21+
1622
class DCGAN(object):
1723
def __init__(self, sess, input_height=108, input_width=108, crop=True,
1824
batch_size=64, sample_num = 64, output_height=64, output_width=64,
@@ -166,7 +172,7 @@ def train(self, config):
166172
[self.z_sum, self.d_sum, self.d_loss_real_sum, self.d_loss_sum])
167173
self.writer = SummaryWriter("./logs", self.sess.graph)
168174

169-
sample_z = np.random.uniform(-1, 1, size=(self.sample_num , self.z_dim))
175+
sample_z = gen_random(config.z_dist, size=(self.sample_num , self.z_dim))
170176

171177
if config.dataset == 'mnist':
172178
sample_inputs = self.data_X[0:self.sample_num]
@@ -223,7 +229,7 @@ def train(self, config):
223229
else:
224230
batch_images = np.array(batch).astype(np.float32)
225231

226-
batch_z = np.random.uniform(-1, 1, [config.batch_size, self.z_dim]) \
232+
batch_z = gen_random(config.z_dist, size=[config.batch_size, self.z_dim]) \
227233
.astype(np.float32)
228234

229235
if config.dataset == 'mnist':

0 commit comments

Comments
 (0)