Skip to content

Commit db2ffde

Browse files
authored
Add files via upload
1 parent 30b9b4c commit db2ffde

File tree

1 file changed

+249
-0
lines changed

1 file changed

+249
-0
lines changed

tools/demo2.py

Lines changed: 249 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,249 @@
1+
#!/usr/bin/env python
2+
# -*- coding: UTF-8 -*
3+
4+
from __future__ import absolute_import
5+
from __future__ import division
6+
from __future__ import print_function
7+
8+
import _init_paths
9+
import matplotlib
10+
matplotlib.use('Agg')
11+
from model.config import cfg
12+
from model.test import im_detect
13+
from model.nms_wrapper import nms
14+
15+
from utils.timer import Timer
16+
import tensorflow as tf
17+
from matplotlib.font_manager import FontProperties
18+
zhfont1 = matplotlib.font_manager.FontProperties(fname='/usr/share/fonts/opentype/noto/NotoSansCJK-Bold.ttc')
19+
import matplotlib.pyplot as plt
20+
import numpy as np
21+
import os, cv2
22+
import argparse
23+
24+
import csv
25+
import time
26+
27+
from nets.vgg16 import vgg16
28+
from nets.resnet_v1_rfcn_hole import resnetv1
29+
import sys
30+
reload(sys)
31+
sys.setdefaultencoding('utf8')
32+
33+
34+
CLASSES = ('__background__',
35+
'dr0', 'dr1', 'dr2', 'dr3')
36+
37+
NETS = {'vgg16': ('vgg16_faster_rcnn_iter_70000.ckpt',),'res101': ('res101_faster_rcnn_iter_200000.ckpt',)}
38+
DATASETS= {'pascal_voc': ('voc_2007_trainval',),'pascal_voc_0712': ('voc_2007_trainval+voc_2012_trainval',)}
39+
localtime = time.asctime( time.localtime(time.time()) )
40+
41+
#thresh在demo()设置了CONF_THRESH=0.8只有概率大于0.8才会显示
42+
#inds被保留的区块序号,区块信息在dets里
43+
def vis_detections(im, class_name, dets, thresh=0.5 ,image_name='null'):
44+
"""Draw detected bounding boxes."""
45+
inds = np.where(dets[:, -1] >= thresh)[0]
46+
if len(inds) == 0:
47+
tm=[0,0,2,2]
48+
return 0,0,tm,thresh
49+
#print(inds)
50+
#im = im[:, :, (2, 1, 0)]
51+
#fig, ax = plt.subplots(figsize=(12, 12))
52+
#ax.imshow(im, aspect='equal')
53+
temp=0
54+
maxscore = max(dets[:, -1])
55+
for i in inds:
56+
bbox = dets[i, :4]
57+
score = dets[i, -1]
58+
if score==maxscore:
59+
return class_name,score,bbox,thresh
60+
#ax.add_patch(plt.Rectangle((bbox[0], bbox[1]),bbox[2] - bbox[0],bbox[3] - bbox[1], fill=False,edgecolor='red', linewidth=3.5))
61+
#ax.text(bbox[0], bbox[1] - 2,'{:s} {:.3f}'.format(class_name, score),bbox=dict(facecolor='blue', alpha=0.5),fontsize=14, color='white')
62+
63+
#ax.set_title(('{} detection results p({} | box) >= {:.1f}').format(class_name, class_name,thresh),fontsize=14)
64+
#plt.axis('off')
65+
#plt.tight_layout()
66+
#plt.draw()
67+
#plt.savefig("/var/www/html/figure/"+image_name)
68+
69+
def vis_detections_onlyone(im, class_name, dets, thresh=0.5):
70+
"""Draw detected bounding boxes."""
71+
inds = np.where(dets[:, -1] >= thresh)[0]
72+
if len(inds) == 0:
73+
return
74+
im = im[:, :, (2, 1, 0)]
75+
fig, ax = plt.subplots(figsize=(12, 12))
76+
ax.imshow(im, aspect='equal')
77+
maxscore = max(dets[:, -1])
78+
for i in inds:
79+
bbox = dets[i, :4]
80+
score = dets[i, -1]
81+
if score==maxscore:
82+
ax.add_patch(plt.Rectangle((bbox[0], bbox[1]),bbox[2] - bbox[0],bbox[3] - bbox[1], fill=False,edgecolor='red', linewidth=3.5))
83+
ax.text(bbox[0], bbox[1] - 2,'{:s} {:.3f}'.format(class_name, score),bbox=dict(facecolor='blue', alpha=0.5),fontsize=14, color='white')
84+
85+
ax.set_title(('{} detection results '
86+
'p({} | box) >= {:.1f}').format(class_name, class_name,
87+
thresh),
88+
fontsize=14)
89+
plt.axis('off')
90+
plt.tight_layout()
91+
plt.draw()
92+
93+
def demo(sess, net, image_name):
94+
"""Detect object classes in an image using pre-computed object proposals."""
95+
96+
# Load the demo image
97+
im_file = os.path.join(cfg.DATA_DIR, 'demo', image_name)
98+
im = cv2.imread(im_file)
99+
100+
# Detect all object classes and regress object bounds
101+
# im_dect at test.py
102+
timer = Timer()
103+
timer.tic()
104+
scores, boxes = im_detect(sess, net, im)
105+
#print(scores)
106+
#print(boxes)
107+
timer.toc()
108+
print('检测区域采样时间 {:.3f}s 共计 {:d} 个目标区块'.format(timer.total_time, boxes.shape[0]))
109+
110+
# Visualize detections for each class
111+
CONF_THRESH = 0.5
112+
NMS_THRESH = 0.3
113+
#enumerate枚举 hstack矩阵拼接 keep记录nms筛选后的区块 dets保存的每个区块的(x1 y1 x2 y2 score)格式list
114+
b=0
115+
e='error'
116+
f=[0,0,2,2]
117+
for cls_ind, cls in enumerate(CLASSES[1:]):
118+
cls_ind += 1 # because we skipped background
119+
cls_boxes = boxes[:, 4*cls_ind:4*(cls_ind + 1)]
120+
cls_scores = scores[:, cls_ind]
121+
dets = np.hstack((cls_boxes,
122+
cls_scores[:, np.newaxis])).astype(np.float32)
123+
keep = nms(dets, NMS_THRESH)
124+
#print(keep)
125+
dets = dets[keep, :]
126+
a=image_name
127+
c,d,g,t=vis_detections(im, cls, dets, thresh=CONF_THRESH, image_name=a)
128+
if (c==0 and d==0):
129+
1+1
130+
else:
131+
if (d>b):
132+
b=d
133+
e=c
134+
f=g
135+
if (e=='dr0'):
136+
e='正常人dr0'
137+
if (e=='dr1'):
138+
e='轻度患者dr1'
139+
if (e=='dr2'):
140+
e='中度患者dr2'
141+
if (e=='dr3'):
142+
e='重度患者dr3'
143+
if (e=='dr4'):
144+
e='增殖患者dr4'
145+
print(e,b)
146+
im = im[:, :, (2, 1, 0)]
147+
fig, ax = plt.subplots(figsize=(8, 6))
148+
ax.imshow(im, aspect='equal')
149+
ax.add_patch(plt.Rectangle((f[0], f[1]),f[2] - f[0],f[3] - f[1], fill=False,edgecolor='red', linewidth=3.5))
150+
ax.text(f[0], f[1] + 25,'{:s} {:.3f}'.format(e, b),bbox=dict(facecolor='blue', alpha=0.5),fontsize=14, color='white',fontproperties=zhfont1)
151+
152+
ax.set_title(('辅助诊断结果:{} p({} | box) >= {:.1f}').format(e, e,t),fontsize=14,fontproperties=zhfont1)
153+
plt.axis('off')
154+
plt.tight_layout()
155+
plt.draw()
156+
#plt.savefig("/var/www/html/figure/"+image_name)
157+
plt.savefig("figure/"+localtime+image_name)
158+
filer=open('results/result'+localtime,'a+')
159+
filer.write(image_name+' '+e+' '+str(b)+'\n')
160+
filer.close
161+
162+
def parse_args():
163+
"""Parse input arguments."""
164+
parser = argparse.ArgumentParser(description='Tensorflow demo')
165+
parser.add_argument('--net', dest='demo_net', help='Network to use [vgg16 res101]',
166+
choices=NETS.keys(), default='res101')
167+
parser.add_argument('--dataset', dest='dataset', help='Trained dataset [pascal_voc pascal_voc_0712]',
168+
choices=DATASETS.keys(), default='pascal_voc')
169+
args = parser.parse_args()
170+
171+
return args
172+
173+
def csv_writer(data, filename):
174+
with open(filename, "wb") as csv_file:
175+
writer = csv.writer(csv_file)
176+
for line in data:
177+
writer.writerow(line)
178+
179+
if __name__ == '__main__':
180+
cfg.TEST.HAS_RPN = True # Use RPN for proposals
181+
args = parse_args()
182+
183+
# model path
184+
demonet = args.demo_net
185+
dataset = args.dataset
186+
tfmodel = os.path.join('output', demonet, DATASETS[dataset][0], 'default',
187+
NETS[demonet][0])
188+
189+
190+
if not os.path.isfile(tfmodel + '.meta'):
191+
raise IOError(('{:s} not found.\nDid you download the proper networks from '
192+
'our server and place them properly?').format(tfmodel + '.meta'))
193+
194+
# set config
195+
tfconfig = tf.ConfigProto(allow_soft_placement=True)
196+
tfconfig.gpu_options.allow_growth=True
197+
print ("\033[1;34mThe code is licensed by engineer1109.")
198+
#print ("\033[1;34m开始初始化系统")
199+
# init session
200+
sess = tf.Session(config=tfconfig)
201+
#print ("\033[1;33m卷积模型开始加载,默认是RES101")
202+
# load network
203+
if demonet == 'vgg16':
204+
net = vgg16(batch_size=1)
205+
elif demonet == 'res101':
206+
net = resnetv1(batch_size=1, num_layers=101)
207+
else:
208+
raise NotImplementedError
209+
#print(demonet)
210+
net.create_architecture(sess, "TEST", 5,
211+
tag='default', anchor_scales=[8, 16, 32])
212+
saver = tf.train.Saver()
213+
saver.restore(sess, tfmodel)
214+
#print(saver.restore(sess, tfmodel))
215+
216+
#print('网络加载完毕 {:s}'.format(tfmodel))
217+
result=[]
218+
fd = file("images.txt", "r" )
219+
for line in fd.readlines():
220+
result.append(list(map(str,line.split(','))))
221+
#print (result)
222+
print ("欢迎使用TF-RFCN测试模式</br>")
223+
print ("总共需要辨识的图片数量</br>")
224+
size=len(result)
225+
print (size)
226+
image_name = [1]*size
227+
for i in range(size):
228+
var=str(result[i][0])
229+
var=var.strip()
230+
image_name[i]=var
231+
#print (image_name)
232+
#print(type(result))
233+
for image_name in image_name:
234+
print('</br>=======================</br>')
235+
print('====测试-TF-RFCN====</br>')
236+
print('测试数据 data/demo/{}</br>'.format(image_name))
237+
demo(sess, net, image_name)
238+
data = []
239+
data2=[]
240+
with open('results/result'+localtime) as f:
241+
for line in f:
242+
data2.append(line.strip().split(" "))
243+
#print('</br>概率分布</br>')
244+
#print (data2)
245+
filename = "csvfiles/output"+localtime+".csv"
246+
csv_writer(data2, filename)
247+
filename = "output.csv"
248+
csv_writer(data2, filename)
249+
#plt.show()

0 commit comments

Comments
 (0)