Skip to content

Commit 97d49a8

Browse files
committed
add C++ tutorial samples about data generation and classifier
1 parent 197fba6 commit 97d49a8

File tree

8 files changed

+313
-31
lines changed

8 files changed

+313
-31
lines changed

modules/cnn_3dobj/samples/CMakeLists.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,15 @@ SET(CMAKE_CXX_FLAGS_DEBUG "$ENV{CXXFLAGS} -O0 -Wall -g -ggdb ")
33
SET(CMAKE_CXX_FLAGS_RELEASE "$ENV{CXXFLAGS} -O3 -Wall")
44
project(sphereview_test)
55
find_package(OpenCV REQUIRED)
6-
set(SOURCES_generator sphereview_3dobj_demo.cpp)
6+
set(SOURCES_generator demo_sphereview_data.cpp)
77
include_directories(${OpenCV_INCLUDE_DIRS})
88
add_executable(sphereview_test ${SOURCES_generator})
99
target_link_libraries(sphereview_test ${OpenCV_LIBS})
1010

11-
set(SOURCES_classifier classifyIMG_demo.cpp)
11+
set(SOURCES_classifier demo_classify.cpp)
1212
add_executable(classify_test ${SOURCES_classifier})
1313
target_link_libraries(classify_test ${OpenCV_LIBS})
1414

15-
set(SOURCES_modelanalysis model_analysis_demo.cpp)
15+
set(SOURCES_modelanalysis demo_model_analysis.cpp)
1616
add_executable(model_test ${SOURCES_modelanalysis})
1717
target_link_libraries(model_test ${OpenCV_LIBS})

modules/cnn_3dobj/samples/classifyIMG_demo.cpp renamed to modules/cnn_3dobj/samples/demo_classify.cpp

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,11 @@
3232
* POSSIBILITY OF SUCH DAMAGE.
3333
*
3434
*/
35+
/**
36+
* @file demo_classify.cpp
37+
* @brief Feature extraction and classification.
38+
* @author Yida Wang
39+
*/
3540
#define HAVE_CAFFE
3641
#include <opencv2/cnn_3dobj.hpp>
3742
#include <opencv2/features2d/features2d.hpp>
@@ -40,7 +45,10 @@ using namespace cv;
4045
using namespace std;
4146
using namespace cv::cnn_3dobj;
4247

43-
/* Get the file name from a root dictionary. */
48+
/**
49+
* @function listDir
50+
* @brief Making all files names under a directory into a list
51+
*/
4452
void listDir(const char *path, std::vector<string>& files, bool r)
4553
{
4654
DIR *pDir;
@@ -70,20 +78,25 @@ void listDir(const char *path, std::vector<string>& files, bool r)
7078
sort(files.begin(),files.end());
7179
};
7280

