diff --git a/demo.py b/demo.py index bbf8179..8b13fc9 100644 --- a/demo.py +++ b/demo.py @@ -2,85 +2,54 @@ sys.path.append('./') -from yolo.net.yolo_tiny_net import YoloTinyNet -import tensorflow as tf +from yolo.net.yolo_tiny_net import YoloTinyNet +from tools.visualize import PredictionWindow +import tensorflow as tf import cv2 import numpy as np -classes_name = ["aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train","tvmonitor"] +import argparse +# setup CLI argument parser +parser = argparse.ArgumentParser() +parser.add_argument('-s', '--source', + help='either webcam or image', + default='image') -def process_predicts(predicts): - p_classes = predicts[0, :, :, 0:20] - C = predicts[0, :, :, 20:22] - coordinate = predicts[0, :, :, 22:] +args = parser.parse_args() - p_classes = np.reshape(p_classes, (7, 7, 1, 20)) - C = np.reshape(C, (7, 7, 2, 1)) +common_params = {'image_size': 448, 'num_classes': 20, 'batch_size':1} +net_params = {'cell_size': 7, 'boxes_per_cell': 2, 'weight_decay': 0.0005} - P = C * p_classes +window = PredictionWindow(common_params, net_params) +cap = cv2.VideoCapture(0) - #print P[5,1, 0, :] +def get_frame(): + ret, frame = cap.read() - index = np.argmax(P) + height, width = frame.shape[:2] - index = np.unravel_index(index, P.shape) + x = (width - height) / 2 - class_num = index[3] + frame = frame[:, x:width-x, :] + frame = cv2.resize(frame, (448, 448)) - coordinate = np.reshape(coordinate, (7, 7, 2, 4)) + stop = cv2.waitKey(1) & 0xFF == ord('q') - max_coordinate = coordinate[index[0], index[1], index[2], :] + if stop: + cap.release() + return stop, frame - xcenter = max_coordinate[0] - ycenter = max_coordinate[1] - w = max_coordinate[2] - h = max_coordinate[3] +def get_image(): + image = cv2.imread('cat.jpg') + image = cv2.resize(image, (448, 448)) - xcenter = (index[1] + xcenter) * (448/7.0) - ycenter = (index[0] + ycenter) * (448/7.0) - w = w * 448 - h = h * 448 + return True, image - xmin = xcenter - w/2.0 - ymin = ycenter - h/2.0 - - xmax = xmin + w - ymax = ymin + h - - return xmin, ymin, xmax, ymax, class_num - -common_params = {'image_size': 448, 'num_classes': 20, - 'batch_size':1} -net_params = {'cell_size': 7, 'boxes_per_cell':2, 'weight_decay': 0.0005} - -net = YoloTinyNet(common_params, net_params, test=True) - -image = tf.placeholder(tf.float32, (1, 448, 448, 3)) -predicts = net.inference(image) - -sess = tf.Session() - -np_img = cv2.imread('cat.jpg') -resized_img = cv2.resize(np_img, (448, 448)) -np_img = cv2.cvtColor(resized_img, cv2.COLOR_BGR2RGB) - - -np_img = np_img.astype(np.float32) - -np_img = np_img / 255.0 * 2 - 1 -np_img = np.reshape(np_img, (1, 448, 448, 3)) - -saver = tf.train.Saver(net.trainable_collection) - -saver.restore(sess, 'models/pretrain/yolo_tiny.ckpt') - -np_predict = sess.run(predicts, feed_dict={image: np_img}) - -xmin, ymin, xmax, ymax, class_num = process_predicts(np_predict) -class_name = classes_name[class_num] -cv2.rectangle(resized_img, (int(xmin), int(ymin)), (int(xmax), int(ymax)), (0, 0, 255)) -cv2.putText(resized_img, class_name, (int(xmin), int(ymin)), 2, 1.5, (0, 0, 255)) -cv2.imwrite('cat_out.jpg', resized_img) -sess.close() +if args.source == 'image': + window.run(get_image) +elif args.source == 'webcam': + window.run(get_frame) +else: + print('please define a valid source, either "webcam" or "image"') diff --git a/tools/visualize.py b/tools/visualize.py new file mode 100644 index 0000000..83df3f1 --- /dev/null +++ b/tools/visualize.py @@ -0,0 +1,95 @@ +import cv2 +import numpy as np +import tensorflow as tf +from yolo.net.yolo_tiny_net import YoloTinyNet + +CLASS_NAMES = ["aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"] + +class PredictionWindow(object): + """ opens window showing the models predictions """ + def __init__(self, common_params, net_params): + self.net = YoloTinyNet(common_params, net_params, test=True) + + self.image_size = common_params['image_size'] + self.image = tf.placeholder(tf.float32, (1, self.image_size, self.image_size, 3)) + self.predicts = self.net.inference(self.image) + + def run(self, source_callback): + sess = tf.Session() + + saver = tf.train.Saver(self.net.trainable_collection) + saver.restore(sess, 'models/pretrain/yolo_tiny.ckpt') + + stop = False + + while not stop: + stop, frame = source_callback() + + orig = np.copy(frame) + + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + frame = frame.astype(np.float32) + frame = frame / 255.0 * 2 - 1 + frame = np.reshape(frame, (1, self.image_size, self.image_size, 3)) + + predictions = sess.run(self.predicts, feed_dict={self.image: frame}) + boxes = self.process_predicts(predictions) + + for xmin, ymin, xmax, ymax, class_num in boxes: + class_name = CLASS_NAMES[class_num] + cv2.rectangle(orig, (int(xmin), int(ymin)), (int(xmax), int(ymax)), (0, 0, 255)) + cv2.putText(orig, class_name, (int(xmin), int(ymin)), 2, 1.5, (0, 0, 255)) + + cv2.imshow('model predictions', orig) + + sess.close() + cv2.waitKey(0) + cv2.destroyAllWindows() + + + def process_predicts(self, predicts): + p_classes = predicts[0, :, :, 0:20] + C = predicts[0, :, :, 20:22] + coordinate = np.reshape(predicts[0, :, :, 22:], (7, 7, 2, 4)) + + p_classes = np.reshape(p_classes, (7, 7, 1, 20)) + C = np.reshape(C, (7, 7, 2, 1)) + + P = C * p_classes + + max_val = np.max(P) + + boxes = [] + for y in range(7): + for x in range(7): + classes = P[y, x] + index = np.argmax(classes) + index = np.unravel_index(index, classes.shape) + + box_index, class_index = index + + #print(box_index, class_index) + + confidence = classes[box_index, class_index] + + if confidence > max_val * 0.8: + class_num = class_index + + cx, cy, w, h = coordinate[y, x, box_index, :] + + cx = (x + cx) * (448 / 7.0) + cy = (y + cy) * (448 / 7.0) + + w = w * 448 + h = h * 448 + + xmin = cx - w / 2.0 + ymin = cy - h / 2.0 + + xmax = xmin + w + ymax = ymin + h + + boxes.append([xmin, ymin, xmax, ymax, class_index]) + + return boxes +