-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathremove_caffe_prototxt_bn.py
More file actions
52 lines (43 loc) · 1.97 KB
/
remove_caffe_prototxt_bn.py
File metadata and controls
52 lines (43 loc) · 1.97 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import caffe
import math
import numpy as np
import json
from caffe.proto import caffe_pb2
from google.protobuf import text_format
prototxt = r'model/dst/custom_resnet_v2_160x160_stride16_27.prototxt'
dst_prototxt = r'model/dst/custom_resnet_v2_160x160_stride16_27_without_bn.prototxt'
def readProtoFile(filepath, parser_object):
file = open(filepath, "r")
text_format.Merge(str(file.read()), parser_object)
file.close()
return parser_object
def readProtoSolverFile(filepath):
solver_config = caffe.proto.caffe_pb2.NetParameter()
return readProtoFile(filepath, solver_config)
net_params = readProtoSolverFile(prototxt)
outfile = open(dst_prototxt, 'w')
outfile.write('name: \"' + net_params.name + '\"\n')
outfile.write('\n')
print(net_params.name)
index = 0
start_remove = False
for layer in net_params.layer:
print(layer.name)
index = index + 1
if (layer.type == 'Convolution' or layer.type == 'InnerProduct') and index < len(net_params.layer) and net_params.layer[index].type == 'BatchNorm':
layer.top[0] = net_params.layer[index + 1].top[0]
# if 'CPM' not in layer.name or 'relu1_' in layer.name or 'relu2_' in layer.name or 'relu3_' in layer.name or 'relu4_CPM_L1_conv2d_dw' in layer.name or 'relu4_CPM_L1_conv2d_pw' in layer.name:
# start_remove = True
start_remove = True
if layer.type == 'BatchNorm' and start_remove:
continue
if layer.type == 'Scale' and start_remove:
start_remove = False
continue
if layer.type == 'Convolution' and index < len(net_params.layer) and net_params.layer[index].type == 'BatchNorm' and start_remove:
layer.convolution_param.bias_term = True
if layer.type == 'InnerProduct' and (index < len(net_params.layer) and net_params.layer[index].type == 'BatchNorm') and start_remove:
layer.inner_product_param.bias_term = True
outfile.write('layer {\n')
outfile.write(' '.join(('\n' + str(layer)).splitlines(True)))
outfile.write('\n}\n\n')