11import torch
22import torch .nn .functional as F
33
4- import os
5- import sys
64import cv2
7- import random
8- import datetime
9- import math
10- import argparse
115import numpy as np
6+ from numba import jit
7+ from numba .typed import List
128
13- import scipy .io as sio
14- import zipfile
15- from .net_s3fd import s3fd
169from .bbox import *
1710
1811
1912def detect (net , img , device ):
2013 img = img .transpose (2 , 0 , 1 )
2114 # Creates a batch of 1
22- img = img . reshape (( 1 ,) + img . shape )
15+ img = np . expand_dims ( img , 0 )
2316
24- img = torch .from_numpy (img ).float (). to (device )
17+ img = torch .from_numpy (img ).to (device , dtype = torch . float32 )
2518
2619 return batch_detect (net , img , device )
2720
@@ -35,37 +28,41 @@ def batch_detect(net, img_batch, device):
3528 if 'cuda' in device :
3629 torch .backends .cudnn .benchmark = True
3730
38- BB , CC , HH , WW = img_batch .size ()
31+ batch_size = img_batch .size (0 )
32+ img_batch = img_batch .to (device , dtype = torch .float32 )
3933
4034 img_batch = img_batch .flip (- 3 ) # RGB to BGR
41- img_batch = img_batch - torch .Tensor ([104 , 117 , 123 ] ).view (1 , 3 , 1 , 1 )
35+ img_batch = img_batch - torch .tensor ([104.0 , 117.0 , 123.0 ], device = device ).view (1 , 3 , 1 , 1 )
4236
4337 with torch .no_grad ():
44- olist = net (img_batch . float () ) # patched uint8_t overflow error
38+ olist = net (img_batch ) # patched uint8_t overflow error
4539
4640 for i in range (len (olist ) // 2 ):
4741 olist [i * 2 ] = F .softmax (olist [i * 2 ], dim = 1 )
4842
49- bboxlists = []
43+ olist = [oelem .data .cpu ().numpy () for oelem in olist ]
44+
45+ bboxlists = get_predictions (List (olist ), batch_size )
46+ return bboxlists
5047
51- olist = [oelem .data .cpu () for oelem in olist ]
5248
53- for j in range (BB ):
49+ @jit (nopython = True )
50+ def get_predictions (olist , batch_size ):
51+ bboxlists = []
52+ variances = [0.1 , 0.2 ]
53+ for j in range (batch_size ):
5454 bboxlist = []
5555 for i in range (len (olist ) // 2 ):
5656 ocls , oreg = olist [i * 2 ], olist [i * 2 + 1 ]
57- FB , FC , FH , FW = ocls .size () # feature map size
5857 stride = 2 ** (i + 2 ) # 4,8,16,32,64,128
59- anchor = stride * 4
6058 poss = zip (* np .where (ocls [:, 1 , :, :] > 0.05 ))
6159 for Iindex , hindex , windex in poss :
6260 axc , ayc = stride / 2 + windex * stride , stride / 2 + hindex * stride
6361 score = ocls [j , 1 , hindex , windex ]
64- loc = oreg [j , :, hindex , windex ].contiguous ().view (1 , 4 )
65- priors = torch .Tensor ([[axc / 1.0 , ayc / 1.0 , stride * 4 / 1.0 , stride * 4 / 1.0 ]])
66- variances = [0.1 , 0.2 ]
62+ loc = oreg [j , :, hindex , windex ].copy ().reshape (1 , 4 )
63+ priors = np .array ([[axc / 1.0 , ayc / 1.0 , stride * 4 / 1.0 , stride * 4 / 1.0 ]])
6764 box = decode (loc , priors , variances )
68- x1 , y1 , x2 , y2 = box [0 ] * 1.0
65+ x1 , y1 , x2 , y2 = box [0 ]
6966 bboxlist .append ([x1 , y1 , x2 , y2 , score ])
7067
7168 bboxlists .append (bboxlist )
0 commit comments