@@ -508,28 +508,39 @@ def model_dir(self):
508508 return "{}_{}_{}_{}" .format (
509509 self .dataset_name , self .batch_size ,
510510 self .output_height , self .output_width )
511-
512- def save (self , checkpoint_dir , step ):
513- model_name = "DCGAN.model"
514- checkpoint_dir = os .path .join (checkpoint_dir , self .model_dir )
515511
512+ def save (self , checkpoint_dir , step , filename = 'model' , ckpt = True , frozen = False ):
513+ # model_name = "DCGAN.model"
514+ # checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir)
515+
516+ filename += '.b' + str (self .batch_size )
516517 if not os .path .exists (checkpoint_dir ):
517518 os .makedirs (checkpoint_dir )
518519
519- self .saver .save (self .sess ,
520- os .path .join (checkpoint_dir , model_name ),
521- global_step = step )
520+ if ckpt :
521+ self .saver .save (self .sess ,
522+ os .path .join (checkpoint_dir , filename ),
523+ global_step = step )
524+
525+ if frozen :
526+ tf .train .write_graph (
527+ tf .graph_util .convert_variables_to_constants (self .sess , self .sess .graph_def , ["generator_1/Tanh" ]),
528+ checkpoint_dir ,
529+ '{}-{:06d}_frz.pb' .format (filename , step ),
530+ as_text = False )
522531
523532 def load (self , checkpoint_dir ):
524- import re
525- print (" [*] Reading checkpoints..." )
526- checkpoint_dir = os .path .join (checkpoint_dir , self .model_dir )
533+ #import re
534+ print (" [*] Reading checkpoints..." , checkpoint_dir )
535+ # checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir)
536+ # print(" ->", checkpoint_dir)
527537
528538 ckpt = tf .train .get_checkpoint_state (checkpoint_dir )
529539 if ckpt and ckpt .model_checkpoint_path :
530540 ckpt_name = os .path .basename (ckpt .model_checkpoint_path )
531541 self .saver .restore (self .sess , os .path .join (checkpoint_dir , ckpt_name ))
532- counter = int (next (re .finditer ("(\d+)(?!.*\d)" ,ckpt_name )).group (0 ))
542+ #counter = int(next(re.finditer("(\d+)(?!.*\d)",ckpt_name)).group(0))
543+ counter = int (ckpt_name .split ('-' )[- 1 ])
533544 print (" [*] Success to read {}" .format (ckpt_name ))
534545 return True , counter
535546 else :
0 commit comments