Skip to content

Commit 1084f75

Browse files
authored
Add files via upload
1 parent 5bbe4e5 commit 1084f75

File tree

6 files changed

+45
-68
lines changed

6 files changed

+45
-68
lines changed

frcnn.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,11 @@
1919
#--------------------------------------------#
2020
class FRCNN(object):
2121
_defaults = {
22-
"model_path": 'model_data/voc_weights_resnet.pth',
23-
"classes_path": 'model_data/voc_classes.txt',
24-
"confidence": 0.5,
25-
"backbone": "resnet50"
22+
"model_path" : 'model_data/voc_weights_resnet.pth',
23+
"classes_path" : 'model_data/voc_classes.txt',
24+
"confidence" : 0.5,
25+
"iou" : 0.45,
26+
"backbone" : "resnet50"
2627
}
2728

2829
@classmethod
@@ -86,9 +87,11 @@ def detect_image(self, image):
8687
old_height = image_shape[0]
8788
old_image = copy.deepcopy(image)
8889
width,height = get_new_img_size(old_width,old_height)
89-
image = image.resize([width,height])
90+
91+
image = image.resize([width,height], Image.BICUBIC)
9092
photo = np.array(image,dtype = np.float32)/255
9193
photo = np.transpose(photo, (2, 0, 1))
94+
9295
with torch.no_grad():
9396
images = []
9497
images.append(photo)
@@ -97,7 +100,7 @@ def detect_image(self, image):
97100

98101
roi_cls_locs, roi_scores, rois, roi_indices = self.model(images)
99102
decodebox = DecodeBox(self.std, self.mean, self.num_classes)
100-
outputs = decodebox.forward(roi_cls_locs, roi_scores, rois, height=height, width=width, score_thresh = self.confidence)
103+
outputs = decodebox.forward(roi_cls_locs, roi_scores, rois, height = height, width = width, nms_iou = self.iou, score_thresh = self.confidence)
101104
if len(outputs)==0:
102105
return old_image
103106
bbox = outputs[:,:4]
@@ -107,6 +110,7 @@ def detect_image(self, image):
107110
bbox[:, 0::2] = (bbox[:, 0::2])/width*old_width
108111
bbox[:, 1::2] = (bbox[:, 1::2])/height*old_height
109112
bbox = np.array(bbox,np.int32)
113+
110114
image = old_image
111115
thickness = (np.shape(old_image)[0] + np.shape(old_image)[1]) // old_width*2
112116
font = ImageFont.truetype(font='model_data/simhei.ttf',size=np.floor(3e-2 * np.shape(image)[1] + 0.5).astype('int32'))

get_dr_txt.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,32 +3,37 @@
33
# 具体视频教程可查看
44
# https://www.bilibili.com/video/BV1zE411u7Vw
55
#----------------------------------------------------#
6-
from frcnn import FRCNN
7-
from PIL import Image
8-
from torch.autograd import Variable
9-
import torch
10-
import numpy as np
6+
import copy
117
import os
8+
9+
import numpy as np
10+
import torch
1211
import torch.backends.cudnn as cudnn
12+
from PIL import Image, ImageDraw, ImageFont
13+
from torch.autograd import Variable
1314
from torch.nn import functional as F
14-
from utils.utils import loc2bbox, nms, DecodeBox
15+
from tqdm import tqdm
16+
17+
from frcnn import FRCNN
1518
from nets.frcnn import FasterRCNN
1619
from nets.frcnn_training import get_new_img_size
17-
from PIL import Image, ImageFont, ImageDraw
18-
import copy
20+
from utils.utils import DecodeBox, loc2bbox, nms
21+
1922

2023
class mAP_FRCNN(FRCNN):
2124
#---------------------------------------------------#
2225
# 检测图片
2326
#---------------------------------------------------#
2427
def detect_image(self,image_id,image):
25-
self.confidence = 0.05
28+
self.confidence = 0.01
29+
self.iou = 0.45
2630
f = open("./input/detection-results/"+image_id+".txt","w")
2731
image_shape = np.array(np.shape(image)[0:2])
2832
old_width = image_shape[1]
2933
old_height = image_shape[0]
3034
width,height = get_new_img_size(old_width,old_height)
31-
image = image.resize([width,height])
35+
36+
image = image.resize([width,height], Image.BICUBIC)
3237
photo = np.array(image,dtype = np.float32)/255
3338
photo = np.transpose(photo, (2, 0, 1))
3439
with torch.no_grad():
@@ -39,7 +44,7 @@ def detect_image(self,image_id,image):
3944

