Skip to content

Commit 57cbde3

Browse files
committed
Merge pull request opencv#10798 from mshabunin:split-stat
2 parents 579781e + 4437e0c commit 57cbde3

File tree

9 files changed

+4517
-4504
lines changed

9 files changed

+4517
-4504
lines changed
Lines changed: 339 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,339 @@
1+
// This file is part of OpenCV project.
2+
// It is subject to the license terms in the LICENSE file found in the top-level directory
3+
// of this distribution and at http://opencv.org/license.html
4+
5+
6+
#include "precomp.hpp"
7+
#include "stat.hpp"
8+
9+
namespace cv
10+
{
11+
12+
template<typename _Tp, typename _Rt>
13+
void batchDistL1_(const _Tp* src1, const _Tp* src2, size_t step2,
14+
int nvecs, int len, _Rt* dist, const uchar* mask)
15+
{
16+
step2 /= sizeof(src2[0]);
17+
if( !mask )
18+
{
19+
for( int i = 0; i < nvecs; i++ )
20+
dist[i] = normL1<_Tp, _Rt>(src1, src2 + step2*i, len);
21+
}
22+
else
23+
{
24+
_Rt val0 = std::numeric_limits<_Rt>::max();
25+
for( int i = 0; i < nvecs; i++ )
26+
dist[i] = mask[i] ? normL1<_Tp, _Rt>(src1, src2 + step2*i, len) : val0;
27+
}
28+
}
29+
30+
template<typename _Tp, typename _Rt>
31+
void batchDistL2Sqr_(const _Tp* src1, const _Tp* src2, size_t step2,
32+
int nvecs, int len, _Rt* dist, const uchar* mask)
33+
{
34+
step2 /= sizeof(src2[0]);
35+
if( !mask )
36+
{
37+
for( int i = 0; i < nvecs; i++ )
38+
dist[i] = normL2Sqr<_Tp, _Rt>(src1, src2 + step2*i, len);
39+
}
40+
else
41+
{
42+
_Rt val0 = std::numeric_limits<_Rt>::max();
43+
for( int i = 0; i < nvecs; i++ )
44+
dist[i] = mask[i] ? normL2Sqr<_Tp, _Rt>(src1, src2 + step2*i, len) : val0;
45+
}
46+
}
47+
48+
template<typename _Tp, typename _Rt>
49+
void batchDistL2_(const _Tp* src1, const _Tp* src2, size_t step2,
50+
int nvecs, int len, _Rt* dist, const uchar* mask)
51+
{
52+
step2 /= sizeof(src2[0]);
53+
if( !mask )
54+
{
55+
for( int i = 0; i < nvecs; i++ )
56+
dist[i] = std::sqrt(normL2Sqr<_Tp, _Rt>(src1, src2 + step2*i, len));
57+
}
58+
else
59+
{
60+
_Rt val0 = std::numeric_limits<_Rt>::max();
61+
for( int i = 0; i < nvecs; i++ )
62+
dist[i] = mask[i] ? std::sqrt(normL2Sqr<_Tp, _Rt>(src1, src2 + step2*i, len)) : val0;
63+
}
64+
}
65+
66+
static void batchDistHamming(const uchar* src1, const uchar* src2, size_t step2,
67+
int nvecs, int len, int* dist, const uchar* mask)
68+
{
69+
step2 /= sizeof(src2[0]);
70+
if( !mask )
71+
{
72+
for( int i = 0; i < nvecs; i++ )
73+
dist[i] = hal::normHamming(src1, src2 + step2*i, len);
74+
}
75+
else
76+
{
77+
int val0 = INT_MAX;
78+
for( int i = 0; i < nvecs; i++ )
79+
{
80+
if (mask[i])
81+
dist[i] = hal::normHamming(src1, src2 + step2*i, len);
82+
else
83+
dist[i] = val0;
84+
}
85+
}
86+
}
87+
88+
static void batchDistHamming2(const uchar* src1, const uchar* src2, size_t step2,
89+
int nvecs, int len, int* dist, const uchar* mask)
90+
{
91+
step2 /= sizeof(src2[0]);
92+
if( !mask )
93+
{
94+
for( int i = 0; i < nvecs; i++ )
95+
dist[i] = hal::normHamming(src1, src2 + step2*i, len, 2);
96+
}
97+
else
98+
{
99+
int val0 = INT_MAX;
100+
for( int i = 0; i < nvecs; i++ )
101+
{
102+
if (mask[i])
103+
dist[i] = hal::normHamming(src1, src2 + step2*i, len, 2);
104+
else
105+
dist[i] = val0;
106+
}
107+
}
108+
}
109+
110+
static void batchDistL1_8u32s(const uchar* src1, const uchar* src2, size_t step2,
111+
int nvecs, int len, int* dist, const uchar* mask)
112+
{
113+
batchDistL1_<uchar, int>(src1, src2, step2, nvecs, len, dist, mask);
114+
}
115+
116+
static void batchDistL1_8u32f(const uchar* src1, const uchar* src2, size_t step2,
117+
int nvecs, int len, float* dist, const uchar* mask)
118+
{
119+
batchDistL1_<uchar, float>(src1, src2, step2, nvecs, len, dist, mask);
120+
}
121+
122+
static void batchDistL2Sqr_8u32s(const uchar* src1, const uchar* src2, size_t step2,
123+
int nvecs, int len, int* dist, const uchar* mask)
124+
{
125+
batchDistL2Sqr_<uchar, int>(src1, src2, step2, nvecs, len, dist, mask);
126+
}
127+
128+
static void batchDistL2Sqr_8u32f(const uchar* src1, const uchar* src2, size_t step2,
129+
int nvecs, int len, float* dist, const uchar* mask)
130+
{
131+
batchDistL2Sqr_<uchar, float>(src1, src2, step2, nvecs, len, dist, mask);
132+
}
133+
134+
static void batchDistL2_8u32f(const uchar* src1, const uchar* src2, size_t step2,
135+
int nvecs, int len, float* dist, const uchar* mask)
136+
{
137+
batchDistL2_<uchar, float>(src1, src2, step2, nvecs, len, dist, mask);
138+
}
139+
140+
static void batchDistL1_32f(const float* src1, const float* src2, size_t step2,
141+
int nvecs, int len, float* dist, const uchar* mask)
142+
{
143+
batchDistL1_<float, float>(src1, src2, step2, nvecs, len, dist, mask);
144+
}
145+
146+
static void batchDistL2Sqr_32f(const float* src1, const float* src2, size_t step2,
147+
int nvecs, int len, float* dist, const uchar* mask)
148+
{
149+
batchDistL2Sqr_<float, float>(src1, src2, step2, nvecs, len, dist, mask);
150+
}
151+
152+
static void batchDistL2_32f(const float* src1, const float* src2, size_t step2,
153+
int nvecs, int len, float* dist, const uchar* mask)
154+
{
155+
batchDistL2_<float, float>(src1, src2, step2, nvecs, len, dist, mask);
156+
}
157+
158+
typedef void (*BatchDistFunc)(const uchar* src1, const uchar* src2, size_t step2,
159+
int nvecs, int len, uchar* dist, const uchar* mask);
160+
161+
162+
struct BatchDistInvoker : public ParallelLoopBody
163+
{
164+
BatchDistInvoker( const Mat& _src1, const Mat& _src2,
165+
Mat& _dist, Mat& _nidx, int _K,
166+
const Mat& _mask, int _update,
167+
BatchDistFunc _func)
168+
{
169+
src1 = &_src1;
170+
src2 = &_src2;
171+
dist = &_dist;
172+
nidx = &_nidx;
173+
K = _K;
174+
mask = &_mask;
175+
update = _update;
176+
func = _func;
177+
}
178+
179+
void operator()(const Range& range) const
180+
{
181+
AutoBuffer<int> buf(src2->rows);
182+
int* bufptr = buf;
183+
184+
for( int i = range.start; i < range.end; i++ )
185+
{
186+
func(src1->ptr(i), src2->ptr(), src2->step, src2->rows, src2->cols,
187+
K > 0 ? (uchar*)bufptr : dist->ptr(i), mask->data ? mask->ptr(i) : 0);
188+
189+
if( K > 0 )
190+
{
191+
int* nidxptr = nidx->ptr<int>(i);
192+
// since positive float's can be compared just like int's,
193+
// we handle both CV_32S and CV_32F cases with a single branch
194+
int* distptr = (int*)dist->ptr(i);
195+
196+
int j, k;
197+
198+
for( j = 0; j < src2->rows; j++ )
199+
{
200+
int d = bufptr[j];
201+
if( d < distptr[K-1] )
202+
{
203+
for( k = K-2; k >= 0 && distptr[k] > d; k-- )
204+
{
205+
nidxptr[k+1] = nidxptr[k];
206+
distptr[k+1] = distptr[k];
207+
}
208+
nidxptr[k+1] = j + update;
209+
distptr[k+1] = d;
210+
}
211+
}
212+
}
213+
}
214+
}
215+
216+
const Mat *src1;
217+
const Mat *src2;
218+
Mat *dist;
219+
Mat *nidx;
220+
const Mat *mask;
221+
int K;
222+
int update;
223+
BatchDistFunc func;
224+
};
225+
226+
}
227+
228+
void cv::batchDistance( InputArray _src1, InputArray _src2,
229+
OutputArray _dist, int dtype, OutputArray _nidx,
230+
int normType, int K, InputArray _mask,
231+
int update, bool crosscheck )
232+
{
233+
CV_INSTRUMENT_REGION()
234+
235+
Mat src1 = _src1.getMat(), src2 = _src2.getMat(), mask = _mask.getMat();
236+
int type = src1.type();
237+
CV_Assert( type == src2.type() && src1.cols == src2.cols &&
238+
(type == CV_32F || type == CV_8U));
239+
CV_Assert( _nidx.needed() == (K > 0) );
240+
241+
if( dtype == -1 )
242+
{
243+
dtype = normType == NORM_HAMMING || normType == NORM_HAMMING2 ? CV_32S : CV_32F;
244+
}
245+
CV_Assert( (type == CV_8U && dtype == CV_32S) || dtype == CV_32F);
246+
247+
K = std::min(K, src2.rows);
248+
249+
_dist.create(src1.rows, (K > 0 ? K : src2.rows), dtype);
250+
Mat dist = _dist.getMat(), nidx;
251+
if( _nidx.needed() )
252+
{
253+
_nidx.create(dist.size(), CV_32S);
254+
nidx = _nidx.getMat();
255+
}
256+
257+
if( update == 0 && K > 0 )
258+
{
259+
dist = Scalar::all(dtype == CV_32S ? (double)INT_MAX : (double)FLT_MAX);
260+
nidx = Scalar::all(-1);
261+
}
262+
263+
if( crosscheck )
264+
{
265+
CV_Assert( K == 1 && update == 0 && mask.empty() );
266+
Mat tdist, tidx;
267+
batchDistance(src2, src1, tdist, dtype, tidx, normType, K, mask, 0, false);
268+
269+
// if an idx-th element from src1 appeared to be the nearest to i-th element of src2,
270+
// we update the minimum mutual distance between idx-th element of src1 and the whole src2 set.
271+
// As a result, if nidx[idx] = i*, it means that idx-th element of src1 is the nearest
272+
// to i*-th element of src2 and i*-th element of src2 is the closest to idx-th element of src1.
273+
// If nidx[idx] = -1, it means that there is no such ideal couple for it in src2.
274+
// This O(N) procedure is called cross-check and it helps to eliminate some false matches.
275+
if( dtype == CV_32S )
276+
{
277+
for( int i = 0; i < tdist.rows; i++ )
278+
{
279+
int idx = tidx.at<int>(i);
280+
int d = tdist.at<int>(i), d0 = dist.at<int>(idx);
281+
if( d < d0 )
282+
{
283+
dist.at<int>(idx) = d;
284+
nidx.at<int>(idx) = i + update;
285+
}
286+
}
287+
}
288+
else
289+
{
290+
for( int i = 0; i < tdist.rows; i++ )
291+
{
292+
int idx = tidx.at<int>(i);
293+
float d = tdist.at<float>(i), d0 = dist.at<float>(idx);
294+
if( d < d0 )
295+
{
296+
dist.at<float>(idx) = d;
297+
nidx.at<int>(idx) = i + update;
298+
}
299+
}
300+
}
301+
return;
302+
}
303+
304+
BatchDistFunc func = 0;
305+
if( type == CV_8U )
306+
{
307+
if( normType == NORM_L1 && dtype == CV_32S )
308+
func = (BatchDistFunc)batchDistL1_8u32s;
309+
else if( normType == NORM_L1 && dtype == CV_32F )
310+
func = (BatchDistFunc)batchDistL1_8u32f;
311+
else if( normType == NORM_L2SQR && dtype == CV_32S )
312+
func = (BatchDistFunc)batchDistL2Sqr_8u32s;
313+
else if( normType == NORM_L2SQR && dtype == CV_32F )
314+
func = (BatchDistFunc)batchDistL2Sqr_8u32f;
315+
else if( normType == NORM_L2 && dtype == CV_32F )
316+
func = (BatchDistFunc)batchDistL2_8u32f;
317+
else if( normType == NORM_HAMMING && dtype == CV_32S )
318+
func = (BatchDistFunc)batchDistHamming;
319+
else if( normType == NORM_HAMMING2 && dtype == CV_32S )
320+
func = (BatchDistFunc)batchDistHamming2;
321+
}
322+
else if( type == CV_32F && dtype == CV_32F )
323+
{
324+
if( normType == NORM_L1 )
325+
func = (BatchDistFunc)batchDistL1_32f;
326+
else if( normType == NORM_L2SQR )
327+
func = (BatchDistFunc)batchDistL2Sqr_32f;
328+
else if( normType == NORM_L2 )
329+
func = (BatchDistFunc)batchDistL2_32f;
330+
}
331+
332+
if( func == 0 )
333+
CV_Error_(CV_StsUnsupportedFormat,
334+
("The combination of type=%d, dtype=%d and normType=%d is not supported",
335+
type, dtype, normType));
336+
337+
parallel_for_(Range(0, src1.rows),
338+
BatchDistInvoker(src1, src2, dist, nidx, K, mask, update, func));
339+
}

0 commit comments

Comments
 (0)