|
| 1 | +from deepmd.env import tf |
| 2 | +from google.protobuf import text_format |
| 3 | +from tensorflow.python.platform import gfile |
| 4 | +from tensorflow.python import pywrap_tensorflow |
| 5 | +from tensorflow.python.framework import graph_util |
| 6 | + |
| 7 | +def convert_to_13(args): |
| 8 | + convert_pb_to_pbtxt(args.input_model, 'frozen_model.pbtxt') |
| 9 | + convert_to_dp13('frozen_model.pbtxt') |
| 10 | + convert_pbtxt_to_pb('frozen_model.pbtxt', args.output_model) |
| 11 | + print("the converted output model(1.3 support) is saved in %s" % args.output_model) |
| 12 | + |
| 13 | +def convert_pb_to_pbtxt(pbfile, pbtxtfile): |
| 14 | + with gfile.FastGFile(pbfile, 'rb') as f: |
| 15 | + graph_def = tf.GraphDef() |
| 16 | + graph_def.ParseFromString(f.read()) |
| 17 | + tf.import_graph_def(graph_def, name='') |
| 18 | + tf.train.write_graph(graph_def, './', pbtxtfile, as_text=True) |
| 19 | + |
| 20 | +def convert_pbtxt_to_pb(pbtxtfile, pbfile): |
| 21 | + with tf.gfile.FastGFile(pbtxtfile, 'r') as f: |
| 22 | + graph_def = tf.GraphDef() |
| 23 | + file_content = f.read() |
| 24 | + # Merges the human-readable string in `file_content` into `graph_def`. |
| 25 | + text_format.Merge(file_content, graph_def) |
| 26 | + tf.train.write_graph(graph_def, './', pbfile, as_text=False) |
| 27 | + |
| 28 | +def convert_to_dp13(file): |
| 29 | + file_data = "" |
| 30 | + with open(file, "r", encoding="utf-8") as f: |
| 31 | + ii = 0 |
| 32 | + lines = f.readlines() |
| 33 | + while (ii < len(lines)): |
| 34 | + line = lines[ii] |
| 35 | + file_data += line |
| 36 | + ii+=1 |
| 37 | + if 'name' in line and ('DescrptSeA' in line or 'ProdForceSeA' in line or 'ProdVirialSeA' in line): |
| 38 | + while not('attr' in lines[ii] and '{' in lines[ii]): |
| 39 | + file_data += lines[ii] |
| 40 | + ii+=1 |
| 41 | + file_data += ' attr {\n' |
| 42 | + file_data += ' key: \"T\"\n' |
| 43 | + file_data += ' value {\n' |
| 44 | + file_data += ' type: DT_DOUBLE\n' |
| 45 | + file_data += ' }\n' |
| 46 | + file_data += ' }\n' |
| 47 | + with open(file, "w", encoding="utf-8") as f: |
| 48 | + f.write(file_data) |
0 commit comments