Skip to content

Commit c0dcec9

Browse files
fantesmergify[bot]
authored andcommitted
feat(torch): support for multiple test sets
1 parent 1d85f00 commit c0dcec9

28 files changed

+1381
-391
lines changed

src/backends/caffe/caffeinputconns.cc

Lines changed: 79 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -611,6 +611,13 @@ namespace dd
611611
write_images_to_hdf5(img_lists.at(1), test_dbfullname, test_list,
612612
alphabet, max_ocr_length, false);
613613
}
614+
if (img_lists.size() > 2)
615+
{
616+
_logger->error(
617+
"multiple test sets not supported by caffe backend yet");
618+
throw InputConnectorBadParamException(
619+
"multiple test sets not supported by caffe backend yet");
620+
}
614621

615622
// save the alphabet as corresp file
616623
std::ofstream correspf(_model_repo + "/" + _correspname, std::ios::binary);
@@ -1341,8 +1348,9 @@ namespace dd
13411348
}
13421349

13431350
/*- DDCCsv -*/
1344-
int DDCCsv::read_file(const std::string &fname)
1351+
int DDCCsv::read_file(const std::string &fname, int test_id)
13451352
{
1353+
(void)test_id;
13461354
if (_cifc)
13471355
{
13481356
_cifc->read_csv(fname);
@@ -1483,7 +1491,8 @@ namespace dd
14831491
{
14841492
if (!_db)
14851493
{
1486-
CSVInputFileConn::add_test_csvline(id, vals);
1494+
// MULTIPLE TEST SETS : we consider here only 1 test set
1495+
CSVInputFileConn::add_test_csvline(0, id, vals);
14871496
return;
14881497
}
14891498

@@ -1535,8 +1544,9 @@ namespace dd
15351544
if (!fileops::file_exists(_csv_fname))
15361545
throw InputConnectorBadParamException("training CSV file " + _csv_fname
15371546
+ " does not exist");
1547+
// MULTIPLE TEST SETS : we consider here only 1 test set
15381548
if (_uris.size() > 1)
1539-
_csv_test_fname = _uris.at(1);
1549+
_csv_test_fnames.push_back(_uris.at(1));
15401550
/*if (ad_input.has("label"))
15411551
_label = ad_input.get("label").get<std::string>();
15421552
else if (_train && _label.empty()) throw
@@ -1644,13 +1654,27 @@ namespace dd
16441654
// XXX: remove in-memory data, which pre-processing is useless and
16451655
// should be avoided
16461656
destroy_txt_entries(_txt);
1647-
destroy_txt_entries(_test_txt);
1657+
// MULTIPLE TEST SETS : we consider here only 1 test set
1658+
// destroy_txt_entries(_test_txt);
1659+
for (auto tt : _tests_txt)
1660+
destroy_txt_entries(tt);
1661+
_tests_txt.clear();
16481662

16491663
return 0;
16501664
}
16511665

16521666
_db_batchsize = _txt.size();
1653-
_db_testbatchsize = _test_txt.size();
1667+
// MULTIPLE TEST SETS : we consider here only 1 test set
1668+
1669+
if (_tests_txt.size() > 1)
1670+
{
1671+
_logger->error(
1672+
"multiple test sets not supported by caffe backend yet");
1673+
throw InputConnectorBadParamException(
1674+
"multiple test sets not supported by caffe backend yet");
1675+
}
1676+
1677+
_db_testbatchsize = _tests_txt[0].size();
16541678

16551679
_logger->info("db_batchsize={} / db_testbatchsize={}", _db_batchsize,
16561680
_db_testbatchsize);
@@ -1661,13 +1685,19 @@ namespace dd
16611685
else
16621686
write_sparse_txt_to_db(dbfullname, _txt);
16631687
destroy_txt_entries(_txt);
1664-
if (!_test_txt.empty())
1688+
// MULTIPLE TEST SETS : we consider here only 1 test set
1689+
if (!_tests_txt.empty() && !_tests_txt[0].empty())
16651690
{
16661691
if (!_sparse)
1667-
write_txt_to_db(testdbfullname, _test_txt);
1692+
// MULTIPLE TEST SETS : we consider here only 1 test set
1693+
write_txt_to_db(testdbfullname, _tests_txt[0]);
16681694
else
1669-
write_sparse_txt_to_db(testdbfullname, _test_txt);
1670-
destroy_txt_entries(_test_txt);
1695+
// MULTIPLE TEST SETS : we consider here only 1 test set
1696+
write_sparse_txt_to_db(testdbfullname, _tests_txt[0]);
1697+
// MULTIPLE TEST SETS : we consider here only 1 test set
1698+
for (auto tt : _tests_txt)
1699+
destroy_txt_entries(tt);
1700+
_tests_txt.clear();
16711701
}
16721702

16731703
return 0;
@@ -1678,11 +1708,17 @@ namespace dd
16781708
if (_cifc)
16791709
{
16801710
_cifc->_columns.clear();
1681-
std::string test_file = _cifc->_csv_test_fname;
1682-
_cifc->_csv_test_fname = "";
1711+
// MULTIPLE TEST SETS : we consider here only 1 test set
1712+
// std::string test_file = _cifc->_csv_test_fname;
1713+
std::vector<std::string> test_files = _cifc->_csv_test_fnames;
1714+
// MULTIPLE TEST SETS : we consider here only 1 test set
1715+
// _cifc->_csv_test_fname = "";
1716+
_cifc->_csv_test_fnames.clear();
16831717
_cifc->read_csv(fname);
16841718
_cifc->push_csv_to_csvts(is_test_data);
1685-
_cifc->_csv_test_fname = test_file;
1719+
// MULTIPLE TEST SETS : we consider here only 1 test set
1720+
//_cifc->_csv_test_fname = test_file;
1721+
_cifc->_csv_test_fnames = test_files;
16861722
_cifc->_ids.push_back(fname);
16871723
return 0;
16881724
}
@@ -1728,8 +1764,10 @@ namespace dd
17281764

17291765
//- read all test files
17301766
std::unordered_set<std::string> testfiles;
1731-
if (!_cifc->_csv_test_fname.empty())
1732-
fileops::list_directory(_cifc->_csv_test_fname, true, false, true,
1767+
// MULTIPLE TEST SETS : we consider here only 1 test set
1768+
if (!_cifc->_csv_test_fnames.empty())
1769+
// MULTIPLE TEST SETS : we consider here only 1 test set
1770+
fileops::list_directory(_cifc->_csv_test_fnames[0], true, false, true,
17331771
testfiles);
17341772

17351773
std::unordered_set<std::string> allfiles = trainfiles;
@@ -1809,15 +1847,19 @@ namespace dd
18091847
if (_datadim != -1)
18101848
return;
18111849
if (is_test_data)
1812-
_datadim = _csvtsdata_test[0][0]._v.size() + 1;
1850+
// MULTIPLE TEST SETS : we consider here only 1 test set
1851+
_datadim = _csvtsdata_tests[0][0][0]._v.size() + 1;
18131852
else
18141853
_datadim = _csvtsdata[0][0]._v.size() + 1;
18151854
}
18161855

18171856
void CSVTSCaffeInputFileConn::push_csv_to_csvts(bool is_test_data)
18181857
{
1819-
1820-
CSVTSInputFileConn::push_csv_to_csvts(is_test_data);
1858+
// MULTIPLE TEST SETS : we consider here only one test set
1859+
if (is_test_data)
1860+
CSVTSInputFileConn::push_csv_to_csvts(0);
1861+
else
1862+
CSVTSInputFileConn::push_csv_to_csvts(-1);
18211863
set_datadim(is_test_data);
18221864
dv_to_db(is_test_data);
18231865
}
@@ -1875,14 +1917,26 @@ namespace dd
18751917
_db = true;
18761918
return; // done
18771919
}
1878-
_csvtsdata_test = std::move(_csvtsdata);
1920+
// MULTIPLE TEST SETS : we consider here only 1 test set
1921+
//_csvtsdata_test = std::move(_csvtsdata);
1922+
_csvtsdata_tests.push_back(std::move(_csvtsdata));
18791923
}
18801924
else
18811925
_csvtsdata.clear();
1926+
1927+
if (_csvtsdata_tests.size() > 1)
1928+
{
1929+
_logger->error(
1930+
"multiple test sets not supported by caffe backend yet");
1931+
throw InputConnectorBadParamException(
1932+
"multiple test sets not supported by caffe backend yet");
1933+
}
1934+
18821935
csvts_to_dv(true, true, true, false, _continuation);
1883-
_csvtsdata_test.clear();
1936+
// MULTIPLE TEST SETS : we consider here only 1 test set
1937+
_csvtsdata_tests[0].clear();
18841938
}
1885-
_csvtsdata_test.clear();
1939+
_csvtsdata_tests.clear();
18861940
}
18871941

18881942
void CSVTSCaffeInputFileConn::reset_dv_test()
@@ -2075,8 +2129,10 @@ namespace dd
20752129
if (!fileops::dir_exists(_csv_fname))
20762130
throw InputConnectorBadParamException("training CSV_TS dir " + _csv_fname
20772131
+ " does not exist");
2132+
// MULTIPLE TEST SETS : we consider here only 1 test set
20782133
if (_uris.size() > 1)
2079-
_csv_test_fname = _uris.at(1);
2134+
//_csv_test_fname = _uris.at(1);
2135+
_csv_test_fnames.push_back(_uris.at(1));
20802136
DDCCsvTS ddccsvts;
20812137
ddccsvts._cifc = this;
20822138
ddccsvts._adconf = ad_input;
@@ -2102,7 +2158,8 @@ namespace dd
21022158
{
21032159
dv = &_dv_test;
21042160
index = &_dv_test_index;
2105-
data = &this->_csvtsdata_test;
2161+
// MULTIPLE TEST SETS : we consider here only 1 test set
2162+
data = &this->_csvtsdata_tests[0];
21062163
}
21072164
else
21082165
{

src/backends/caffe/caffeinputconns.h

Lines changed: 73 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,13 @@ namespace dd
358358
"segmentation input test file " + _uris.at(1)
359359
+ " not found");
360360
}
361+
if (_uris.size() > 2)
362+
{
363+
_logger->error(
364+
"multiple test sets not supported by caffe backend yet");
365+
throw InputConnectorBadParamException(
366+
"multiple test sets not supported by caffe backend yet");
367+
}
361368

362369
// class weights if any
363370
write_class_weights(_model_repo, ad_mllib);
@@ -390,6 +397,13 @@ namespace dd
390397
"object detection input test file " + _uris.at(1)
391398
+ " not found");
392399
}
400+
if (_uris.size() > 2)
401+
{
402+
_logger->error(
403+
"multiple test sets not supported by caffe backend yet");
404+
throw InputConnectorBadParamException(
405+
"multiple test sets not supported by caffe backend yet");
406+
}
393407

394408
// - create lmdbs
395409
_dbfullname = _model_repo + "/" + _dbfullname;
@@ -451,6 +465,13 @@ namespace dd
451465
throw ex;*/
452466
return;
453467
}
468+
if (_uris.size() > 2)
469+
{
470+
_logger->error(
471+
"multiple test sets not supported by caffe backend yet");
472+
throw InputConnectorBadParamException(
473+
"multiple test sets not supported by caffe backend yet");
474+
}
454475
if (!this->_db)
455476
{
456477
// create test db for image data layer (no db of images)
@@ -611,11 +632,12 @@ namespace dd
611632
{
612633
}
613634

614-
int read_file(const std::string &fname);
635+
int read_file(const std::string &fname, int test_id);
615636
int read_db(const std::string &fname);
616637
int read_mem(const std::string &content);
617-
int read_dir(const std::string &dir)
638+
int read_dir(const std::string &dir, int test_id)
618639
{
640+
(void)test_id;
619641
throw InputConnectorBadParamException(
620642
"uri " + dir + " is a directory, requires a CSV file");
621643
}
@@ -754,12 +776,23 @@ namespace dd
754776
_db = true;
755777
return; // done
756778
}
757-
_csvdata_test = std::move(_csvdata);
779+
// MULTIPLE TEST SETS : we consider here only 1 test set
780+
_csvdata_tests.push_back(std::move(_csvdata));
758781
}
759782
else
760783
_csvdata.clear();
761-
auto hit = _csvdata_test.begin();
762-
while (hit != _csvdata_test.end())
784+
if (_csvdata_tests.size() > 1)
785+
{
786+
_logger->error(
787+
"multiple test sets not supported by caffe backend yet");
788+
throw InputConnectorBadParamException(
789+
"multiple test sets not supported by caffe backend yet");
790+
}
791+
792+
// MULTIPLE TEST SETS : we consider here only 1 test set
793+
auto hit = _csvdata_tests[0].begin();
794+
// MULTIPLE TEST SETS : we consider here only 1 test set
795+
while (hit != _csvdata_tests[0].end())
763796
{
764797
// no ids taken on the test set
765798
if (_label.size() == 1)
@@ -779,9 +812,11 @@ namespace dd
779812
this->_ids.push_back((*hit)._str);
780813
++hit;
781814
}
782-
_csvdata_test.clear();
815+
// MULTIPLE TEST SETS : we consider here only 1 test set
816+
_csvdata_tests[0].clear();
783817
}
784-
_csvdata_test.clear();
818+
// MULTIPLE TEST SETS : we consider here only 1 test set
819+
_csvdata_tests[0].clear();
785820
}
786821

