Skip to content

Commit 28dc891

Browse files
authored
Merge pull request #14 from deepmodeling/devel
Devel update
2 parents 86c7fcb + 2b02556 commit 28dc891

File tree

3 files changed

+97
-1
lines changed

3 files changed

+97
-1
lines changed

source/train/Fitting.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ def build (self,
212212
if ii >= 1 and self.n_neuron[ii] == self.n_neuron[ii-1] :
213213
layer+= one_layer(layer, self.n_neuron[ii], name='layer_'+str(ii)+'_type_'+str(type_i)+suffix, reuse=reuse, seed = self.seed, use_timestep = self.resnet_dt, activation_fn = self.fitting_activation_fn, precision = self.fitting_precision, trainable = self.trainable[ii])
214214
else :
215-
layer = one_layer(layer, self.n_neuron[ii], name='layer_'+str(ii)+'_type_'+str(type_i)+suffix, reuse=reuse, seed = self.seed, precision = self.fitting_precision, trainable = self.trainable[ii])
215+
layer = one_layer(layer, self.n_neuron[ii], name='layer_'+str(ii)+'_type_'+str(type_i)+suffix, reuse=reuse, seed = self.seed, activation_fn = self.fitting_activation_fn, precision = self.fitting_precision, trainable = self.trainable[ii])
216216
final_layer = one_layer(layer, 1, activation_fn = None, bavg = type_bias_ae, name='final_layer_type_'+str(type_i)+suffix, reuse=reuse, seed = self.seed, precision = self.fitting_precision, trainable = self.trainable[-1])
217217

218218
if type_i < len(self.atom_ener) and self.atom_ener[type_i] is not None:

source/train/main.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from .freeze import freeze
55
from .config import config
66
from .test import test
7+
from .transform import transform
78

89
def main () :
910
parser = argparse.ArgumentParser(
@@ -15,6 +16,13 @@ def main () :
1516
# help="the output json file")
1617

1718
default_num_inter_threads = 0
19+
parser_transform = subparsers.add_parser('transform', help='pass parameters to another model')
20+
parser_transform.add_argument('-r', "--raw-model", default = "raw_frozen_model.pb", type=str,
21+
help = "the model receiving parameters")
22+
parser_transform.add_argument("-o","--old-model", default = "old_frozen_model.pb", type=str,
23+
help='the model providing parameters')
24+
parser_transform.add_argument("-n", "--output", default = "frozen_model.pb", type=str,
25+
help = "the model after passing parameters")
1826
parser_train = subparsers.add_parser('train', help='train a model')
1927
parser_train.add_argument('INPUT',
2028
help='the input parameter file in json format')
@@ -62,5 +70,7 @@ def main () :
6270
config(args)
6371
elif args.command == 'test' :
6472
test(args)
73+
elif args.command == 'transform' :
74+
transform(args)
6575
else :
6676
raise RuntimeError('unknown command ' + args.command)

source/train/transform.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
from deepmd.env import tf
2+
import re
3+
import numpy as np
4+
def transform(args):
5+
raw_graph = load_graph(args.raw_model)
6+
old_graph = load_graph(args.old_model)
7+
print("%d ops in the raw graph\n%d ops in the old graph" %(len(raw_graph.as_graph_def().node),len(old_graph.as_graph_def().node)))
8+
new_graph_def = transform_graph(raw_graph,old_graph)
9+
with tf.gfile.GFile(args.output, mode='wb') as f:
10+
f.write(new_graph_def.SerializeToString())
11+
print("the output model is saved in %s" % args.output)
12+
13+
def load_graph(graphName):
14+
graph_def = tf.GraphDef()
15+
with open(graphName,"rb") as f:
16+
graph_def.ParseFromString(f.read())
17+
with tf.Graph().as_default() as graph:
18+
tf.import_graph_def(graph_def,name = "")
19+
return graph
20+
21+
def transform_graph(raw_graph,old_graph):
22+
precision_dict = {\
23+
1:(np.float32, "float32"),\
24+
2:(np.float64, "float64"),\
25+
19:(np.float16, "float16")\
26+
}
27+
old_graph_def = old_graph.as_graph_def()
28+
raw_graph_def = raw_graph.as_graph_def()
29+
raw_graph_node = load_transform_node(raw_graph_def)
30+
old_graph_node = load_transform_node(old_graph_def)
31+
32+
if len(raw_graph_node) != len(old_graph_node):
33+
raise RuntimeError("raw graph and old graph has different network structure")
34+
35+
for node in raw_graph_def.node:
36+
if node.name in raw_graph_node.keys():
37+
if precision_dict[old_graph_node[node.name].dtype][1] == "float16" or precision_dict[raw_graph_node[node.name].dtype][1] == "float16":
38+
raise RuntimeError("float16 conversions not currently supported")
39+
40+
check_dim(raw_graph_node, old_graph_node, node.name)
41+
42+
if re.fullmatch("final_layer_type_\d+/bias",node.name) == None:
43+
tensor_value = np.frombuffer(old_graph_node[node.name].tensor_content,dtype = precision_dict[old_graph_node[node.name].dtype][0])
44+
tensor_value = tensor_value.astype(dtype=precision_dict[raw_graph_node[node.name].dtype][0])
45+
node.attr["value"].tensor.tensor_content = tensor_value.tostring()
46+
47+
else:
48+
if precision_dict[old_graph_node[node.name].dtype][1] == "float64":
49+
tensor_value = (np.array(old_graph_node[node.name].double_val)).astype(precision_dict[raw_graph_node[node.name].dtype][0])
50+
node.attr["value"].CopyFrom(tf.AttrValue(tensor=tf.make_tensor_proto(tensor_value,precision_dict[raw_graph_node[node.name].dtype][0], [1])))
51+
52+
elif precision_dict[old_graph_node[node.name].dtype][1] == "float32":
53+
tensor_value = (np.array(old_graph_node[node.name].float_val)).astype(precision_dict[raw_graph_node[node.name].dtype][0])
54+
node.attr["value"].CopyFrom(tf.AttrValue(tensor=tf.make_tensor_proto(tensor_value, precision_dict[raw_graph_node[node.name].dtype][0], [1])))
55+
56+
elif precision_dict[old_graph_node[node.name].dtype][1] == "float16":
57+
tensor_value = (np.array(old_graph_node[node.name].half_val)).astype(precision_dict[raw_graph_node[node.name].dtype][0])
58+
node.attr["value"].CopyFrom(tf.AttrValue(tensor=tf.make_tensor_proto(tensor_value, precision_dict[raw_graph_node[node.name].dtype][0], [1])))
59+
60+
print("%s is passed from old graph(%s) to raw graph(%s)" % (node.name,precision_dict[old_graph_node[node.name].dtype][1],precision_dict[raw_graph_node[node.name].dtype][1]))
61+
62+
return raw_graph_def
63+
64+
def check_dim(raw_graph_node, old_graph_node, node_name):
65+
raw_graph_dim = raw_graph_node[node_name].tensor_shape
66+
old_graph_dim = old_graph_node[node_name].tensor_shape
67+
if raw_graph_dim != old_graph_dim:
68+
raise RuntimeError("old graph and raw graph has different"+node_name+" dim")
69+
70+
71+
def load_transform_node(graph):
72+
transform_node = {}
73+
transform_node_pattern = "\
74+
filter_type_\d+/matrix_\d+_\d+|\
75+
filter_type_\d+/bias_\d+_\d+|\
76+
filter_type_\d+/idt_\d+_\d+|\
77+
layer_\d+_type_\d+/matrix|\
78+
layer_\d+_type_\d+/bias|\
79+
layer_\d+_type_\d+/idt|\
80+
final_layer_type_\d+/bias|\
81+
final_layer_type_\d+/matrix\
82+
"
83+
for node in graph.node:
84+
if re.fullmatch(transform_node_pattern,node.name) != None:
85+
transform_node[node.name] = node.attr["value"].tensor
86+
return transform_node

0 commit comments

Comments
 (0)