@@ -18,7 +18,7 @@ def __init__(self, sess, input_height=108, input_width=108, crop=True,
1818 batch_size = 64 , sample_num = 64 , output_height = 64 , output_width = 64 ,
1919 y_dim = None , z_dim = 100 , gf_dim = 64 , df_dim = 64 ,
2020 gfc_dim = 1024 , dfc_dim = 1024 , c_dim = 3 , dataset_name = 'default' ,
21- input_fname_pattern = '*.jpg' , checkpoint_dir = None , sample_dir = None ):
21+ input_fname_pattern = '*.jpg' , checkpoint_dir = None , sample_dir = None , data_dir = './data' ):
2222 """
2323
2424 Args:
@@ -69,12 +69,13 @@ def __init__(self, sess, input_height=108, input_width=108, crop=True,
6969 self .dataset_name = dataset_name
7070 self .input_fname_pattern = input_fname_pattern
7171 self .checkpoint_dir = checkpoint_dir
72+ self .data_dir = data_dir
7273
7374 if self .dataset_name == 'mnist' :
7475 self .data_X , self .data_y = self .load_mnist ()
7576 self .c_dim = self .data_X [0 ].shape [- 1 ]
7677 else :
77- self .data = glob (os .path .join ("./data" , self .dataset_name , self .input_fname_pattern ))
78+ self .data = glob (os .path .join (self . data_dir , self .dataset_name , self .input_fname_pattern ))
7879 imreadImg = imread (self .data [0 ])
7980 if len (imreadImg .shape ) >= 3 : #check if image is a non-grayscale image by checking channel number
8081 self .c_dim = imread (self .data [0 ]).shape [- 1 ]
@@ -192,7 +193,7 @@ def train(self, config):
192193 batch_idxs = min (len (self .data_X ), config .train_size ) // config .batch_size
193194 else :
194195 self .data = glob (os .path .join (
195- "./data" , config .dataset , self .input_fname_pattern ))
196+ config . data_dir , config .dataset , self .input_fname_pattern ))
196197 batch_idxs = min (len (self .data ), config .train_size ) // config .batch_size
197198
198199 for idx in xrange (0 , batch_idxs ):
@@ -451,7 +452,7 @@ def sampler(self, z, y=None):
451452 return tf .nn .sigmoid (deconv2d (h2 , [self .batch_size , s_h , s_w , self .c_dim ], name = 'g_h3' ))
452453
453454 def load_mnist (self ):
454- data_dir = os .path .join ("./data" , self .dataset_name )
455+ data_dir = os .path .join (self . data_dir , self .dataset_name )
455456
456457 fd = open (os .path .join (data_dir ,'train-images-idx3-ubyte' ))
457458 loaded = np .fromfile (file = fd ,dtype = np .uint8 )
0 commit comments