Skip to content

Commit ec8d0e5

Browse files
committed
Options to export, freeze and prune graph
1 parent 1f358e8 commit ec8d0e5

File tree

2 files changed

+39
-16
lines changed

2 files changed

+39
-16
lines changed

main.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
flags.DEFINE_boolean("train", False, "True for training, False for testing [False]")
2626
flags.DEFINE_boolean("crop", False, "True for training, False for testing [False]")
2727
flags.DEFINE_boolean("visualize", False, "True for visualizing, False for nothing [False]")
28+
flags.DEFINE_boolean("export", False, "True for exporting with new batch size")
29+
flags.DEFINE_boolean("freeze", False, "True for exporting with new batch size")
2830
flags.DEFINE_integer("z_dim", 100, "dimensions of z")
2931
flags.DEFINE_string("z_dist", "uniform_signed", "'normal01' or 'uniform_unsigned' or uniform_signed")
3032
#flags.DEFINE_integer("generate_test_images", 100, "Number of images to generate during test. [100]")
@@ -87,9 +89,10 @@ def main(_):
8789
if FLAGS.train:
8890
dcgan.train(FLAGS)
8991
else:
90-
if not dcgan.load(FLAGS.checkpoint_dir)[0]:
91-
raise Exception("[!] Train a model first, then run test mode")
92-
92+
load_success, load_counter = dcgan.load(FLAGS.checkpoint_dir)
93+
if not load_success:
94+
raise Exception("Checkpoint not found in " + FLAGS.checkpoint_dir)
95+
9396

9497
# to_json("./web/js/layers.js", [dcgan.h0_w, dcgan.h0_b, dcgan.g_bn0],
9598
# [dcgan.h1_w, dcgan.h1_b, dcgan.g_bn1],
@@ -98,8 +101,17 @@ def main(_):
98101
# [dcgan.h4_w, dcgan.h4_b, None])
99102

100103
# Below is codes for visualization
101-
OPTION = 1
102-
visualize(sess, dcgan, FLAGS, OPTION)
104+
if FLAGS.export:
105+
export_dir = os.path.join(FLAGS.checkpoint_dir, 'export_b'+str(FLAGS.batch_size))
106+
dcgan.save(export_dir, load_counter, ckpt=True, frozen=False)
107+
108+
if FLAGS.freeze:
109+
export_dir = os.path.join(FLAGS.checkpoint_dir, 'frozen_b'+str(FLAGS.batch_size))
110+
dcgan.save(export_dir, load_counter, ckpt=False, frozen=True)
111+
112+
if FLAGS.visualize:
113+
OPTION = 1
114+
visualize(sess, dcgan, FLAGS, OPTION, FLAGS.sample_dir)
103115

104116
if __name__ == '__main__':
105117
tf.app.run()

model.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -508,28 +508,39 @@ def model_dir(self):
508508
return "{}_{}_{}_{}".format(
509509
self.dataset_name, self.batch_size,
510510
self.output_height, self.output_width)
511-
512-
def save(self, checkpoint_dir, step):
513-
model_name = "DCGAN.model"
514-
checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir)
515511

512+
def save(self, checkpoint_dir, step, filename='model', ckpt=True, frozen=False):
513+
# model_name = "DCGAN.model"
514+
# checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir)
515+
516+
filename += '.b' + str(self.batch_size)
516517
if not os.path.exists(checkpoint_dir):
517518
os.makedirs(checkpoint_dir)
518519

519-
self.saver.save(self.sess,
520-
os.path.join(checkpoint_dir, model_name),
521-
global_step=step)
520+
if ckpt:
521+
self.saver.save(self.sess,
522+
os.path.join(checkpoint_dir, filename),
523+
global_step=step)
524+
525+
if frozen:
526+
tf.train.write_graph(
527+
tf.graph_util.convert_variables_to_constants(self.sess, self.sess.graph_def, ["generator_1/Tanh"]),
528+
checkpoint_dir,
529+
'{}-{:06d}_frz.pb'.format(filename, step),
530+
as_text=False)
522531

523532
def load(self, checkpoint_dir):
524-
import re
525-
print(" [*] Reading checkpoints...")
526-
checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir)
533+
#import re
534+
print(" [*] Reading checkpoints...", checkpoint_dir)
535+
# checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir)
536+
# print(" ->", checkpoint_dir)
527537

528538
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
529539
if ckpt and ckpt.model_checkpoint_path:
530540
ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
531541
self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
532-
counter = int(next(re.finditer("(\d+)(?!.*\d)",ckpt_name)).group(0))
542+
#counter = int(next(re.finditer("(\d+)(?!.*\d)",ckpt_name)).group(0))
543+
counter = int(ckpt_name.split('-')[-1])
533544
print(" [*] Success to read {}".format(ckpt_name))
534545
return True, counter
535546
else:

0 commit comments

Comments
 (0)