Skip to content

Commit de11ba8

Browse files
authored
Merge pull request #208 from GeiduanLiu/devel
transform parameters from old model to new model.
2 parents 9cd81c4 + 8385272 commit de11ba8

File tree

3 files changed

+74
-1
lines changed

3 files changed

+74
-1
lines changed

source/train/Fitting.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ def build (self,
212212
if ii >= 1 and self.n_neuron[ii] == self.n_neuron[ii-1] :
213213
layer+= one_layer(layer, self.n_neuron[ii], name='layer_'+str(ii)+'_type_'+str(type_i)+suffix, reuse=reuse, seed = self.seed, use_timestep = self.resnet_dt, activation_fn = self.fitting_activation_fn, precision = self.fitting_precision, trainable = self.trainable[ii])
214214
else :
215-
layer = one_layer(layer, self.n_neuron[ii], name='layer_'+str(ii)+'_type_'+str(type_i)+suffix, reuse=reuse, seed = self.seed, precision = self.fitting_precision, trainable = self.trainable[ii])
215+
layer = one_layer(layer, self.n_neuron[ii], name='layer_'+str(ii)+'_type_'+str(type_i)+suffix, reuse=reuse, seed = self.seed, activation_fn = self.fitting_activation_fn, precision = self.fitting_precision, trainable = self.trainable[ii])
216216
final_layer = one_layer(layer, 1, activation_fn = None, bavg = type_bias_ae, name='final_layer_type_'+str(type_i)+suffix, reuse=reuse, seed = self.seed, precision = self.fitting_precision, trainable = self.trainable[-1])
217217

218218
if type_i < len(self.atom_ener) and self.atom_ener[type_i] is not None:

source/train/main.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from .freeze import freeze
55
from .config import config
66
from .test import test
7+
from .transform import transform
78

89
def main () :
910
parser = argparse.ArgumentParser(
@@ -15,6 +16,13 @@ def main () :
1516
# help="the output json file")
1617

1718
default_num_inter_threads = 0
19+
parser_transform = subparsers.add_parser('transform', help='pass parameters to another model')
20+
parser_transform.add_argument('-r', "--raw-model", default = "raw_frozen_model.pb", type=str,
21+
help = "the model receiving parameters")
22+
parser_transform.add_argument("-o","--old-model", default = "old_frozen_model.pb", type=str,
23+
help='the model providing parameters')
24+
parser_transform.add_argument("-n", "--output", default = "frozen_model.pb", type=str,
25+
help = "the model after passing parameters")
1826
parser_train = subparsers.add_parser('train', help='train a model')
1927
parser_train.add_argument('INPUT',
2028
help='the input parameter file in json format')
@@ -62,5 +70,7 @@ def main () :
6270
config(args)
6371
elif args.command == 'test' :
6472
test(args)
73+
elif args.command == 'transform' :
74+
transform(args)
6575
else :
6676
raise RuntimeError('unknown command ' + args.command)

source/train/transform.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
from deepmd.env import tf
2+
import re
3+
def transform(args):
4+
new_graph = load_graph(args.raw_model)
5+
old_graph = load_graph(args.old_model)
6+
print("%d ops in the raw graph\n%d ops in the old graph" %(len(new_graph.node),len(old_graph.node)))
7+
transform_node = load_data(new_graph,old_graph)
8+
for node in new_graph.node:
9+
if node.name in transform_node:
10+
print("%s is passed from old graph to raw graph" % node.name)
11+
node.attr["value"].tensor.CopyFrom(transform_node[node.name].attr["value"].tensor)
12+
with tf.gfile.GFile(args.output, mode='wb') as f:
13+
f.write(new_graph.SerializeToString())
14+
print("the output model is saved in %s" % args.output)
15+
16+
def load_graph(graphName):
17+
graph_def = tf.GraphDef()
18+
with open(graphName,"rb") as f:
19+
graph_def.ParseFromString(f.read())
20+
with tf.Graph().as_default() as graph:
21+
tf.import_graph_def(graph_def,name = "")
22+
return graph_def
23+
24+
def load_data(new_graph,old_graph):
25+
new_graph_node = load_transform_node(new_graph)
26+
old_graph_node = load_transform_node(old_graph)
27+
if len(new_graph_node) != len(old_graph_node):
28+
raise RuntimeError("New graph and original graph has different network structure\n")
29+
for nodeName in old_graph_node.keys():
30+
check_dim(new_graph_node, old_graph_node, nodeName)
31+
check_precision(new_graph_node, old_graph_node, nodeName)
32+
return old_graph_node
33+
34+
35+
def check_precision(new_graph_node, old_graph_node, nodeName):
36+
new_graph_precision = new_graph_node[nodeName].attr["value"].tensor.dtype
37+
old_graph_precision = old_graph_node[nodeName].attr["value"].tensor.dtype
38+
if new_graph_precision != old_graph_precision:
39+
raise RuntimeError("New graph and original graph has different"+nodeName+" precision\n")
40+
41+
def check_dim(new_graph_node, old_graph_node, nodeName):
42+
new_graph_dim = new_graph_node[nodeName].attr["value"].tensor.tensor_shape
43+
old_graph_dim = old_graph_node[nodeName].attr["value"].tensor.tensor_shape
44+
if new_graph_dim != old_graph_dim:
45+
raise RuntimeError("New graph and original graph has different"+nodeName+" dim\n")
46+
47+
48+
def load_transform_node(graph):
49+
transform_node = {}
50+
transform_node_pattern = "\
51+
filter_type_\d+/matrix_\d+_\d+|\
52+
filter_type_\d+/bias_\d+_\d+|\
53+
filter_type_\d+/idt_\d+_\d+\
54+
layer_\d+_type_\d+/matrix|\
55+
layer_\d+_type_\d+/bias|\
56+
layer_\d+_type_\d+/idt|\
57+
final_layer_type_\d+/bias|\
58+
final_layer_type_\d+/matrix\
59+
"
60+
for node in graph.node:
61+
if re.fullmatch(transform_node_pattern,node.name) != None:
62+
transform_node[node.name] = node
63+
return transform_node

0 commit comments

Comments
 (0)