4045
roi_cls_locs, roi_scores, rois, roi_indices = self.model(images)
4146
decodebox = DecodeBox(self.std, self.mean, self.num_classes)
42-
outputs = decodebox.forward(roi_cls_locs, roi_scores, rois, height=height, width=width, score_thresh = self.confidence)
47+
outputs = decodebox.forward(roi_cls_locs, roi_scores, rois, height = height, width = width, nms_iou = self.iou, score_thresh = self.confidence)
4348
if len(outputs)==0:
4449
return
4550
bbox = outputs[:,:4]
@@ -71,12 +76,11 @@ def detect_image(self,image_id,image):
7176
os.makedirs("./input/images-optional")
7277

7378

74-
for image_id in image_ids:
79+
for image_id in tqdm(image_ids):
7580
image_path = "./VOCdevkit/VOC2007/JPEGImages/"+image_id+".jpg"
7681
image = Image.open(image_path)
77-
image.save("./input/images-optional/"+image_id+".jpg")
82+
# image.save("./input/images-optional/"+image_id+".jpg")
7883
frcnn.detect_image(image_id,image)
79-
print(image_id," done!")
8084

8185

8286
print("Conversion completed!")

nets/frcnn.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,4 +63,9 @@ def forward(self, x, scale=1.):
6363
# print(np.shape(rois))
6464
# print(roi_indices)
6565
roi_cls_locs, roi_scores = self.head.forward(h, rois, roi_indices)
66-
return roi_cls_locs, roi_scores, rois, roi_indices
66+
return roi_cls_locs, roi_scores, rois, roi_indices
67+
68+
def freeze_bn(self):
69+
for m in self.modules():
70+
if isinstance(m, nn.BatchNorm2d):
71+
m.eval()

nets/rpn.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,21 +8,12 @@
88
import numpy as np
99

1010

