Skip to content

Commit f0566c7

Browse files
authored
AlfClassifier - Introduce support MlpackClassifier (#9)
1 parent 66f5cf0 commit f0566c7

File tree

11 files changed

+66
-24
lines changed

11 files changed

+66
-24
lines changed

include/wif/classifiers/alfClassifier.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
#pragma once
1111

12-
#include "wif/classifiers/scikitMlClassifier.hpp"
12+
#include "wif/classifiers/genericMlClassifier.hpp"
1313
#include "wif/filesystem/fileModificationChecker.hpp"
1414
#include "wif/reporters/unirecReporter.hpp"
1515
#include "wif/utils/timer.hpp"
@@ -47,7 +47,7 @@ class AlfClassifier : public Classifier {
4747
* @param timerIntervalInSeconds number of seconds between each check of file on disk
4848
*/
4949
AlfClassifier(
50-
ScikitMlClassifier& mlClassifier,
50+
GenericMlClassifier& mlClassifier,
5151
UnirecReporter& reporter,
5252
unsigned timerIntervalInSeconds);
5353

@@ -85,7 +85,7 @@ class AlfClassifier : public Classifier {
8585

8686
bool m_modelReloadNeeded = false;
8787
uint64_t m_lastModelLoadTime;
88-
ScikitMlClassifier& m_mlClassifier;
88+
GenericMlClassifier& m_mlClassifier;
8989
UnirecReporter& m_reporter;
9090
Timer m_timer;
9191
};
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/**
2+
* @file
3+
* @author Jachym Hudlicky <[email protected]>
4+
* @brief Generic machine learning classifier interface
5+
*
6+
* SPDX-License-Identifier: BSD-3-Clause
7+
*/
8+
9+
#pragma once
10+
11+
#include "wif/classifiers/classifier.hpp"
12+
13+
#include <memory>
14+
#include <vector>
15+
16+
namespace WIF {
17+
18+
/**
19+
* @brief Abstract class specifying interfaces for ML classifiers (ScikitMlClassifier and
20+
* MlpackClassifier)
21+
*
22+
*/
23+
class GenericMlClassifier : public Classifier {
24+
public:
25+
/**
26+
* @brief Return the path of the ML model, which is currently loaded
27+
* @return const std::string& path of the model
28+
*/
29+
virtual const std::string& getMlModelPath() const noexcept = 0;
30+
31+
/**
32+
* @brief Reload the model from file, which was set in the constructor
33+
*
34+
* @param logicalName contains the logical name of the trained model. The parameter is used only
35+
* with MlpackClassifier (it is unused with ScikitMlClassifier)
36+
*/
37+
virtual void reloadModelFromDisk(const std::string& logicalName = "trained_data") = 0;
38+
};
39+
40+
} // namespace WIF

include/wif/classifiers/mlpackClassifier.hpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
#pragma once
1010

11-
#include "wif/classifiers/classifier.hpp"
11+
#include "wif/classifiers/genericMlClassifier.hpp"
1212
#include "wif/ml/mlpackWrapper.hpp"
1313

1414
#include <memory>
@@ -21,13 +21,13 @@ namespace WIF {
2121
* @brief Classifier performing ML classification which is interconnected with Mlpack library
2222
*
2323
*/
24-
class MlpackClassifier : public Classifier {
24+
class MlpackClassifier : public GenericMlClassifier {
2525
public:
2626
/**
2727
* @brief Construct a new Mlpack Classifier object
2828
*
29-
* @param path contains the path to the file with the trained model.
30-
* @param logicalName contains the logical name of the trained model.
29+
* @param path contains the path to the file with the trained model
30+
* @param logicalName contains the logical name of the trained model
3131
*/
3232
MlpackClassifier(const std::string& path, const std::string& logicalName = "trained_data");
3333

@@ -44,7 +44,7 @@ class MlpackClassifier : public Classifier {
4444
*
4545
* @param flowFeatures flow features to classify
4646
* @return ClfResult result of the classification, which contains double represention class or
47-
* vector<double> with probabilities for each class (depends on model).
47+
* vector<double> with probabilities for each class (depends on model)
4848
*/
4949
ClfResult classify(const FlowFeatures& flowFeatures) override;
5050

@@ -59,16 +59,16 @@ class MlpackClassifier : public Classifier {
5959

6060
/**
6161
* @brief Return the path of the ML model, which is currently loaded
62-
* @return const std::string& path of the model.
62+
* @return const std::string& path of the model
6363
*/
64-
const std::string getMlModelPath() const noexcept;
64+
const std::string& getMlModelPath() const noexcept override;
6565

6666
/**
6767
* @brief Reload the model from file, which was set in the constructor
6868
*
69-
* @param logicalName contains the logical name of the trained model.
69+
* @param logicalName contains the logical name of the trained model
7070
*/
71-
void reloadModelFromDisk(const std::string& logicalName = "trained_data");
71+
void reloadModelFromDisk(const std::string& logicalName = "trained_data") override;
7272

7373
private:
7474
/**

include/wif/classifiers/scikitMlClassifier.hpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
#pragma once
1010

11-
#include "wif/classifiers/classifier.hpp"
11+
#include "wif/classifiers/genericMlClassifier.hpp"
1212
#include "wif/ml/scikitMlWrapper.hpp"
1313

1414
#include <memory>
@@ -19,7 +19,7 @@ namespace WIF {
1919
/**
2020
* @brief Classifier performing Machine-Learning based detection via Scikit-learn library
2121
*/
22-
class ScikitMlClassifier : public Classifier {
22+
class ScikitMlClassifier : public GenericMlClassifier {
2323
public:
2424
/**
2525
* @brief Construct a new Scikitlearn Ml Classifier object
@@ -34,7 +34,7 @@ class ScikitMlClassifier : public Classifier {
3434
*
3535
* @return const std::string& the path to the used ML model
3636
*/
37-
const std::string& getMlModelPath() const noexcept;
37+
const std::string& getMlModelPath() const noexcept override;
3838

3939
/**
4040
* @brief Set feature IDs which will be used for classification
@@ -45,8 +45,10 @@ class ScikitMlClassifier : public Classifier {
4545

4646
/**
4747
* @brief Reload used ML model from disk
48+
*
49+
* @param logicalName is unused
4850
*/
49-
void reloadModelFromDisk();
51+
void reloadModelFromDisk([[maybe_unused]] const std::string& logicalName) override;
5052

5153
/**
5254
* @brief Classify single flowFeature object

include/wif/ml/mlpackModels/mlpackModel.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ class MlpackModel {
8585
* @brief Getter for path of the used ML model
8686
* @return const std::string&
8787
*/
88-
std::string getPath() const;
88+
const std::string& getPath() const;
8989

9090
/**
9191
* @brief Set feature IDs which will be used for classification

include/wif/ml/mlpackWrapper.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ class MlpackWrapper {
6060
* @brief Getter for path of the used ML model
6161
* @return const std::string&
6262
*/
63-
const std::string getModelPath() const;
63+
const std::string& getModelPath() const;
6464

6565
/**
6666
* @brief Load the model from the file

src/wif/classifiers/alfClassifier.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,14 @@ void AlfClassifier::AlfCallback::onTick()
2525
}
2626

2727
AlfClassifier::AlfClassifier(
28-
ScikitMlClassifier& mlClassifier,
28+
GenericMlClassifier& mlClassifier,
2929
UnirecReporter& reporter,
3030
unsigned timerIntervalInSeconds)
3131
: m_mlClassifier(mlClassifier)
3232
, m_reporter(reporter)
3333
, m_timer(
3434
timerIntervalInSeconds,
35-
std::make_unique<AlfCallback>(mlClassifier.getMlModelPath(), *this))
35+
std::make_unique<AlfCallback>(m_mlClassifier.getMlModelPath(), *this))
3636
{
3737
updateLastModelLoadTime();
3838
m_timer.start();

src/wif/classifiers/mlpackClassifier.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ MlpackClassifier::classify(const std::vector<FlowFeatures>& burstOfFlowFeatures)
3232
return m_mlpackWrapper->classify(burstOfFlowFeatures);
3333
}
3434

35-
const std::string MlpackClassifier::getMlModelPath() const noexcept
35+
const std::string& MlpackClassifier::getMlModelPath() const noexcept
3636
{
3737
return m_mlpackWrapper->getModelPath();
3838
}

src/wif/classifiers/scikitMlClassifier.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ void ScikitMlClassifier::setFeatureSourceIDs(const std::vector<FeatureID>& sourc
2828
m_scikitMlWrapper->setFeatureSourceIDs(sourceFeatureIDs);
2929
}
3030

31-
void ScikitMlClassifier::reloadModelFromDisk()
31+
void ScikitMlClassifier::reloadModelFromDisk([[maybe_unused]] const std::string& logicalName)
3232
{
3333
m_scikitMlWrapper->reloadModel();
3434
}

src/wif/ml/mlpackModels/mlpackModel.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ bool MlpackModel::isLoaded() const
1515
return m_loaded;
1616
}
1717

18-
std::string MlpackModel::getPath() const
18+
const std::string& MlpackModel::getPath() const
1919
{
2020
return m_modelPath;
2121
}

0 commit comments

Comments
 (0)