Skip to content

Commit 00ad211

Browse files
committed
Merge pull request #1127 from arrybn:accuracy_scripts
2 parents b97931e + 053303a commit 00ad211

14 files changed

+860
-68
lines changed

modules/dnn/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ endif()
99

1010
set(the_description "Deep neural network module. It allows to load models from different frameworks and to make forward pass")
1111

12-
ocv_add_module(dnn opencv_core opencv_imgproc)
12+
ocv_add_module(dnn opencv_core opencv_imgproc WRAP python matlab)
1313
ocv_warnings_disable(CMAKE_CXX_FLAGS -Wno-shadow -Wno-parentheses -Wmaybe-uninitialized -Wsign-promo
1414
-Wmissing-declarations -Wmissing-prototypes
1515
)

modules/dnn/include/opencv2/dnn/dnn.hpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,16 @@ namespace dnn //! This namespace is used for dnn module functionlaity.
304304
*/
305305
CV_EXPORTS_W Net readNetFromCaffe(const String &prototxt, const String &caffeModel = String());
306306

307+
/** @brief Reads a network model stored in Tensorflow model file.
308+
* @details This is shortcut consisting from createTensorflowImporter and Net::populateNet calls.
309+
*/
310+
CV_EXPORTS_W Net readNetFromTensorflow(const String &model);
311+
312+
/** @brief Reads a network model stored in Torch model file.
313+
* @details This is shortcut consisting from createTorchImporter and Net::populateNet calls.
314+
*/
315+
CV_EXPORTS_W Net readNetFromTorch(const String &model, bool isBinary = true);
316+
307317
/** @brief Creates the importer of <a href="http://www.tensorflow.org">TensorFlow</a> framework network.
308318
* @param model path to the .pb file with binary protobuf description of the network architecture.
309319
* @returns Pointer to the created importer, NULL in failure cases.

modules/dnn/misc/python/pyopencv_dnn.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,10 @@ bool pyopencv_to(PyObject *o, dnn::DictValue &dv, const char *name)
2626
return false;
2727
}
2828

29+
template<>
30+
bool pyopencv_to(PyObject *o, std::vector<Mat> &blobs, const char *name) //required for Layer::blobs RW
31+
{
32+
return pyopencvVecConverter<Mat>::to(o, blobs, ArgInfo(name, false));
33+
}
34+
2935
#endif

modules/dnn/samples/enet-classes.txt

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
Unlabeled 0 0 0
2+
Road 128 64 128
3+
Sidewalk 244 35 232
4+
Building 70 70 70
5+
Wall 102 102 156
6+
Fence 190 153 153
7+
Pole 153 153 153
8+
TrafficLight 250 170 30
9+
TrafficSign 220 220 0
10+
Vegetation 107 142 35
11+
Terrain 152 251 152
12+
Sky 70 130 180
13+
Person 220 20 60
14+
Rider 255 0 0
15+
Car 0 0 142
16+
Truck 0 0 70
17+
Bus 0 60 100
18+
Train 0 80 100
19+
Motorcycle 0 0 230
20+
Bicycle 119 11 32

modules/dnn/samples/tf_inception.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ int main(int argc, char **argv)
9999

100100
Mat inputBlob = blobFromImage(img); //Convert Mat to image batch
101101
//! [Prepare blob]
102-
102+
inputBlob -= 117.0;
103103
//! [Set input blob]
104104
net.setBlob(inBlobName, inputBlob); //set the network input
105105
//! [Set input blob]

modules/dnn/samples/torch_enet.cpp

Lines changed: 49 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@ const String keys =
2626
"{o_blob || output blob's name. If empty, last blob's name in net is used}"
2727
;
2828

29-
std::vector<String> readClassNames(const char *filename);
3029
static void colorizeSegmentation(const Mat &score, Mat &segm,
31-
Mat &legend, vector<String> &classNames);
30+
Mat &legend, vector<String> &classNames, vector<Vec3b> &colors);
31+
static vector<Vec3b> readColors(const String &filename, vector<String>& classNames);
3232

3333
int main(int argc, char **argv)
3434
{
@@ -52,43 +52,21 @@ int main(int argc, char **argv)
5252
String classNamesFile = parser.get<String>("c_names");
5353
String resultFile = parser.get<String>("result");
5454

55-
//! [Create the importer of TensorFlow model]
56-
Ptr<dnn::Importer> importer;
57-
try //Try to import TensorFlow AlexNet model
58-
{
59-
importer = dnn::createTorchImporter(modelFile);
60-
}
61-
catch (const cv::Exception &err) //Importer can throw errors, we will catch them
62-
{
63-
std::cerr << err.msg << std::endl;
64-
}
65-
//! [Create the importer of Caffe model]
66-
67-
if (!importer)
68-
{
69-
std::cerr << "Can't load network by using the mode file: " << std::endl;
70-
std::cerr << modelFile << std::endl;
71-
exit(-1);
72-
}
73-
74-
//! [Initialize network]
75-
dnn::Net net;
76-
importer->populateNet(net);
77-
importer.release(); //We don't need importer anymore
78-
//! [Initialize network]
55+
//! [Read model and initialize network]
56+
dnn::Net net = dnn::readNetFromTorch(modelFile);
7957

8058
//! [Prepare blob]
81-
Mat img = imread(imageFile, 1);
82-
59+
Mat img = imread(imageFile), input;
8360
if (img.empty())
8461
{
8562
std::cerr << "Can't read image from the file: " << imageFile << std::endl;
8663
exit(-1);
8764
}
8865

89-
Size inputImgSize(512, 512);
66+
Size origSize = img.size();
67+
Size inputImgSize = cv::Size(1024, 512);
9068

91-
if (inputImgSize != img.size())
69+
if (inputImgSize != origSize)
9270
resize(img, img, inputImgSize); //Resize image to input size
9371

9472
Mat inputBlob = blobFromImage(img, 1./255, true); //Convert Mat to image batch
@@ -135,20 +113,18 @@ int main(int argc, char **argv)
135113

136114
if (parser.has("show"))
137115
{
138-
size_t nclasses = result.size[1];
139116
std::vector<String> classNames;
117+
vector<cv::Vec3b> colors;
140118
if(!classNamesFile.empty()) {
141-
classNames = readClassNames(classNamesFile.c_str());
142-
if (classNames.size() > nclasses)
143-
classNames = std::vector<String>(classNames.begin() + classNames.size() - nclasses,
144-
classNames.end());
119+
colors = readColors(classNamesFile, classNames);
145120
}
146121
Mat segm, legend;
147-
colorizeSegmentation(result, segm, legend, classNames);
122+
colorizeSegmentation(result, segm, legend, classNames, colors);
148123

149124
Mat show;
150-
addWeighted(img, 0.2, segm, 0.8, 0.0, show);
125+
addWeighted(img, 0.1, segm, 0.9, 0.0, show);
151126

127+
cv::resize(show, show, origSize, 0, 0, cv::INTER_NEAREST);
152128
imshow("Result", show);
153129
if(classNames.size())
154130
imshow("Legend", legend);
@@ -158,44 +134,16 @@ int main(int argc, char **argv)
158134
return 0;
159135
} //main
160136

161-
162-
std::vector<String> readClassNames(const char *filename)
163-
{
164-
std::vector<String> classNames;
165-
166-
std::ifstream fp(filename);
167-
if (!fp.is_open())
168-
{
169-
std::cerr << "File with classes labels not found: " << filename << std::endl;
170-
exit(-1);
171-
}
172-
173-
std::string name;
174-
while (!fp.eof())
175-
{
176-
std::getline(fp, name);
177-
if (name.length())
178-
classNames.push_back(name);
179-
}
180-
181-
fp.close();
182-
return classNames;
183-
}
184-
185-
static void colorizeSegmentation(const Mat &score, Mat &segm, Mat &legend, vector<String> &classNames)
137+
static void colorizeSegmentation(const Mat &score, Mat &segm, Mat &legend, vector<String> &classNames, vector<Vec3b> &colors)
186138
{
187139
const int rows = score.size[2];
188140
const int cols = score.size[3];
189141
const int chns = score.size[1];
190142

191-
vector<Vec3i> colors;
192-
RNG rng(12345678);
193-
194143
cv::Mat maxCl(rows, cols, CV_8UC1);
195144
cv::Mat maxVal(rows, cols, CV_32FC1);
196145
for (int ch = 0; ch < chns; ch++)
197146
{
198-
colors.push_back(Vec3i(rng.uniform(0, 256), rng.uniform(0, 256), rng.uniform(0, 256)));
199147
for (int row = 0; row < rows; row++)
200148
{
201149
const float *ptrScore = score.ptr<float>(0, ch, row);
@@ -235,3 +183,38 @@ static void colorizeSegmentation(const Mat &score, Mat &segm, Mat &legend, vecto
235183
}
236184
}
237185
}
186+
187+
static vector<Vec3b> readColors(const String &filename, vector<String>& classNames)
188+
{
189+
vector<cv::Vec3b> colors;
190+
classNames.clear();
191+
192+
ifstream fp(filename.c_str());
193+
if (!fp.is_open())
194+
{
195+
cerr << "File with colors not found: " << filename << endl;
196+
exit(-1);
197+
}
198+
199+
string line;
200+
while (!fp.eof())
201+
{
202+
getline(fp, line);
203+
if (line.length())
204+
{
205+
stringstream ss(line);
206+
207+
string name; ss >> name;
208+
int temp;
209+
cv::Vec3b color;
210+
ss >> temp; color[0] = temp;
211+
ss >> temp; color[1] = temp;
212+
ss >> temp; color[2] = temp;
213+
classNames.push_back(name);
214+
colors.push_back(color);
215+
}
216+
}
217+
218+
fp.close();
219+
return colors;
220+
}

modules/dnn/src/dnn.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -604,7 +604,10 @@ void Net::setBlob(String outputName, const Mat &blob_)
604604

605605
LayerData &ld = impl->layers[pin.lid];
606606
ld.outputBlobs.resize( std::max(pin.oid+1, (int)ld.requiredOutputs.size()) );
607+
MatSize prevShape = ld.outputBlobs[pin.oid].size;
607608
ld.outputBlobs[pin.oid] = blob_.clone();
609+
610+
impl->netWasAllocated = prevShape == blob_.size;
608611
}
609612

610613
Mat Net::getBlob(String outputName)

modules/dnn/src/tensorflow/tf_importer.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -736,6 +736,23 @@ void TFImporter::populateNet(Net dstNet)
736736

737737
} // namespace
738738

739+
Net cv::dnn::readNetFromTensorflow(const String &model)
740+
{
741+
Ptr<Importer> importer;
742+
try
743+
{
744+
importer = createTensorflowImporter(model);
745+
}
746+
catch(...)
747+
{
748+
}
749+
750+
Net net;
751+
if (importer)
752+
importer->populateNet(net);
753+
return net;
754+
}
755+
739756
Ptr<Importer> cv::dnn::createTensorflowImporter(const String &model)
740757
{
741758
return Ptr<Importer>(new TFImporter(model.c_str()));

modules/dnn/src/torch/torch_importer.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -970,6 +970,24 @@ Mat readTorchBlob(const String &filename, bool isBinary)
970970

971971
return importer->tensors.begin()->second;
972972
}
973+
974+
Net readNetFromTorch(const String &model, bool isBinary)
975+
{
976+
Ptr<Importer> importer;
977+
try
978+
{
979+
importer = createTorchImporter(model, isBinary);
980+
}
981+
catch(...)
982+
{
983+
}
984+
985+
Net net;
986+
if (importer)
987+
importer->populateNet(net);
988+
return net;
989+
}
990+
973991
#else
974992

975993
Ptr<Importer> createTorchImporter(const String&, bool)

0 commit comments

Comments
 (0)