Skip to content

Commit a836892

Browse files
committed
Merge pull request opencv#10903 from alalek:ml_ann_test
2 parents 25518a1 + 12d2bd4 commit a836892

File tree

1 file changed

+51
-26
lines changed

1 file changed

+51
-26
lines changed

modules/ml/test/test_mltests2.cpp

Lines changed: 51 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -252,31 +252,35 @@ TEST(ML_ANN, ActivationFunction)
252252
}
253253
}
254254

255-
TEST(ML_ANN, Method)
255+
CV_ENUM(ANN_MLP_METHOD, ANN_MLP::RPROP, ANN_MLP::ANNEAL)
256+
257+
typedef tuple<ANN_MLP_METHOD, string, int> ML_ANN_METHOD_Params;
258+
typedef TestWithParam<ML_ANN_METHOD_Params> ML_ANN_METHOD;
259+
260+
TEST_P(ML_ANN_METHOD, Test)
256261
{
262+
int methodType = get<0>(GetParam());
263+
string methodName = get<1>(GetParam());
264+
int N = get<2>(GetParam());
265+
257266
String folder = string(cvtest::TS::ptr()->get_data_path());
258267
String original_path = folder + "waveform.data";
259-
String dataname = folder + "waveform";
268+
String dataname = folder + "waveform" + '_' + methodName;
260269

261270
Ptr<TrainData> tdata2 = TrainData::loadFromCSV(original_path, 0);
262-
Mat responses(tdata2->getResponses().rows, 3, CV_32FC1, Scalar(0));
263-
for (int i = 0; i<tdata2->getResponses().rows; i++)
271+
Mat samples = tdata2->getSamples()(Range(0, N), Range::all());
272+
Mat responses(N, 3, CV_32FC1, Scalar(0));
273+
for (int i = 0; i < N; i++)
264274
responses.at<float>(i, static_cast<int>(tdata2->getResponses().at<float>(i, 0))) = 1;
265-
Ptr<TrainData> tdata = TrainData::create(tdata2->getSamples(), ml::ROW_SAMPLE, responses);
275+
Ptr<TrainData> tdata = TrainData::create(samples, ml::ROW_SAMPLE, responses);
266276

267277
ASSERT_FALSE(tdata.empty()) << "Could not find test data file : " << original_path;
268278
RNG& rng = theRNG();
269279
rng.state = 0;
270280
tdata->setTrainTestSplitRatio(0.8);
271281

272-
vector<int> methodType;
273-
methodType.push_back(ml::ANN_MLP::RPROP);
274-
methodType.push_back(ml::ANN_MLP::ANNEAL);
275-
// methodType.push_back(ml::ANN_MLP::BACKPROP); -----> NO BACKPROP TEST
276-
vector<String> methodName;
277-
methodName.push_back("_rprop");
278-
methodName.push_back("_anneal");
279-
// methodName.push_back("_backprop"); -----> NO BACKPROP TEST
282+
Mat testSamples = tdata->getTestSamples();
283+
280284
#ifdef GENERATE_TESTDATA
281285
{
282286
Ptr<ml::ANN_MLP> xx = ml::ANN_MLP_ANNEAL::create();
@@ -296,14 +300,13 @@ TEST(ML_ANN, Method)
296300
fs.release();
297301
}
298302
#endif
299-
for (size_t i = 0; i < methodType.size(); i++)
300303
{
301304
FileStorage fs;
302-
fs.open(dataname + "_init_weight.yml.gz", FileStorage::READ + FileStorage::BASE64);
305+
fs.open(dataname + "_init_weight.yml.gz", FileStorage::READ);
303306
Ptr<ml::ANN_MLP> x = ml::ANN_MLP_ANNEAL::create();
304307
x->read(fs.root());
305-
x->setTrainMethod(methodType[i]);
306-
if (methodType[i] == ml::ANN_MLP::ANNEAL)
308+
x->setTrainMethod(methodType);
309+
if (methodType == ml::ANN_MLP::ANNEAL)
307310
{
308311
x->setAnnealEnergyRNG(RNG(CV_BIG_INT(0xffffffff)));
309312
x->setAnnealInitialT(12);
@@ -313,28 +316,50 @@ TEST(ML_ANN, Method)
313316
}
314317
x->setTermCriteria(TermCriteria(TermCriteria::COUNT, 100, 0.01));
315318
x->train(tdata, ml::ANN_MLP::NO_OUTPUT_SCALE + ml::ANN_MLP::NO_INPUT_SCALE + ml::ANN_MLP::UPDATE_WEIGHTS);
316-
ASSERT_TRUE(x->isTrained()) << "Could not train networks with " << methodName[i];
319+
ASSERT_TRUE(x->isTrained()) << "Could not train networks with " << methodName;
320+
string filename = dataname + ".yml.gz";
321+
Mat r_gold;
317322
#ifdef GENERATE_TESTDATA
318-
x->save(dataname + methodName[i] + ".yml.gz");
323+
x->save(filename);
324+
x->predict(testSamples, r_gold);
325+
{
326+
FileStorage fs_response(dataname + "_response.yml.gz", FileStorage::WRITE + FileStorage::BASE64);
327+
fs_response << "response" << r_gold;
328+
}
329+
#else
330+
{
331+
FileStorage fs_response(dataname + "_response.yml.gz", FileStorage::READ);
332+
fs_response["response"] >> r_gold;
333+
}
319334
#endif
320-
Ptr<ml::ANN_MLP> y = Algorithm::load<ANN_MLP>(dataname + methodName[i] + ".yml.gz");
321-
ASSERT_TRUE(y != NULL) << "Could not load " << dataname + methodName[i] + ".yml";
322-
Mat testSamples = tdata->getTestSamples();
323-
Mat rx, ry, dst;
335+
ASSERT_FALSE(r_gold.empty());
336+
Ptr<ml::ANN_MLP> y = Algorithm::load<ANN_MLP>(filename);
337+
ASSERT_TRUE(y != NULL) << "Could not load " << filename;
338+
Mat rx, ry;
324339
for (int j = 0; j < 4; j++)
325340
{
326341
rx = x->getWeights(j);
327342
ry = y->getWeights(j);
328343
double n = cvtest::norm(rx, ry, NORM_INF);
329-
EXPECT_LT(n, FLT_EPSILON) << "Weights are not equal for " << dataname + methodName[i] + ".yml and " << methodName[i] << " layer : " << j;
344+
EXPECT_LT(n, FLT_EPSILON) << "Weights are not equal for layer: " << j;
330345
}
331346
x->predict(testSamples, rx);
332347
y->predict(testSamples, ry);
333-
double n = cvtest::norm(rx, ry, NORM_INF);
334-
EXPECT_LT(n, FLT_EPSILON) << "Predict are not equal for " << dataname + methodName[i] + ".yml and " << methodName[i];
348+
double n = cvtest::norm(ry, rx, NORM_INF);
349+
EXPECT_LT(n, FLT_EPSILON) << "Predict are not equal to result of the saved model";
350+
n = cvtest::norm(r_gold, rx, NORM_INF);
351+
EXPECT_LT(n, FLT_EPSILON) << "Predict are not equal to 'gold' response";
335352
}
336353
}
337354

355+
INSTANTIATE_TEST_CASE_P(/*none*/, ML_ANN_METHOD,
356+
testing::Values(
357+
make_tuple<ANN_MLP_METHOD, string, int>(ml::ANN_MLP::RPROP, "rprop", 5000),
358+
make_tuple<ANN_MLP_METHOD, string, int>(ml::ANN_MLP::ANNEAL, "anneal", 1000)
359+
//make_pair<ANN_MLP_METHOD, string>(ml::ANN_MLP::BACKPROP, "backprop", 5000); -----> NO BACKPROP TEST
360+
)
361+
);
362+
338363

339364
// 6. dtree
340365
// 7. boost

0 commit comments

Comments
 (0)