|
13 | 13 | def conv_out_size_same(size, stride): |
14 | 14 | return int(math.ceil(float(size) / float(stride))) |
15 | 15 |
|
| 16 | +def gen_random(mode, size): |
| 17 | + if mode=='normal01': return np.random.normal(0,1,size=size) |
| 18 | + if mode=='uniform_signed': return np.random.uniform(-1,1,size=size) |
| 19 | + if mode=='uniform_unsigned': return np.random.uniform(0,1,size=size) |
| 20 | + |
| 21 | + |
16 | 22 | class DCGAN(object): |
17 | 23 | def __init__(self, sess, input_height=108, input_width=108, crop=True, |
18 | 24 | batch_size=64, sample_num = 64, output_height=64, output_width=64, |
@@ -166,7 +172,7 @@ def train(self, config): |
166 | 172 | [self.z_sum, self.d_sum, self.d_loss_real_sum, self.d_loss_sum]) |
167 | 173 | self.writer = SummaryWriter("./logs", self.sess.graph) |
168 | 174 |
|
169 | | - sample_z = np.random.uniform(-1, 1, size=(self.sample_num , self.z_dim)) |
| 175 | + sample_z = gen_random(config.z_dist, size=(self.sample_num , self.z_dim)) |
170 | 176 |
|
171 | 177 | if config.dataset == 'mnist': |
172 | 178 | sample_inputs = self.data_X[0:self.sample_num] |
@@ -223,7 +229,7 @@ def train(self, config): |
223 | 229 | else: |
224 | 230 | batch_images = np.array(batch).astype(np.float32) |
225 | 231 |
|
226 | | - batch_z = np.random.uniform(-1, 1, [config.batch_size, self.z_dim]) \ |
| 232 | + batch_z = gen_random(config.z_dist, size=[config.batch_size, self.z_dim]) \ |
227 | 233 | .astype(np.float32) |
228 | 234 |
|
229 | 235 | if config.dataset == 'mnist': |
|
0 commit comments