-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathremove_caffe_model_bn.py
More file actions
120 lines (95 loc) · 4.9 KB
/
remove_caffe_model_bn.py
File metadata and controls
120 lines (95 loc) · 4.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import caffe
import math
import numpy as np
prototxt = r'model/dst/custom_resnet_v2_160x160_stride16_27.prototxt'
caffemodel = r'model/dst/custom_resnet_v2_160x160_stride16_27.caffemodel'
dst_prototxt = r'model/dst/custom_resnet_v2_160x160_stride16_27_without_bn.prototxt'
dst_caffemodel = r'model/dst/custom_resnet_v2_160x160_stride16_27_without_bn.caffemodel'
net = caffe.Net(prototxt, caffemodel, caffe.TEST)
net_dst = caffe.Net(dst_prototxt, caffe.TEST)
for k in net_dst.params:
if k in net.params:
for i in range(len(net.params[k])):
net_dst.params[k][i].data[...] = net.params[k][i].data[...]
print('copy from', k, net.params[k][i].data.shape)
for i in range(len(net.layers)):
if net.layers[i].type == 'Convolution':
print(net._layer_names[i], net.layers[i].type)
conv_name = net._layer_names[i]
# start_remove = False
# if 'CPM' not in conv_name or 'relu1_' in conv_name or 'relu2_' in conv_name or 'relu3_' in conv_name or 'relu4_CPM_L1_conv2d_dw' in conv_name or 'relu4_CPM_L1_conv2d_pw' in conv_name:
# start_remove = True
start_remove = True
j = i + 1
if j >= len(net.layers) or j + 1 >= len(net.layers):
continue
print('next type', net.layers[j].type)
if net.layers[j].type == 'BatchNorm' and start_remove:
print(' ', net._layer_names[j], net.layers[j].type)
print(' ', net._layer_names[j + 1], net.layers[j + 1].type)
bn_name = net._layer_names[j]
scale_name = net._layer_names[j + 1]
bn_mean = net.params[bn_name][0].data
bn_variance = net.params[bn_name][1].data
bn_scale = net.params[bn_name][2].data
scale_weight = net.params[scale_name][0].data
scale_bias = net.params[scale_name][1].data
# print ' ', bn_name, bn_mean, bn_variance, bn_scale
# print ' ', scale_name, scale_weight, scale_bias
dst_conv_weight = net.params[conv_name][0].data
if len(net.params[conv_name]) > 1:
dst_conv_bias = net.params[conv_name][1].data
else:
dst_conv_bias = 0
if np.count_nonzero(bn_variance) != bn_variance.size:
assert False
alpha = scale_weight / np.sqrt(bn_variance / bn_scale + 0.001) #remember reading eps
print('len(dst_conv_weight)', len(dst_conv_weight), 'len(alpha)', len(alpha))
assert len(dst_conv_weight) == len(alpha)
for k in range(len(alpha)):
dst_conv_weight[k] = dst_conv_weight[k] * alpha[k]
dst_conv_bias = dst_conv_bias * alpha + (scale_bias - (bn_mean / bn_scale) * alpha)
net_dst.params[conv_name][0].data[...] = dst_conv_weight
# print ' ', dst_conv_weight
# print ' ', dst_conv_bias
if len(net_dst.params[conv_name]) > 1:
net_dst.params[conv_name][1].data[...] = dst_conv_bias
if net.layers[i].type == 'InnerProduct' and start_remove:
print(net._layer_names[i], net.layers[i].type)
ip_name = net._layer_names[i]
j = i + 1
if j >= len(net.layers) or j + 1 >= len(net.layers):
continue
if net.layers[j].type == 'BatchNorm':
print(' ', net._layer_names[j], net.layers[j].type)
print(' ', net._layer_names[j + 1], net.layers[j + 1].type)
bn_name = net._layer_names[j]
scale_name = net._layer_names[j + 1]
bn_mean = net.params[bn_name][0].data
bn_variance = net.params[bn_name][1].data
bn_scale = net.params[bn_name][2].data
scale_weight = net.params[scale_name][0].data
scale_bias = net.params[scale_name][1].data
# print bn_name
# print bn_mean, bn_variance, bn_scale
# print scale_name
# print scale_weight, scale_bias
dst_inner_weight = net.params[ip_name][0].data
if np.count_nonzero(bn_variance) != bn_variance.size:
assert False
alpha = scale_weight / np.sqrt(bn_variance / bn_scale)
if len(net.params[ip_name]) > 1:
dst_inner_bias = net.params[ip_name][1].data
else:
dst_inner_bias = 0
print('len(dst_inner_weight)', len(dst_inner_weight), 'len(alpha)', len(alpha))
assert len(dst_inner_weight) == len(alpha)
for k in range(len(alpha)):
dst_inner_weight[k] = dst_inner_weight[k] * alpha[k]
dst_inner_bias = dst_inner_bias * alpha + (scale_bias - (bn_mean / bn_scale) * alpha)
net_dst.params[ip_name][0].data[...] = dst_inner_weight
if len(net_dst.params[ip_name]) > 1:
net_dst.params[ip_name][1].data[...] = dst_inner_bias
net_dst.save(dst_caffemodel)
print('FINISH ##############################')
# net_dst.save(dst_caffemodel)