11"""Main code for Multi-Template-Matching (MTM)."""
2+ import os
3+ import warnings
4+ from concurrent .futures import ThreadPoolExecutor , as_completed
5+
26import cv2
3- import numpy as np
7+ import numpy as np
48import pandas as pd
5- import warnings
9+ from scipy . signal import find_peaks
610from skimage .feature import peak_local_max
7- from scipy .signal import find_peaks
8- from .version import __version__
911
1012from .NMS import NMS
13+ from .version import __version__
1114
1215__all__ = ['NMS' ]
1316
@@ -33,7 +36,7 @@ def _findLocalMax_(corrMap, score_threshold=0.6):
3336 peaks = [[i ,0 ] for i in peaks [0 ]]
3437
3538
36- else : # Correlatin map is 2D
39+ else : # Correlation map is 2D
3740 peaks = peak_local_max (corrMap , threshold_abs = score_threshold , exclude_border = False ).tolist ()
3841
3942 return peaks
@@ -116,82 +119,110 @@ def findMatches(listTemplates, image, method=cv2.TM_CCOEFF_NORMED, N_object=floa
116119 -------
117120 - Pandas DataFrame with 1 row per hit and column "TemplateName"(string), "BBox":(X, Y, Width, Height), "Score":float
118121 """
119- if N_object != float ("inf" ) and type (N_object ) != int :
122+ if N_object != float ("inf" ) and not isinstance (N_object , int ) :
120123 raise TypeError ("N_object must be an integer" )
121124
122125 ## Crop image to search region if provided
123126 if searchBox is not None :
124127 xOffset , yOffset , searchWidth , searchHeight = searchBox
125128 image = image [yOffset : yOffset + searchHeight , xOffset : xOffset + searchWidth ]
126-
127129 else :
128130 xOffset = yOffset = 0
129-
131+
130132 # Check that the template are all smaller are equal to the image (original, or cropped if there is a search region)
131133 for index , tempTuple in enumerate (listTemplates ):
132-
134+
133135 if not isinstance (tempTuple , tuple ) or len (tempTuple )== 1 :
134136 raise ValueError ("listTemplates should be a list of tuples as ('name','array') or ('name', 'array', 'mask')" )
135-
137+
136138 templateSmallerThanImage = all (templateDim <= imageDim for templateDim , imageDim in zip (tempTuple [1 ].shape , image .shape ))
137-
139+
138140 if not templateSmallerThanImage :
139141 fitIn = "searchBox" if (searchBox is not None ) else "image"
140142 raise ValueError ("Template '{}' at index {} in the list of templates is larger than {}." .format (tempTuple [0 ], index , fitIn ) )
141-
143+
142144 listHit = []
143- for tempTuple in listTemplates :
145+ ## Use multi-threading to iterate through all templates, using half the number of cpu cores available.
146+ with ThreadPoolExecutor (max_workers = round (os .cpu_count ()* .5 )) as executor :
147+ futures = [executor .submit (_multi_compute , tempTuple , image , method , N_object , score_threshold , xOffset , yOffset , listHit ) for tempTuple in listTemplates ]
148+ for future in as_completed (futures ):
149+ _ = future .result ()
144150
145- templateName , template = tempTuple [:2 ]
146- mask = None
151+ if listHit :
152+ return pd .DataFrame (listHit ) # All possible hits before Non-Maxima Supression
153+ else :
154+ return pd .DataFrame (columns = ["TemplateName" , "BBox" , "Score" ])
147155
148- if len (tempTuple )>= 3 : # ie a mask is also provided
149- if method in (0 ,3 ):
150- mask = tempTuple [2 ]
151- else :
152- warnings .warn ("Template matching method not supporting the use of Mask. Use 0/TM_SQDIFF or 3/TM_CCORR_NORMED." )
153156
154- #print('\nSearch with template : ',templateName)
155- corrMap = computeScoreMap (template , image , method , mask = mask )
157+ def _multi_compute (tempTuple , image , method , N_object , score_threshold , xOffset , yOffset , listHit ):
158+ """
159+ Find all possible template locations satisfying the score threshold provided a template to search and an image.
160+ Add the hits in the list of hits.
161+
162+ Parameters
163+ ----------
164+ - tempTuple : a tuple (LabelString, template, mask (optional))
165+ template to search in each image, associated to a label
166+ labelstring : string
167+ template : numpy array (grayscale or RGB)
168+ mask (optional): numpy array, should have the same dimensions and type than the template
156169
157- ## Find possible location of the object
158- if N_object == 1 : # Detect global Min/Max
159- minVal , maxVal , minLoc , maxLoc = cv2 .minMaxLoc (corrMap )
170+ - image : Grayscale or RGB numpy array
171+ image in which to perform the search, it should be the same bitDepth and number of channels than the templates
160172
161- if method in ( 0 , 1 ):
162- peaks = [ minLoc [:: - 1 ]] # opposite sorting than in the multiple detection
173+ - method : int
174+ one of OpenCV template matching method (0 to 5), default 5=0-mean cross-correlation
163175
164- else :
165- peaks = [ maxLoc [:: - 1 ]]
176+ - N_object: int or float("inf")
177+ expected number of objects in the image, default to infinity if unknown
166178
179+ - score_threshold: float in range [0,1]
180+ if N_object>1, returns local minima/maxima respectively below/above the score_threshold
167181
168- else :# Detect local max or min
169- if method in (0 ,1 ): # Difference => look for local minima
170- peaks = _findLocalMin_ (corrMap , score_threshold )
182+ - xOffset : int
183+ optional the x offset if the search area is provided
171184
172- else :
173- peaks = _findLocalMax_ ( corrMap , score_threshold )
185+ - yOffset : int
186+ optional the y offset if the search area is provided
174187
188+ - listHit : the list of hits which we want to add the discovered hit
189+ expected array of hits
190+ """
191+ templateName , template = tempTuple [:2 ]
192+ mask = None
175193
176- #print('Initially found',len(peaks),'hit with this template')
194+ if len (tempTuple )>= 3 : # ie a mask is also provided
195+ if method in (0 ,3 ):
196+ mask = tempTuple [2 ]
197+ else :
198+ warnings .warn ("Template matching method not supporting the use of Mask. Use 0/TM_SQDIFF or 3/TM_CCORR_NORMED." )
177199
200+ #print('\nSearch with template : ',templateName)
201+ corrMap = computeScoreMap (template , image , method , mask = mask )
178202
179- # Once every peak was detected for this given template
180- ## Create a dictionnary for each hit with {'TemplateName':, 'BBox': (x,y,Width, Height), 'Score':coeff}
203+ ## Find possible location of the object
204+ if N_object == 1 : # Detect global Min/Max
205+ _ , _ , minLoc , maxLoc = cv2 .minMaxLoc (corrMap )
206+ if method in (0 ,1 ):
207+ peaks = [minLoc [::- 1 ]] # opposite sorting than in the multiple detection
208+ else :
209+ peaks = [maxLoc [::- 1 ]]
210+ else :# Detect local max or min
211+ if method in (0 ,1 ): # Difference => look for local minima
212+ peaks = _findLocalMin_ (corrMap , score_threshold )
213+ else :
214+ peaks = _findLocalMax_ (corrMap , score_threshold )
181215
182- height , width = template . shape [ 0 : 2 ] # slicing make sure it works for RGB too
216+ #print('Initially found',len(peaks),'hit with this template')
183217
184- for peak in peaks :
185- coeff = corrMap [tuple (peak )]
186- newHit = {'TemplateName' :templateName , 'BBox' : ( int (peak [1 ])+ xOffset , int (peak [0 ])+ yOffset , width , height ) , 'Score' :coeff }
218+ # Once every peak was detected for this given template
219+ ## Create a dictionnary for each hit with {'TemplateName':, 'BBox': (x,y,Width, Height), 'Score':coeff}
187220
188- # append to list of potential hit before Non maxima suppression
189- listHit .append (newHit )
221+ height , width = template .shape [0 :2 ] # slicing make sure it works for RGB too
190222
191- if listHit :
192- return pd .DataFrame (listHit ) # All possible hits before Non-Maxima Supression
193- else :
194- return pd .DataFrame (columns = ["TemplateName" , "BBox" , "Score" ]) # empty df with correct column header
223+ for peak in peaks :
224+ # append to list of potential hit before Non maxima suppression
225+ listHit .append ({'TemplateName' :templateName , 'BBox' : ( int (peak [1 ])+ xOffset , int (peak [0 ])+ yOffset , width , height ) , 'Score' :corrMap [tuple (peak )]}) # empty df with correct column header
195226
196227
197228def matchTemplates (listTemplates , image , method = cv2 .TM_CCOEFF_NORMED , N_object = float ("inf" ), score_threshold = 0.5 , maxOverlap = 0.25 , searchBox = None ):
@@ -239,7 +270,7 @@ def matchTemplates(listTemplates, image, method=cv2.TM_CCOEFF_NORMED, N_object=f
239270 tableHit = findMatches (listTemplates , image , method , N_object , score_threshold , searchBox )
240271
241272 if method == 0 : raise ValueError ("The method TM_SQDIFF is not supported. Use TM_SQDIFF_NORMED instead." )
242- sortAscending = True if method == 1 else False
273+ sortAscending = ( method == 1 )
243274
244275 return NMS (tableHit , score_threshold , sortAscending , N_object , maxOverlap )
245276
@@ -275,7 +306,7 @@ def drawBoxesOnRGB(image, tableHit, boxThickness=2, boxColor=(255, 255, 00), sho
275306 if image .ndim == 2 : outImage = cv2 .cvtColor (image , cv2 .COLOR_GRAY2RGB ) # convert to RGB to be able to show detections as color box on grayscale image
276307 else : outImage = image .copy ()
277308
278- for index , row in tableHit .iterrows ():
309+ for _ , row in tableHit .iterrows ():
279310 x ,y ,w ,h = row ['BBox' ]
280311 cv2 .rectangle (outImage , (x , y ), (x + w , y + h ), color = boxColor , thickness = boxThickness )
281312 if showLabel : cv2 .putText (outImage , text = row ['TemplateName' ], org = (x , y ), fontFace = cv2 .FONT_HERSHEY_SIMPLEX , fontScale = labelScale , color = labelColor , lineType = cv2 .LINE_AA )
@@ -315,9 +346,9 @@ def drawBoxesOnGray(image, tableHit, boxThickness=2, boxColor=255, showLabel=Fal
315346 if image .ndim == 3 : outImage = cv2 .cvtColor (image , cv2 .COLOR_RGB2GRAY ) # convert to RGB to be able to show detections as color box on grayscale image
316347 else : outImage = image .copy ()
317348
318- for index , row in tableHit .iterrows ():
349+ for _ , row in tableHit .iterrows ():
319350 x ,y ,w ,h = row ['BBox' ]
320351 cv2 .rectangle (outImage , (x , y ), (x + w , y + h ), color = boxColor , thickness = boxThickness )
321352 if showLabel : cv2 .putText (outImage , text = row ['TemplateName' ], org = (x , y ), fontFace = cv2 .FONT_HERSHEY_SIMPLEX , fontScale = labelScale , color = labelColor , lineType = cv2 .LINE_AA )
322353
323- return outImage
354+ return outImage
0 commit comments