Skip to content

Commit 66f5cf0

Browse files
authored
WIF - Introduce mlpack support (#2)
1 parent 1b2d8fe commit 66f5cf0

21 files changed

+2043
-2
lines changed

cmake/dependencies.cmake

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# Project dependencies
2-
find_package(Boost REQUIRED COMPONENTS regex)
2+
find_package(Armadillo REQUIRED)
3+
find_package(Boost REQUIRED COMPONENTS regex serialization)
4+
find_package(OpenMP REQUIRED)
35
find_package(Python3 REQUIRED COMPONENTS Development NumPy)
46

57
if(BUILD_WITH_UNIREC)
@@ -9,3 +11,9 @@ endif()
911

1012
# Set define for none depricated API for NUMPY
1113
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DNPY_NO_DEPRECATED_API=NPY_1_7_API_VERSION")
14+
15+
if(OpenMP_CXX_FOUND)
16+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
17+
add_compile_options(${OpenMP_CXX_FLAGS})
18+
endif()
19+
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
/**
2+
* @file
3+
* @author Jachym Hudlicky <[email protected]>
4+
* @brief Mlpack classifier interface
5+
*
6+
* SPDX-License-Identifier: BSD-3-Clause
7+
*/
8+
9+
#pragma once
10+
11+
#include "wif/classifiers/classifier.hpp"
12+
#include "wif/ml/mlpackWrapper.hpp"
13+
14+
#include <memory>
15+
#include <string>
16+
#include <vector>
17+
18+
namespace WIF {
19+
20+
/**
21+
* @brief Classifier performing ML classification which is interconnected with Mlpack library
22+
*
23+
*/
24+
class MlpackClassifier : public Classifier {
25+
public:
26+
/**
27+
* @brief Construct a new Mlpack Classifier object
28+
*
29+
* @param path contains the path to the file with the trained model.
30+
* @param logicalName contains the logical name of the trained model.
31+
*/
32+
MlpackClassifier(const std::string& path, const std::string& logicalName = "trained_data");
33+
34+
/**
35+
* @brief Set feature IDs which will be used for classification
36+
*
37+
* @param sourceFeatureIDs
38+
*/
39+
void setFeatureSourceIDs(const std::vector<FeatureID>& sourceFeatureIDs) override;
40+
41+
/**
42+
* @brief Classify single flowFeature object
43+
* See std::vector<ClfResult> classify(const std::vector<FlowFeatures>&) for more details
44+
*
45+
* @param flowFeatures flow features to classify
46+
* @return ClfResult result of the classification, which contains double represention class or
47+
* vector<double> with probabilities for each class (depends on model).
48+
*/
49+
ClfResult classify(const FlowFeatures& flowFeatures) override;
50+
51+
/**
52+
* @brief Classify a burst of flow features
53+
*
54+
* @param burstOfFlowsFeatures the burst of flow features to classify
55+
* @return std::vector<ClfResult> classification results with ClfResult object for each flow
56+
* features object
57+
*/
58+
std::vector<ClfResult> classify(const std::vector<FlowFeatures>& burstOfFlowFeatures) override;
59+
60+
/**
61+
* @brief Return the path of the ML model, which is currently loaded
62+
* @return const std::string& path of the model.
63+
*/
64+
const std::string getMlModelPath() const noexcept;
65+
66+
/**
67+
* @brief Reload the model from file, which was set in the constructor
68+
*
69+
* @param logicalName contains the logical name of the trained model.
70+
*/
71+
void reloadModelFromDisk(const std::string& logicalName = "trained_data");
72+
73+
private:
74+
/**
75+
* @brief Pointer to wrapper object with loaded mlpack model
76+
*/
77+
std::unique_ptr<MlpackWrapper> m_mlpackWrapper;
78+
};
79+
80+
} // namespace WIF
Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
/**
2+
* @file
3+
* @author Jachym Hudlicky <[email protected]>
4+
* @brief Mlpack AdaBoost model class
5+
*
6+
* SPDX-License-Identifier: BSD-3-Clause
7+
*/
8+
9+
#include "wif/ml/mlpackModels/mlpackModel.hpp"
10+
#include "wif/storage/clfResult.hpp"
11+
#include "wif/storage/flowFeatures.hpp"
12+
13+
#include <armadillo>
14+
#include <mlpack.hpp>
15+
#include <stdexcept>
16+
#include <string>
17+
#include <utility>
18+
#include <variant>
19+
#include <vector>
20+
21+
namespace WIF::MlpackModels {
22+
23+
/**
24+
* @brief Class which provides AdaBoost model weak learner from Mlpack library
25+
* @tparam WeakLearnerType weak learner used by AdaBoost
26+
*/
27+
template<typename WeakLearnerType>
28+
class AdaBoostModel : public MlpackModel {
29+
public:
30+
/**
31+
* @brief Construct a new AdaBoost wrapper object with no loaded model
32+
*/
33+
AdaBoostModel() = default;
34+
35+
/**
36+
* @brief Construct a new AdaBoost wrapper object
37+
*
38+
* @param modelPath contains path to trained model file
39+
* @param logicalName contains the logical name of the trained model.
40+
*/
41+
AdaBoostModel(const std::string& modelPath, const std::string& logicalName = "trained_data")
42+
{
43+
m_loaded = mlpack::data::Load(modelPath, logicalName, m_ab, true);
44+
if (m_loaded) {
45+
m_modelPath = modelPath;
46+
}
47+
}
48+
49+
/**
50+
* @brief Classify single flowFeature object
51+
*
52+
* @param flowFeatures flow features to classify
53+
* @return ClfResult result of the classification, which contains vector<double> with
54+
* probabilities for each class.
55+
*/
56+
ClfResult classify(const FlowFeatures& flowFeatures) override
57+
{
58+
arma::mat testDataset(m_featureIDs.size(), 1);
59+
arma::Row<size_t> predictions;
60+
arma::mat probaMatrix;
61+
std::vector<ClfResult> burstResults;
62+
std::vector<FlowFeatures> burstOfFeatures = {flowFeatures};
63+
64+
burstResults.reserve(1);
65+
66+
MlpackModel::convertBurstOfFeaturesToMatrix(burstOfFeatures, testDataset);
67+
m_ab.Classify(testDataset, predictions, probaMatrix);
68+
for (unsigned i = 0; i < predictions.n_elem; ++i) {
69+
std::vector<double> probabilities(probaMatrix.col(i).begin(), probaMatrix.col(i).end());
70+
burstResults.emplace_back(probabilities);
71+
}
72+
73+
return burstResults[0];
74+
}
75+
76+
/**
77+
* @brief Classify a burst of flow features
78+
*
79+
* @param burstOfFlowsFeatures the burst of flow features to classify
80+
* @return std::vector<ClfResult> the results of the classification. Each ClfResult contains
81+
* result of the classification, which contains vector<double> with probabilities for each
82+
* class
83+
*/
84+
std::vector<ClfResult> classify(const std::vector<FlowFeatures>& burstOfFeatures) override
85+
{
86+
arma::mat testDataset(m_featureIDs.size(), burstOfFeatures.size());
87+
arma::Row<size_t> predictions;
88+
arma::mat probaMatrix;
89+
std::vector<ClfResult> burstResults;
90+
91+
burstResults.reserve(burstOfFeatures.size());
92+
93+
MlpackModel::convertBurstOfFeaturesToMatrix(burstOfFeatures, testDataset);
94+
m_ab.Classify(testDataset, predictions, probaMatrix);
95+
for (unsigned i = 0; i < predictions.n_elem; ++i) {
96+
std::vector<double> probabilities(probaMatrix.col(i).begin(), probaMatrix.col(i).end());
97+
burstResults.emplace_back(probabilities);
98+
}
99+
100+
return burstResults;
101+
}
102+
103+
/**
104+
* @brief Load AdaBoost model from file
105+
*
106+
* @param modelPath contains path to the model file.
107+
* @param logicalName contains the logical name of the trained model.
108+
* @return Bool value true, if model was successfully loaded. False if not.
109+
*/
110+
bool
111+
load(const std::string& modelPath, const std::string& logicalName = "trained_data") override
112+
{
113+
m_loaded = mlpack::data::Load(modelPath, logicalName, m_ab);
114+
if (m_loaded) {
115+
m_modelPath = modelPath;
116+
}
117+
return m_loaded;
118+
}
119+
120+
/**
121+
* @brief Save AdaBoost model to file
122+
*
123+
* @param modelPath contains file path, where the model will be saved.
124+
* @param logicalName contains the logical name of the trained model.
125+
* @return Bool value true, if model was successfully saved. False if not.
126+
*/
127+
bool save(const std::string& modelPath, const std::string& logicalName = "trained_data")
128+
const override
129+
{
130+
return mlpack::data::Save(modelPath, logicalName, m_ab);
131+
}
132+
133+
/**
134+
* @brief Train AdaBoost model
135+
*
136+
* @param data contains training vector of flow features.
137+
* @param labels contains training labels, between 0 and numClasses - 1 (inclusive). Should have
138+
* length data.length().
139+
* @param path contains path, where file will be saved.
140+
* @param numClasses contains number of classes in the dataset.
141+
* @param maxIterations contains maximum number of iterations of AdaBoost.MH to use. This is the
142+
* maximum number of weak learners to train. (0 means no limit, and weak learners will be
143+
* trained until the tolerance is met.)
144+
* @param tolerance when the weighted residual (r_t) of the model goes below tolerance, training
145+
* will terminate and no more weak learners will be added.
146+
* @param weakLearnerParams optional weak learner hyperparameters.
147+
*/
148+
template<typename... WeakLearnerParams>
149+
void train(
150+
const std::vector<FlowFeatures>& data,
151+
const std::vector<size_t>& labels,
152+
const std::string& path,
153+
size_t numClasses = 2,
154+
size_t maxIterations = 100,
155+
double tolerance = 1e-6,
156+
WeakLearnerParams&&... weakLearnerParams)
157+
{
158+
arma::mat dataset(m_featureIDs.size(), data.size());
159+
arma::Row<size_t> armaLabels(labels);
160+
161+
MlpackModel::convertBurstOfFeaturesToMatrix(data, dataset);
162+
m_ab.Train(
163+
dataset,
164+
armaLabels,
165+
numClasses,
166+
maxIterations,
167+
tolerance,
168+
std::forward<WeakLearnerParams>(weakLearnerParams)...);
169+
170+
this->save(path);
171+
}
172+
173+
private:
174+
/**
175+
* @brief AdaBoost model
176+
*/
177+
mlpack::AdaBoost<WeakLearnerType, arma::mat> m_ab;
178+
};
179+
180+
} // namespace WIF::MlpackModels
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
/**
2+
* @file
3+
* @author Jachym Hudlicky <[email protected]>
4+
* @brief Mlpack Decision tree model interface
5+
*
6+
* SPDX-License-Identifier: BSD-3-Clause
7+
*/
8+
9+
#pragma once
10+
11+
#include "wif/ml/mlpackModels/mlpackModel.hpp"
12+
#include "wif/storage/clfResult.hpp"
13+
#include "wif/storage/flowFeatures.hpp"
14+
15+
#include <armadillo>
16+
#include <memory>
17+
#include <mlpack/core.hpp>
18+
#include <mlpack/methods/decision_tree/decision_tree.hpp>
19+
#include <stdexcept>
20+
#include <string>
21+
#include <utility>
22+
#include <vector>
23+
24+
namespace WIF::MlpackModels {
25+
26+
/**
27+
* @brief Class which provides Decision tree model from Mlpack library
28+
*/
29+
class DecisionTreeModel : public MlpackModel {
30+
public:
31+
/**
32+
* @brief Construct a new Decision tree object with no loaded model
33+
*/
34+
DecisionTreeModel() = default;
35+
36+
/**
37+
* @brief Construct a new Decision tree object
38+
*
39+
* @param modelPath contains path to trained model file
40+
* @param logicalName contains the logical name of the trained model.
41+
*/
42+
DecisionTreeModel(
43+
const std::string& modelPath,
44+
const std::string& logicalName = "trained_data");
45+
46+
/**
47+
* @brief Classify single flowFeature object
48+
*
49+
* @param flowFeatures flow features to classify
50+
* @return ClfResult result of the classification, which contains vector<double> with
51+
* probabilities for each class.
52+
*/
53+
ClfResult classify(const FlowFeatures& flowFeatures) override;
54+
55+
/**
56+
* @brief Classify a burst of flow features
57+
*
58+
* @param burstOfFlowsFeatures the burst of flow features to classify
59+
* @return std::vector<ClfResult> the results of the classification. Each ClfResult contains
60+
* result of the classification, which contains vector<double> with probabilities for each
61+
* class.
62+
*/
63+
std::vector<ClfResult> classify(const std::vector<FlowFeatures>& burstOfFeatures) override;
64+
65+
/**
66+
* @brief Load Decision tree model from file
67+
*
68+
* @param modelPath contains path to model file.
69+
* @param logicalName contains the logical name of the trained model.
70+
* @return Bool value true, if model was successfully loaded. False if not.
71+
*/
72+
bool
73+
load(const std::string& modelPath, const std::string& logicalName = "trained_data") override;
74+
75+
/**
76+
* @brief Save Decision tree model to file
77+
*
78+
* @param modelPath contains path, where the model will be saved.
79+
* @param logicalName contains the logical name of the trained model.
80+
* @return Bool value true, if model was successfully saved. False if not.
81+
*/
82+
bool save(const std::string& modelPath, const std::string& logicalName = "trained_data")
83+
const override;
84+
85+
/**
86+
* @brief Train Decision tree model
87+
*
88+
* @param data contains training vector of flow features.
89+
* @param labels contains training labels, between 0 and numClasses - 1 (inclusive). Should have
90+
* length data.length().
91+
* @param path contains path, where file will be saved.
92+
* @param numClasses contains number of classes in the dataset.
93+
* @param minLeafSize contains minimum number of points in each leaf node.
94+
* @param minGainSplit contains minimum gain for a node to split.
95+
* @param maxDepth contains maximum depth for the tree. (0 means no limit.)
96+
*/
97+
void train(
98+
const std::vector<FlowFeatures>& data,
99+
const std::vector<size_t>& labels,
100+
const std::string& path,
101+
size_t numClasses = 2,
102+
size_t minLeafSize = 10,
103+
double minGainSplit = 1e-7,
104+
size_t maxDepth = 0);
105+
106+
private:
107+
/**
108+
* @brief Decision tree model
109+
*/
110+
mlpack::DecisionTree<> m_dt;
111+
};
112+
113+
} // namespace WIF::MlpackModels

0 commit comments

Comments
 (0)