forked from nyukat/breast_density_classifier
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathconvert_model.py
More file actions
67 lines (59 loc) · 2.85 KB
/
convert_model.py
File metadata and controls
67 lines (59 loc) · 2.85 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import argparse
import torch
import tensorflow as tf
import models_torch
def histogram_tf_to_torch(input_path, output_path):
g = tf.Graph()
histogram_model = models_torch.BaselineHistogramModel(num_bins=50)
with tf.Session(graph=g) as sess:
saver = tf.train.import_meta_graph(input_path + ".meta")
saver.restore(sess, input_path)
histogram_model.fc1.weight.data = torch.Tensor(sess.run(g.get_tensor_by_name("fully_connected/weights:0")).T)
histogram_model.fc1.bias.data = torch.Tensor(sess.run(g.get_tensor_by_name("fully_connected/biases:0")))
histogram_model.fc2.weight.data = torch.Tensor(sess.run(g.get_tensor_by_name("fully_connected_2/weights:0")).T)
histogram_model.fc2.bias.data = torch.Tensor(sess.run(g.get_tensor_by_name("fully_connected_2/biases:0")))
torch.save(histogram_model.state_dict(), output_path)
def cnn_tf_to_torch(input_path, output_path):
g = tf.Graph()
device = torch.device("cpu")
bbmodel = models_torch.BaselineBreastModel(device, nodropout_probability=1.0)
with tf.Session(graph=g, config=tf.ConfigProto(allow_soft_placement=True)) as sess:
saver = tf.train.import_meta_graph(input_path + ".meta")
saver.restore(sess, input_path)
var_dict = {
var.name: var
for var in g.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
}
for conv_name, conv_layer in bbmodel.conv_layer_dict.items():
for view in ["CC", "MLO"]:
conv_layer.ops[view].weight.data = torch.Tensor(sess.run(
var_dict["{}_{}/weights:0".format(conv_name, view)]
)).permute(3, 2, 0, 1)
conv_layer.ops[view].bias.data = torch.Tensor(sess.run(
var_dict["{}_{}/biases:0".format(conv_name, view)]
))
bbmodel.fc1.weight.data = torch.Tensor(sess.run(
var_dict["fully_connected/weights:0"]
).T)
bbmodel.fc1.bias.data = torch.Tensor(sess.run(
var_dict["fully_connected/biases:0"]
))
bbmodel.fc2.weight.data = torch.Tensor(sess.run(
var_dict["fully_connected_2/weights:0"]
).T)
bbmodel.fc2.bias.data = torch.Tensor(sess.run(
var_dict["fully_connected_2/biases:0"]
))
torch.save(bbmodel.state_dict(), output_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Convert from TensorFlow checkpoints to PyTorch pickles')
parser.add_argument('model_type')
parser.add_argument('input_path')
parser.add_argument('output_path')
args = parser.parse_args()
if args.model_type == "histogram":
histogram_tf_to_torch(args.input_path, args.output_path)
elif args.model_type == "cnn":
cnn_tf_to_torch(args.input_path, args.output_path)
else:
raise RuntimeError(args.model_type)