@@ -273,31 +273,37 @@ void MKLDNNTester::printVector(const VectorPtr& v) {
273
273
VLOG (MKLDNN_ALL) << std::endl << ostr.str ();
274
274
}
275
275
276
- double MKLDNNTester::getDelta (const real* d1 ,
277
- const real* d2 ,
276
+ double MKLDNNTester::getDelta (const real* refer ,
277
+ const real* value ,
278
278
size_t len,
279
279
const float failRate,
280
280
const float thres) {
281
281
double delta = 0 , sum = 0 ;
282
282
int failCnt = 0 ;
283
283
const double eps = 1e-5 ;
284
- double maxOut = 0 ;
284
+ double maxRatio = 0 ;
285
285
for (size_t i = 0 ; i < len; ++i) {
286
- double ref = fabs (d2[i]);
287
- double diff = fabs (d1[i] - d2[i]);
286
+ double ref = fabs (refer[i]);
287
+ double val = fabs (value[i]);
288
+ double diff = fabs (refer[i] - value[i]);
288
289
delta += diff;
289
290
sum += ref;
290
- if (ref > eps && fabs (d1[i]) > eps && diff / ref > thres) {
291
- maxOut = std::max (maxOut, diff / ref);
291
+ if (ref < eps && val < eps) { // both values are very small
292
+ continue ;
293
+ }
294
+ double ratio = diff / ref;
295
+ if (ratio > thres) {
296
+ maxRatio = std::max (maxRatio, ratio);
292
297
failCnt++;
293
298
}
294
299
}
295
- EXPECT_TRUE (std::isnormal (sum));
296
300
EXPECT_FALSE (std::isinf (sum));
301
+ EXPECT_FALSE (std::isnan (sum));
297
302
EXPECT_FALSE (std::isnan (delta));
298
303
VLOG (MKLDNN_ALL) << " reference avg data: " << sum / len
299
304
<< " , delta: " << delta / sum << " , failCnt:" << failCnt;
300
- return (failCnt / (float )len) > failRate ? maxOut : delta / sum;
305
+ double res = sum > eps ? delta / sum : eps;
306
+ return (failCnt / (float )len) > failRate ? maxRatio : res;
301
307
}
302
308
303
309
double MKLDNNTester::compareMatrix (const MatrixPtr& m1, const MatrixPtr& m2) {
@@ -543,12 +549,12 @@ void MKLDNNTester::getOutResult(const std::string& configPath,
543
549
void MKLDNNTester::compareResult (DataOut& ref, DataOut& dnn, float eps) {
544
550
CHECK_EQ (ref.outValues .size (), dnn.outValues .size ());
545
551
CHECK_EQ (ref.paraValues .size (), dnn.paraValues .size ());
546
- VLOG (MKLDNN_TESTS) << " compare value size: " << ref.outValues .size ();
547
552
for (size_t i = 0 ; i < ref.outValues .size (); i++) {
553
+ VLOG (MKLDNN_TESTS) << " compare value index: " << i;
548
554
EXPECT_LE (fabs (compareMatrix (ref.outValues [i], dnn.outValues [i])), eps);
549
555
}
550
- VLOG (MKLDNN_TESTS) << " compare param size: " << ref.outValues .size ();
551
556
for (size_t i = 0 ; i < ref.paraValues .size (); i++) {
557
+ VLOG (MKLDNN_TESTS) << " compare param index: " << i;
552
558
EXPECT_LE (fabs (compareVector (ref.paraValues [i], dnn.paraValues [i])), eps);
553
559
}
554
560
}
0 commit comments