81+
/**
82+
* @function main
83+
*/
7384
int main(int argc, char** argv)
7485
{
75-
const String keys = "{help | | this demo will convert a set of images in a particular path into leveldb database for feature extraction using Caffe. If there little variance in data such as human faces, you can add a mean_file, otherwise it is not so useful}"
86+
const String keys = "{help | | This sample will extract featrues from reference images and target image for classification. You can add a mean_file if there little variance in data such as human faces, otherwise it is not so useful}"
7687
"{src_dir | ../data/images_all/ | Source direction of the images ready for being used for extract feature as gallery.}"
7788
"{caffemodel | ../../testdata/cv/3d_triplet_iter_30000.caffemodel | caffe model for feature exrtaction.}"
7889
"{network_forIMG | ../../testdata/cv/3d_triplet_testIMG.prototxt | Network definition file used for extracting feature from a single image and making a classification}"
7990
"{mean_file | no | The mean file generated by Caffe from all gallery images, this could be used for mean value substraction from all images. If you want to use the mean file, you can set this as ../data/images_mean/triplet_mean.binaryproto.}"
8091
"{target_img | ../data/images_all/1_8.png | Path of image waiting to be classified.}"
8192
"{feature_blob | feat | Name of layer which will represent as the feature, in this network, ip1 or feat is well.}"
8293
"{num_candidate | 15 | Number of candidates in gallery as the prediction result.}"
83-
"{device | CPU | device}"
84-
"{dev_id | 0 | dev_id}";
94+
"{device | CPU | Device type: CPU or GPU}"
95+
"{dev_id | 0 | Device id}";
96+
97+
/* get parameters from comand line */
8598
cv::CommandLineParser parser(argc, argv, keys);
86-
parser.about("Demo for object data classification and pose estimation");
99+
parser.about("Feature extraction and classification");
87100
if (parser.has("help"))
88101
{
89102
parser.printMessage();
@@ -99,13 +112,18 @@ int main(int argc, char** argv)
99112
string device = parser.get<string>("device");
100113
int dev_id = parser.get<int>("dev_id");
101114

115+
/* Initialize a net work with Device */
102116
cv::cnn_3dobj::descriptorExtractor descriptor(device);
103117
std::cout << "Using" << descriptor.getDeviceType() << std::endl;
118+
119+
/* Load net with the caffe trained net work parameter and structure */
104120
if (strcmp(mean_file.c_str(), "no") == 0)
105121
descriptor.loadNet(network_forIMG, caffemodel);
106122
else
107123
descriptor.loadNet(network_forIMG, caffemodel, mean_file);
108124
std::vector<string> name_gallery;
125+
126+
/* List the file names under a given path */
109127
listDir(src_dir.c_str(), name_gallery, false);
110128
for (unsigned int i = 0; i < name_gallery.size(); i++)
111129
{
@@ -117,23 +135,31 @@ int main(int argc, char** argv)
117135
{
118136
img_gallery.push_back(cv::imread(name_gallery[i], -1));
119137
}
138+
139+
/* Extract feature from a set of images */
120140
descriptor.extract(img_gallery, feature_reference, feature_blob);
121141

122142
std::cout << std::endl << "---------- Prediction for " << target_img << " ----------" << std::endl;
123143

124144
cv::Mat img = cv::imread(target_img, -1);
125-
// CHECK(!img.empty()) << "Unable to decode image " << target_img;
126145
std::cout << std::endl << "---------- Features of gallery images ----------" << std::endl;
127146
std::vector<std::pair<string, float> > prediction;
147+
148+
/* Print features of the reference images. */
128149
for (unsigned int i = 0; i < feature_reference.rows; i++)
129150
std::cout << feature_reference.row(i) << endl;
130151
cv::Mat feature_test;
131152
descriptor.extract(img, feature_test, feature_blob);
153+
/* Initialize a matcher which using L2 distance. */
132154
cv::BFMatcher matcher(NORM_L2);
133155
std::vector<std::vector<cv::DMatch> > matches;
156+
/* Have a KNN match on the target and reference images. */
134157
matcher.knnMatch(feature_test, feature_reference, matches, num_candidate);
158+
159+
/* Print feature of the target image waiting to be classified. */
135160
std::cout << std::endl << "---------- Features of target image: " << target_img << "----------" << endl << feature_test << std::endl;
136-
// Print the top N prediction.
161+
162+
/* Print the top N prediction. */
137163
std::cout << std::endl << "---------- Prediction result(Distance - File Name in Gallery) ----------" << std::endl;
138164
for (size_t i = 0; i < matches[0].size(); ++i)
139165
{

modules/cnn_3dobj/samples/model_analysis_demo.cpp renamed to modules/cnn_3dobj/samples/demo_model_analysis.cpp

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,11 @@
3232
* POSSIBILITY OF SUCH DAMAGE.
3333
*
3434
*/
35+
/**
36+
* @file sphereview_3dobj_demo.cpp
37+
* @brief Generating training data for CNN with triplet loss.
38+
* @author Yida Wang
39+
*/
3540
#define HAVE_CAFFE
3641
#include <iostream>
3742
#include "opencv2/imgproc.hpp"
@@ -52,6 +57,7 @@ int main(int argc, char** argv)
5257
"{feature_blob | feat | Name of layer which will represent as the feature, in this network, ip1 or feat is well.}"
5358
"{device | CPU | device}"
5459
"{dev_id | 0 | dev_id}";
60+
/* Get parameters from comand line. */
5561
cv::CommandLineParser parser(argc, argv, keys);
5662
parser.about("Demo for object data classification and pose estimation");
5763
if (parser.has("help"))
@@ -70,13 +76,23 @@ int main(int argc, char** argv)
7076
string device = parser.get<string>("device");
7177
int dev_id = parser.get<int>("dev_id");
7278

73-
7479
std::vector<string> ref_img;
80+
/* Sample which is most closest in pose to reference image
81+
*and also the same class.
82+
*/
7583
ref_img.push_back(ref_img1);
84+
/* Sample which is less closest in pose to reference image
85+
*and also the same class.
86+
*/
7687
ref_img.push_back(ref_img2);
88+
/* Sample which is very close in pose to reference image
89+
*but not the same class.
90+
*/
7791
ref_img.push_back(ref_img3);
7892

93+
/* Initialize a net work with Device. */
7994
cv::cnn_3dobj::descriptorExtractor descriptor(device, dev_id);
95+
/* Load net with the caffe trained net work parameter and structure. */
8096
if (strcmp(mean_file.c_str(), "no") == 0)
8197
descriptor.loadNet(network_forIMG, caffemodel);
8298
else
@@ -116,6 +132,10 @@ int main(int argc, char** argv)
116132
}
117133
bool pose_pass = false;
118134
bool class_pass = false;
135+
/* Have comparations on the distance between reference image and 3 other images
136+
*distance between closest sample and reference image should be smallest and
137+
*distance between sample in another class and reference image should be largest.
138+
*/
119139
if (matches[0] < matches[1] && matches[0] < matches[2])
120140
pose_pass = true;
121141
if (matches[1] < matches[2])

modules/cnn_3dobj/samples/sphereview_3dobj_demo.cpp renamed to modules/cnn_3dobj/samples/demo_sphereview_data.cpp

Lines changed: 29 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,11 @@
3232
* POSSIBILITY OF SUCH DAMAGE.
3333
*
3434
*/
35+
/**
36+
* @file demo_sphereview_data.cpp
37+
* @brief Generating training data for CNN with triplet loss.
38+
* @author Yida Wang
39+
*/
3540
#define HAVE_CAFFE
3641
#include <opencv2/cnn_3dobj.hpp>
3742
#include <opencv2/viz/vizcore.hpp>
@@ -44,14 +49,15 @@ int main(int argc, char *argv[])
4449
{
4550
const String keys = "{help | | demo :$ ./sphereview_test -ite_depth=2 -plymodel=../data/3Dmodel/ape.ply -imagedir=../data/images_all/ -labeldir=../data/label_all.txt -num_class=4 -label_class=0, then press 'q' to run the demo for images generation when you see the gray background and a coordinate.}"
4651
"{ite_depth | 2 | Iteration of sphere generation.}"
47-
"{plymodel | ../data/3Dmodel/ape.ply | path of the '.ply' file for image rendering. }"
48-
"{imagedir | ../data/images_all/ | path of the generated images for one particular .ply model. }"
49-
"{labeldir | ../data/label_all.txt | path of the generated images for one particular .ply model. }"
50-
"{num_class | 4 | total number of classes of models}"
51-
"{label_class | 0 | class label of current .ply model}"
52-
"{rgb_use | 0 | use RGB image or grayscale}";
52+
"{plymodel | ../data/3Dmodel/ape.ply | Path of the '.ply' file for image rendering. }"
53+
"{imagedir | ../data/images_all/ | Path of the generated images for one particular .ply model. }"
54+
"{labeldir | ../data/label_all.txt | Path of the generated images for one particular .ply model. }"
55+
"{num_class | 4 | Total number of classes of models}"
56+
"{label_class | 0 | Class label of current .ply model}"
57+
"{rgb_use | 0 | Use RGB image or grayscale}";
58+
/* Get parameters from comand line. */
5359
cv::CommandLineParser parser(argc, argv, keys);
54-
parser.about("Demo for Sphere View data generation");
60+
parser.about("Generating training data for CNN with triplet loss");
5561
if (parser.has("help"))
5662
{
5763
parser.printMessage();
@@ -70,23 +76,25 @@ int main(int argc, char *argv[])
7076
char* p=(char*)labeldir.data();
7177
imglabel.open(p, fstream::app|fstream::out);
7278
bool camera_pov = (true);
73-
/// Create a window
79+
/* Create a window using viz. */
7480
viz::Viz3d myWindow("Coordinate Frame");
81+
/* Set window size as 64*64, we use this scale as default. */
7582
myWindow.setWindowSize(Size(64,64));
76-
/// Add coordinate axes
83+
/* Add coordinate axes. */
7784
myWindow.showWidget("Coordinate Widget", viz::WCoordinateSystem());
85+
/* Set background color. */
7886
myWindow.setBackgroundColor(viz::Color::gray());
7987
myWindow.spin();
80-
/// Set background color
81-
/// Let's assume camera has the following properties
82-
/// Create a cloud widget.
88+
/* Create a Mesh widget, loading .ply models. */
8389
viz::Mesh objmesh = viz::Mesh::load(plymodel);
90+
/* Get the center of the generated mesh widget, cause some .ply files. */
8491
Point3d cam_focal_point = ViewSphere.getCenter(objmesh.cloud);
8592
float radius = ViewSphere.getRadius(objmesh.cloud, cam_focal_point);
8693
Point3d cam_y_dir(0.0f,0.0f,1.0f);
8794
const char* headerPath = "../data/header_for_";
8895
const char* binaryPath = "../data/binary_";
8996
ViewSphere.createHeader((int)campos.size(), 64, 64, headerPath);
97+
/* Images will be saved as .png files. */
9098
for(int pose = 0; pose < (int)campos.size(); pose++){
9199
char* temp = new char;
92100
sprintf (temp,"%d",label_class);
@@ -97,17 +105,16 @@ int main(int argc, char *argv[])
97105
filename += ".png";
98106
imglabel << filename << ' ' << (int)(campos.at(pose).x*100) << ' ' << (int)(campos.at(pose).y*100) << ' ' << (int)(campos.at(pose).z*100) << endl;
99107
filename = imagedir + filename;
100-
/// We can get the pose of the cam using makeCameraPoses
108+
/* Get the pose of the camera using makeCameraPoses. */
101109
Affine3f cam_pose = viz::makeCameraPose(campos.at(pose)*radius+cam_focal_point, cam_focal_point, cam_y_dir*radius+cam_focal_point);
102-
/// We can get the transformation matrix from camera coordinate system to global using
103-
/// - makeTransformToGlobal. We need the axes of the camera
110+
/* Get the transformation matrix from camera coordinate system to global. */
104111
Affine3f transform = viz::makeTransformToGlobal(Vec3f(1.0f,0.0f,0.0f), Vec3f(0.0f,1.0f,0.0f), Vec3f(0.0f,0.0f,1.0f), campos.at(pose));
105112
viz::WMesh mesh_widget(objmesh);
106-
/// Pose of the widget in camera frame
113+
/* Pose of the widget in camera frame. */
107114
Affine3f cloud_pose = Affine3f().translate(Vec3f(1.0f,1.0f,1.0f));
108-
/// Pose of the widget in global frame
115+
/* Pose of the widget in global frame. */
109116
Affine3f cloud_pose_global = transform * cloud_pose;
110-
/// Visualize camera frame
117+
/* Visualize camera frame. */
111118
if (!camera_pov)
112119
{
113120
viz::WCameraPosition cpw(1); // Coordinate axes
@@ -116,14 +123,16 @@ int main(int argc, char *argv[])
116123
myWindow.showWidget("CPW_FRUSTUM", cpw_frustum, cam_pose);
117124
}
118125

119-
/// Visualize widget
126+
/* Visualize widget. */
120127
mesh_widget.setRenderingProperty(viz::LINE_WIDTH, 4.0);
121128
myWindow.showWidget("ape", mesh_widget, cloud_pose_global);
122129

123-
/// Set the viewer pose to that of camera
130+
/* Set the viewer pose to that of camera. */
124131
if (camera_pov)
125132
myWindow.setViewerPose(cam_pose);
133+
/* Save screen shot as images. */
126134
myWindow.saveScreenshot(filename);
135+
/* Write images into binary files for further using in CNN training. */
127136
ViewSphere.writeBinaryfile(filename, binaryPath, headerPath,(int)campos.size()*num_class, label_class, (int)(campos.at(pose).x*100), (int)(campos.at(pose).y*100), (int)(campos.at(pose).z*100), rgb_use);
128137
}
129138
imglabel.close();
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
Training data generation using Icosphere {#tutorial_data_generation}
2+
=============
3+
4+
Goal
5+
----
6+
7+
In this tutorial you will learn how to
8+
9+
- Conduct a point cloud of camera view on sphere.
10+
- Generate training images using 3D model.
11+
12+
Code
13+
----
14+
15+
You can download the code from [here ](https://github.com/Wangyida/opencv_contrib/blob/cnn_3dobj/samples/demo_sphereview_data.cpp).
16+
@include cnn_3dobj/samples/demo_sphereview_data.cpp
17+
18+
Explanation
19+
-----------
20+
21+
Here is the general structure of the program:
22+
23+
- Create a window.
24+
@code{.cpp}
25+
viz::Viz3d myWindow("Coordinate Frame");
26+
@endcode
27+
28+
- Set window size as 64*64, we use this scale as default.
29+
@code{.cpp}
30+
myWindow.setWindowSize(Size(64,64));
31+
@endcode
32+
33+
- Add coordinate axes.
34+
@code{.cpp}
35+
myWindow.showWidget("Coordinate Widget", viz::WCoordinateSystem());
36+
myWindow.setBackgroundColor(viz::Color::gray());
37+
myWindow.spin();
38+
@endcode
39+
40+
- Create a Mesh widget, loading .ply models.
41+
@code{.cpp}
42+
viz::Mesh objmesh = viz::Mesh::load(plymodel);
43+
@endcode
44+
- Get the center of the generated mesh widget, cause some .ply files.
45+
@code{.cpp}
46+
Point3d cam_focal_point = ViewSphere.getCenter(objmesh.cloud);
47+
@endcode
48+
49+
- Get the pose of the camera using makeCameraPoses.
50+
@code{.cpp}
51+
Affine3f cam_pose = viz::makeCameraPose(campos.at(pose)*radius+cam_focal_point, cam_focal_point, cam_y_dir*radius+cam_focal_point);
52+
@endcode
53+
54+
- Get the transformation matrix from camera coordinate system to global.
55+
@code{.cpp}
56+
Affine3f transform = viz::makeTransformToGlobal(Vec3f(1.0f,0.0f,0.0f), Vec3f(0.0f,1.0f,0.0f), Vec3f(0.0f,0.0f,1.0f), campos.at(pose));
57+
viz::WMesh mesh_widget(objmesh);
58+
@endcode
59+
60+
- Save screen shot as images.
61+
@code{.cpp}
62+
myWindow.saveScreenshot(filename);
63+
@endcode
64+
65+
- Write images into binary files for further using in CNN training.
66+
@code{.cpp}
67+
ViewSphere.writeBinaryfile(filename, binaryPath, headerPath,(int)campos.size()*num_class, label_class, (int)(campos.at(pose).x*100), (int)(campos.at(pose).y*100), (int)(campos.at(pose).z*100), rgb_use);
68+
@endcode
69+
70+
Results
71+
-------
72+
73+
Here is collection images created by this demo using 4 model.
74+
75+
![](images_all/1_8.png)

0 commit comments

Comments
 (0)