Skip to content

Commit 604ced7

Browse files
authored
Merge pull request #235 from GeiduanLiu/devel
Devel
2 parents ddcf17c + e141573 commit 604ced7

File tree

2 files changed

+15
-12
lines changed

2 files changed

+15
-12
lines changed

source/train/Fitting.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,12 +194,14 @@ def build (self,
194194
if self.numb_fparam > 0 :
195195
ext_fparam = tf.tile(fparam, [1, natoms[2+type_i]])
196196
ext_fparam = tf.reshape(ext_fparam, [-1, self.numb_fparam])
197+
ext_fparam = tf.cast(ext_fparam,self.fitting_precision)
197198
layer = tf.concat([layer, ext_fparam], axis = 1)
198199
if self.numb_aparam > 0 :
199200
ext_aparam = tf.slice(aparam,
200201
[ 0, start_index * self.numb_aparam],
201202
[-1, natoms[2+type_i] * self.numb_aparam])
202203
ext_aparam = tf.reshape(ext_aparam, [-1, self.numb_aparam])
204+
ext_aparam = tf.cast(ext_aparam,self.fitting_precision)
203205
layer = tf.concat([layer, ext_aparam], axis = 1)
204206
start_index += natoms[2+type_i]
205207

source/train/transform.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -49,22 +49,18 @@ def transform_graph(raw_graph,old_graph):
4949

5050
for node in raw_graph_def.node:
5151
if node.name in raw_graph_node.keys():
52-
"""
53-
if precision_dict[old_graph_node[node.name].dtype][1] == "float16" or precision_dict[raw_graph_node[node.name].dtype][1] == "float16":
54-
raise RuntimeError("float16 conversions not currently supported")
55-
"""
5652

5753
check_dim(raw_graph_node, old_graph_node, node.name)
54+
tensor_shape = [dim.size for dim in raw_graph_node[node.name].tensor_shape.dim]
5855
old_graph_dtype = precision_dict[old_graph_node[node.name].dtype]
5956
raw_graph_dtype = precision_dict[raw_graph_node[node.name].dtype]
6057
print("%s is passed from old graph(%s) to raw graph(%s)" % (node.name, old_graph_dtype[1],raw_graph_dtype[1]))
6158

6259
if raw_graph_dtype[1] == "float16":
6360
if old_graph_dtype[1] == "float64" or old_graph_dtype[1] == "float32":
64-
if re.fullmatch("final_layer_type_\d+/bias", node.name) == None:
61+
if (len(tensor_shape) != 1) or (tensor_shape[0] != 1):
6562
tensor_value = np.frombuffer(old_graph_node[node.name].tensor_content, dtype=old_graph_dtype[0])
6663
tensor_value = tensor_value.astype(np.float16)
67-
tensor_shape = [dim.size for dim in raw_graph_node[node.name].tensor_shape.dim]
6864
node.attr["value"].CopyFrom(tf.AttrValue(tensor=tf.make_tensor_proto(tensor_value, tf.float16, tensor_shape)))
6965

7066
else:
@@ -77,13 +73,12 @@ def transform_graph(raw_graph,old_graph):
7773
node.attr["value"].CopyFrom(tf.AttrValue(tensor=tf.make_tensor_proto(tensor_value,tf.float16, [1])))
7874

7975
elif old_graph_dtype[1] == "float16":
80-
tensor_shape = [dim.size for dim in raw_graph_node[node.name].tensor_shape.dim]
8176
tensor_value = convertMatrix(np.array(old_graph_node[node.name].half_val), tensor_shape)
8277
node.attr["value"].CopyFrom(tf.AttrValue(tensor=tf.make_tensor_proto(tensor_value, tf.float16, tensor_value.shape)))
8378

8479
elif raw_graph_dtype[1] == "float64" or raw_graph_dtype[1] == "float32":
8580
if old_graph_dtype[1] == "float64" or old_graph_dtype[1] == "float32":
86-
if re.fullmatch("final_layer_type_\d+/bias", node.name) == None:
81+
if (len(tensor_shape) != 1) or (tensor_shape[0] != 1):
8782
tensor_value = np.frombuffer(old_graph_node[node.name].tensor_content,dtype = old_graph_dtype[0])
8883
tensor_value = tensor_value.astype(dtype=raw_graph_dtype[0])
8984
node.attr["value"].tensor.tensor_content = tensor_value.tostring()
@@ -98,13 +93,11 @@ def transform_graph(raw_graph,old_graph):
9893
node.attr["value"].CopyFrom(tf.AttrValue(tensor=tf.make_tensor_proto(tensor_value,raw_graph_dtype[0], [1])))
9994

10095
elif old_graph_dtype[1] == "float16":
101-
if re.fullmatch("final_layer_type_\d+/bias", node.name) == None:
102-
tensor_shape = [dim.size for dim in raw_graph_node[node.name].tensor_shape.dim]
96+
if (len(tensor_shape) != 1) or (tensor_shape[0] != 1):
10397
tensor_value = convertMatrix(np.array(old_graph_node[node.name].half_val), tensor_shape)
10498
tensor_value = tensor_value.astype(raw_graph_dtype[0])
10599
node.attr["value"].tensor.tensor_content = tensor_value.tostring()
106100
else:
107-
tensor_shape = [dim.size for dim in raw_graph_node[node.name].tensor_shape.dim]
108101
tensor_value = convertMatrix(np.array(old_graph_node[node.name].half_val), tensor_shape)
109102
tensor_value = tensor_value.astype(raw_graph_dtype[0])
110103
node.attr["value"].CopyFrom(tf.AttrValue(tensor=tf.make_tensor_proto(tensor_value,raw_graph_dtype[0], tensor_value.shape)))
@@ -127,8 +120,16 @@ def load_transform_node(graph):
127120
layer_\d+_type_\d+/matrix|\
128121
layer_\d+_type_\d+/bias|\
129122
layer_\d+_type_\d+/idt|\
123+
final_layer_type_\d+/matrix|\
124+
descrpt_attr/t_avg|\
125+
descrpt_attr/t_std|\
130126
final_layer_type_\d+/bias|\
131-
final_layer_type_\d+/matrix\
127+
fitting_attr/t_fparam_avg|\
128+
fitting_attr/t_fparam_istd|\
129+
fitting_attr/t_aparam_avg|\
130+
fitting_attr/t_aparam_istd|\
131+
model_attr/t_tab_info|\
132+
model_attr/t_tab_data|\
132133
"
133134
for node in graph.node:
134135
if re.fullmatch(transform_node_pattern,node.name) != None:

0 commit comments

Comments
 (0)