|
1 | 1 | from deepmd.env import tf |
| 2 | +import re |
2 | 3 | def transform(args): |
3 | 4 | new_graph = load_graph(args.raw_model) |
4 | 5 | old_graph = load_graph(args.old_model) |
@@ -46,14 +47,17 @@ def check_dim(new_graph_node, old_graph_node, nodeName): |
46 | 47 |
|
47 | 48 | def load_transform_node(graph): |
48 | 49 | transform_node = {} |
49 | | - filter_w = ["filter_type_0/matrix_{}_0".format(i) for i in range(1,10)] |
50 | | - filter_b = ["filter_type_0/bias_{}_0".format(i) for i in range(1,10)] |
51 | | - fitting_w = ["layer_{}_type_0/matrix".format(i) for i in range(0,10)] |
52 | | - fitting_b = ["layer_{}_type_0/bias".format(i) for i in range(0,10)] |
53 | | - fitting_idt = ["layer_{}_type_0/idt".format(i) for i in range(0,10)] |
54 | | - final_layer = ["final_layer_type_0/bias","final_layer_type_0/matrix"] |
55 | | - transform_node_list = filter_w + filter_b + fitting_w + fitting_b + fitting_idt + final_layer |
| 50 | + transform_node_pattern = "\ |
| 51 | +filter_type_\d+/matrix_\d+_\d+|\ |
| 52 | +filter_type_\d+/bias_\d+_\d+|\ |
| 53 | +filter_type_\d+/idt_\d+_\d+\ |
| 54 | +layer_\d+_type_\d+/matrix|\ |
| 55 | +layer_\d+_type_\d+/bias|\ |
| 56 | +layer_\d+_type_\d+/idt|\ |
| 57 | +final_layer_type_\d+/bias|\ |
| 58 | +final_layer_type_\d+/matrix\ |
| 59 | +" |
56 | 60 | for node in graph.node: |
57 | | - if node.name in transform_node_list: |
| 61 | + if re.fullmatch(transform_node_pattern,node.name) != None: |
58 | 62 | transform_node[node.name] = node |
59 | 63 | return transform_node |
0 commit comments