Skip to content

Commit 8385272

Browse files
authored
Update transform.py
1 parent df5f3a6 commit 8385272

File tree

1 file changed

+12
-8
lines changed

1 file changed

+12
-8
lines changed

source/train/transform.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from deepmd.env import tf
2+
import re
23
def transform(args):
34
new_graph = load_graph(args.raw_model)
45
old_graph = load_graph(args.old_model)
@@ -46,14 +47,17 @@ def check_dim(new_graph_node, old_graph_node, nodeName):
4647

4748
def load_transform_node(graph):
4849
transform_node = {}
49-
filter_w = ["filter_type_0/matrix_{}_0".format(i) for i in range(1,10)]
50-
filter_b = ["filter_type_0/bias_{}_0".format(i) for i in range(1,10)]
51-
fitting_w = ["layer_{}_type_0/matrix".format(i) for i in range(0,10)]
52-
fitting_b = ["layer_{}_type_0/bias".format(i) for i in range(0,10)]
53-
fitting_idt = ["layer_{}_type_0/idt".format(i) for i in range(0,10)]
54-
final_layer = ["final_layer_type_0/bias","final_layer_type_0/matrix"]
55-
transform_node_list = filter_w + filter_b + fitting_w + fitting_b + fitting_idt + final_layer
50+
transform_node_pattern = "\
51+
filter_type_\d+/matrix_\d+_\d+|\
52+
filter_type_\d+/bias_\d+_\d+|\
53+
filter_type_\d+/idt_\d+_\d+\
54+
layer_\d+_type_\d+/matrix|\
55+
layer_\d+_type_\d+/bias|\
56+
layer_\d+_type_\d+/idt|\
57+
final_layer_type_\d+/bias|\
58+
final_layer_type_\d+/matrix\
59+
"
5660
for node in graph.node:
57-
if node.name in transform_node_list:
61+
if re.fullmatch(transform_node_pattern,node.name) != None:
5862
transform_node[node.name] = node
5963
return transform_node

0 commit comments

Comments
 (0)