11-
'''
12-
一些建议的参数设置:
13-
VGG:SGD优化器,冻结时学习率1e-3,解冻时学习率1e-4
14-
nets.rpn中ProposalCreator的n_train_post_nms=2000;
15-
utils.utils中ProposalTargetCreator的pos_ratio=0.25;
16-
RESNET50:Adam优化器,冻结时学习率1e-4,解冻时学习率1e-5
17-
nets.rpn中ProposalCreator的n_train_post_nms=300;
18-
utils.utils中ProposalTargetCreator的pos_ratio=0.5;
19-
'''
2011
class ProposalCreator():
2112
def __init__(self,
2213
mode,
2314
nms_thresh=0.7,
2415
n_train_pre_nms=12000,
25-
n_train_post_nms=300,
16+
n_train_post_nms=2000,
2617
n_test_pre_nms=3000,
2718
n_test_post_nms=300,
2819
min_size=16

train.py

Lines changed: 6 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -75,15 +75,6 @@ def fit_ont_epoch(net,epoch,epoch_size,epoch_size_val,gen,genval,Epoch):
7575
print('Saving state, iter:', str(epoch+1))
7676
torch.save(model.state_dict(), 'logs/Epoch%d-Total_Loss%.4f-Val_Loss%.4f.pth'%((epoch+1),total_loss/(epoch_size+1),val_toal_loss/(epoch_size_val+1)))
7777

78-
'''
79-
一些建议的参数设置:
80-
VGG:SGD优化器,冻结时学习率1e-3,解冻时学习率1e-4
81-
nets.rpn中ProposalCreator的n_train_post_nms=2000;
82-
utils.utils中ProposalTargetCreator的pos_ratio=0.25;
83-
RESNET50:Adam优化器,冻结时学习率1e-4,解冻时学习率1e-5
84-
nets.rpn中ProposalCreator的n_train_post_nms=300;
85-
utils.utils中ProposalTargetCreator的pos_ratio=0.5;
86-
'''
8778
if __name__ == "__main__":
8879
# 参数初始化
8980
annotation_path = '2007_train.txt'
@@ -118,22 +109,12 @@ def fit_ont_epoch(net,epoch,epoch_size,epoch_size_val,gen,genval,Epoch):
118109
num_val = int(len(lines)*val_split)
119110
num_train = len(lines) - num_val
120111

121-
'''
122-
一些建议的参数设置:
123-
VGG:SGD优化器,冻结时学习率1e-3,解冻时学习率1e-4
124-
nets.rpn中ProposalCreator的n_train_post_nms=2000;
125-
utils.utils中ProposalTargetCreator的pos_ratio=0.25;
126-
RESNET50:Adam优化器,冻结时学习率1e-4,解冻时学习率1e-5
127-
nets.rpn中ProposalCreator的n_train_post_nms=300;
128-
utils.utils中ProposalTargetCreator的pos_ratio=0.5;
129-
'''
130112
if True:
131113
lr = 1e-4
132114
Init_Epoch = 0
133-
Freeze_Epoch = 25
115+
Freeze_Epoch = 50
134116

135117
optimizer = optim.Adam(model.parameters(),lr,weight_decay=5e-4)
136-
# optimizer = optim.SGD(model.parameters(),lr,weight_decay=5e-4,momentum=0.9)
137118
lr_scheduler = optim.lr_scheduler.StepLR(optimizer,step_size=1,gamma=0.95)
138119

139120
if Use_Data_Loader:
@@ -158,18 +139,18 @@ def fit_ont_epoch(net,epoch,epoch_size,epoch_size_val,gen,genval,Epoch):
158139
# ------------------------------------#
159140
# 由于batch==1所以冻结bn层
160141
# ------------------------------------#
161-
model = model.eval()
142+
model.freeze_bn()
162143

163144
for epoch in range(Init_Epoch,Freeze_Epoch):
164145
fit_ont_epoch(model,epoch,epoch_size,epoch_size_val,gen,gen_val,Freeze_Epoch)
165146
lr_scheduler.step()
166147

167148
if True:
168149
lr = 1e-5
169-
Freeze_Epoch = 25
170-
Unfreeze_Epoch = 50
150+
Freeze_Epoch = 50
151+
Unfreeze_Epoch = 100
152+
171153
optimizer = optim.Adam(model.parameters(),lr,weight_decay=5e-4)
172-
# optimizer = optim.SGD(model.parameters(),lr,weight_decay=5e-4,momentum=0.9)
173154
lr_scheduler = optim.lr_scheduler.StepLR(optimizer,step_size=1,gamma=0.95)
174155

175156
if Use_Data_Loader:
@@ -194,7 +175,7 @@ def fit_ont_epoch(net,epoch,epoch_size,epoch_size_val,gen,genval,Epoch):
194175
# ------------------------------------#
195176
# 由于batch==1所以冻结bn层
196177
# ------------------------------------#
197-
model = model.eval()
178+
model.freeze_bn()
198179

199180
for epoch in range(Freeze_Epoch,Unfreeze_Epoch):
200181
fit_ont_epoch(model,epoch,epoch_size,epoch_size_val,gen,gen_val,Unfreeze_Epoch)

utils/utils.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def __init__(self, std, mean, num_classes):
8787
self.mean = mean
8888
self.num_classes = num_classes + 1
8989

90-
def forward(self, roi_cls_locs, roi_scores, rois, height, width, score_thresh):
90+
def forward(self, roi_cls_locs, roi_scores, rois, height, width, nms_iou, score_thresh):
9191

9292
rois = torch.Tensor(rois)
9393

@@ -130,24 +130,16 @@ def forward(self, roi_cls_locs, roi_scores, rois, height, width, score_thresh):
130130

131131
prob_l_index = np.argsort(prob_l)[::-1]
132132
detections_class = detections_class[prob_l_index]
133-
nms_out = nms(detections_class, 0.3)
133+
nms_out = nms(detections_class, nms_iou)
134134
if outputs==[]:
135135
outputs = nms_out
136136
else:
137137
outputs = np.concatenate([outputs, nms_out],axis=0)
138138
return outputs
139-
'''
140-
一些建议的参数设置:
141-
VGG:SGD优化器,冻结时学习率1e-3,解冻时学习率1e-4
142-
nets.rpn中ProposalCreator的n_train_post_nms=2000;
143-
utils.utils中ProposalTargetCreator的pos_ratio=0.25;
144-
RESNET50:Adam优化器,冻结时学习率1e-4,解冻时学习率1e-5
145-
nets.rpn中ProposalCreator的n_train_post_nms=300;
146-
utils.utils中ProposalTargetCreator的pos_ratio=0.5;
147-
'''
139+
148140
class ProposalTargetCreator(object):
149141
def __init__(self,n_sample=128,
150-
pos_ratio=0.5, pos_iou_thresh=0.5,
142+
pos_ratio=0.25, pos_iou_thresh=0.5,
151143
neg_iou_thresh_hi=0.5, neg_iou_thresh_lo=0
152144
):
153145
self.n_sample = n_sample

0 commit comments

Comments
 (0)