787822
std::vector<caffe::Datum> get_dv_test(const int &num,
@@ -1192,36 +1227,46 @@ namespace dd
11921227
_db = true;
11931228
return; // done
11941229
}
1195-
_test_txt = std::move(_txt);
1230+
// MULTIPLE TEST SETS : we consider here only 1 test set
1231+
// _test_txt = std::move(_txt);
1232+
_tests_txt.resize(1);
1233+
_tests_txt.push_back(std::move(_txt));
11961234
}
11971235

11981236
int n = 0;
1199-
auto hit = _test_txt.begin();
1200-
while (hit != _test_txt.end())
1237+
// MULTIPLE TEST SETS: here we push all tests in dv_test w/o taking
1238+
// into account rank into request. this is okay for prediction, but
1239+
// will do weird stuff when multiple test sets will be implemented on
1240+
// caffe side
1241+
for (size_t i = 0; i < _tests_txt.size(); ++i)
12011242
{
1202-
if (!_sparse)
1203-
{
1204-
if (_characters)
1205-
_dv_test.push_back(std::move(to_datum<TxtCharEntry>(
1206-
static_cast<TxtCharEntry *>((*hit)))));
1207-
else
1208-
_dv_test.push_back(std::move(to_datum<TxtBowEntry>(
1209-
static_cast<TxtBowEntry *>((*hit)))));
1210-
}
1211-
else
1243+
auto hit = _tests_txt[i].begin();
1244+
while (hit != _tests_txt[i].end())
12121245
{
1213-
if (_characters)
1246+
if (!_sparse)
12141247
{
1215-
// TODO
1248+
if (_characters)
1249+
_dv_test.push_back(std::move(to_datum<TxtCharEntry>(
1250+
static_cast<TxtCharEntry *>((*hit)))));
1251+
else
1252+
_dv_test.push_back(std::move(to_datum<TxtBowEntry>(
1253+
static_cast<TxtBowEntry *>((*hit)))));
12161254
}
12171255
else
1218-
_dv_test_sparse.push_back(std::move(
1219-
to_sparse_datum(static_cast<TxtBowEntry *>((*hit)))));
1256+
{
1257+
if (_characters)
1258+
{
1259+
// TODO
1260+
}
1261+
else
1262+
_dv_test_sparse.push_back(std::move(to_sparse_datum(
1263+
static_cast<TxtBowEntry *>((*hit)))));
1264+
}
1265+
if (!_train)
1266+
this->_ids.push_back(std::to_string(n));
1267+
++hit;
1268+
++n;
12201269
}
1221-
if (!_train)
1222-
this->_ids.push_back(std::to_string(n));
1223-
++hit;
1224-
++n;
12251270
}
12261271
}
12271272
}

src/backends/caffe/caffelib.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1856,7 +1856,8 @@ namespace dd
18561856

18571857
if (m != "cmdiag" && m != "cmfull" && m != "clacc"
18581858
&& m != "labels" && m != "cliou" && m != "precisions"
1859-
&& m != "recalls" && m != "f1s")
1859+
&& m != "recalls" && m != "f1s" && m != "test_id"
1860+
&& m != "test_name")
18601861
// do not report confusion matrix in server logs
18611862
{
18621863
double mval = meas_obj.get(m).get<double>();

0 commit comments

Comments
 (0)