Skip to content

Commit e52e5fa

Browse files
committed
Merge branch 'tg-45-alf-classifier' into 'main'
WIF - ALF Classifier See merge request feta/wif-group/libwif!36
2 parents 414550a + a931ca8 commit e52e5fa

File tree

7 files changed

+226
-2
lines changed

7 files changed

+226
-2
lines changed
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
/**
2+
* @file
3+
* @author Pavel Siska <[email protected]>
4+
* @author Richard Plny <[email protected]>
5+
* @brief ALF classifier interface
6+
*
7+
* SPDX-License-Identifier: BSD-3-Clause
8+
*/
9+
10+
#pragma once
11+
12+
#include "wif/classifiers/scikitMlClassifier.hpp"
13+
#include "wif/filesystem/fileModificationChecker.hpp"
14+
#include "wif/reporters/unirecReporter.hpp"
15+
#include "wif/utils/timer.hpp"
16+
17+
#include <chrono>
18+
#include <memory>
19+
#include <vector>
20+
21+
namespace WIF {
22+
23+
/**
24+
* @brief Classifier performing ML classification which is interconnected with ALF
25+
*
26+
*/
27+
class AlfClassifier : public Classifier {
28+
private:
29+
class AlfCallback : public TimerCallback {
30+
public:
31+
AlfCallback(const std::string& modelPath, AlfClassifier& classifier);
32+
33+
void onTick() override;
34+
35+
private:
36+
FileModificationChecker m_fileModificationChecker;
37+
AlfClassifier& m_classifier;
38+
};
39+
40+
public:
41+
/**
42+
* @brief Construct a new Alf Classifier object
43+
*
44+
* @param mlClassifier underlying ML classifier, which has already set source feature IDs
45+
* @param reporter UniRec reporter for ALF interconnection, which already has its interface
46+
* ready
47+
* @param timerIntervalInSeconds number of seconds between each check of file on disk
48+
*/
49+
AlfClassifier(
50+
ScikitMlClassifier& mlClassifier,
51+
UnirecReporter& reporter,
52+
unsigned timerIntervalInSeconds);
53+
54+
/**
55+
* @brief Classify single flowFeature object
56+
* See std::vector<ClfResult> classify(const std::vector<FlowFeatures>&) for more details
57+
*
58+
* @param flowFeatures flow features to classify
59+
* @return ClfResult result of the classification
60+
*/
61+
ClfResult classify(const FlowFeatures& flowFeatures) override;
62+
63+
/**
64+
* @brief Classify a burst of flow features
65+
* Firstly, ML model is reloaded if needed, then ML classifier is used to obtain results
66+
* Then, source features of ALF classifier are reported, then source features of ML classifier
67+
* are reported, then last model load time (uint64_t) and finally classification results are
68+
* reported
69+
*
70+
* @param burstOfFlowsFeatures the burst of flow features to classify
71+
* @return std::vector<ClfResult> the results of the classification
72+
*/
73+
std::vector<ClfResult> classify(const std::vector<FlowFeatures>& burstOfFlowsFeatures) override;
74+
75+
/**
76+
* @brief Method which marks used ML model as obsolete
77+
* This operation will result into ML model reload during the next call of classify() method
78+
*/
79+
void markModelAsObsolete() { m_modelReloadNeeded = true; }
80+
81+
private:
82+
void updateLastModelLoadTime();
83+
void handleModelUpdate();
84+
void handleSingleReport(const FlowFeatures flowFeatures, const ClfResult& result);
85+
86+
bool m_modelReloadNeeded = false;
87+
uint64_t m_lastModelLoadTime;
88+
ScikitMlClassifier& m_mlClassifier;
89+
UnirecReporter& m_reporter;
90+
Timer m_timer;
91+
};
92+
93+
} // namespace WIF

include/wif/classifiers/scikitMlClassifier.hpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,25 @@ class ScikitMlClassifier : public Classifier {
2929
*/
3030
ScikitMlClassifier(const std::string& bridgePath, const std::string& mlModelPath);
3131

32+
/**
33+
* @brief Get the path to the used ML model
34+
*
35+
* @return const std::string& the path to the used ML model
36+
*/
37+
const std::string& getMlModelPath() const noexcept;
38+
3239
/**
3340
* @brief Set feature IDs which will be used for classification
3441
*
3542
* @param sourceFeatureIDs
3643
*/
3744
void setFeatureSourceIDs(const std::vector<FeatureID>& sourceFeatureIDs) override;
3845

46+
/**
47+
* @brief Reload used ML model from disk
48+
*/
49+
void reloadModelFromDisk();
50+
3951
/**
4052
* @brief Classify single flowFeature object
4153
*

include/wif/ml/scikitMlWrapper.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,10 @@ class ScikitMlWrapper {
5757
void setFeatureSourceIDs(const std::vector<FeatureID>& sourceFeatureIDs);
5858

5959
/**
60-
* @brief Getter for used ML model path
60+
* @brief Getter for path of the used ML model
6161
* @return const std::string&
6262
*/
63-
const std::string& mlModelPath() const noexcept { return m_mlModelPath; }
63+
const std::string& getMlModelPath() const noexcept { return m_mlModelPath; }
6464

6565
/**
6666
* @brief Reload model from disk

include/wif/storage/flowFeatures.hpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,14 @@ class FlowFeatures {
4545
return std::get<T>(m_features.at(featureID));
4646
}
4747

48+
/**
49+
* @brief Get raw feature
50+
*
51+
* @param featureID the feature identifier
52+
* @return const DataVariant& raw feature
53+
*/
54+
const DataVariant& getRaw(FeatureID featureID) const { return m_features.at(featureID); }
55+
4856
/**
4957
* @brief Set feature
5058
*

src/wif/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ set(LIBWIF_LIBS
2121

2222
if(BUILD_WITH_UNIREC)
2323
list(APPEND LIBWIF_SOURCES
24+
classifiers/alfClassifier.cpp
2425
reporters/unirecReporter.cpp
2526
)
2627
list(APPEND LIBWIF_LIBS
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
/**
2+
* @file
3+
* @author Pavel Siska <[email protected]>
4+
* @author Richard Plny <[email protected]>
5+
* @brief ALF classifier implementation
6+
*
7+
* SPDX-License-Identifier: BSD-3-Clause
8+
*/
9+
10+
#include "wif/classifiers/alfClassifier.hpp"
11+
12+
namespace WIF {
13+
14+
AlfClassifier::AlfCallback::AlfCallback(const std::string& modelPath, AlfClassifier& classifier)
15+
: m_fileModificationChecker(modelPath)
16+
, m_classifier(classifier)
17+
{
18+
}
19+
20+
void AlfClassifier::AlfCallback::onTick()
21+
{
22+
if (m_fileModificationChecker.isChangeDetected()) {
23+
m_classifier.markModelAsObsolete();
24+
}
25+
}
26+
27+
AlfClassifier::AlfClassifier(
28+
ScikitMlClassifier& mlClassifier,
29+
UnirecReporter& reporter,
30+
unsigned timerIntervalInSeconds)
31+
: m_mlClassifier(mlClassifier)
32+
, m_reporter(reporter)
33+
, m_timer(
34+
timerIntervalInSeconds,
35+
std::make_unique<AlfCallback>(mlClassifier.getMlModelPath(), *this))
36+
{
37+
updateLastModelLoadTime();
38+
m_timer.start();
39+
}
40+
41+
ClfResult AlfClassifier::classify(const FlowFeatures& flowFeatures)
42+
{
43+
return classify(std::vector<FlowFeatures> {flowFeatures})[0];
44+
}
45+
46+
std::vector<ClfResult>
47+
AlfClassifier::classify(const std::vector<FlowFeatures>& burstOfFlowsFeatures)
48+
{
49+
if (m_modelReloadNeeded) {
50+
handleModelUpdate();
51+
}
52+
53+
auto burstOfResults = m_mlClassifier.classify(burstOfFlowsFeatures);
54+
for (unsigned flowIdx = 0; flowIdx < burstOfResults.size(); ++flowIdx) {
55+
auto& flowFeatures = burstOfFlowsFeatures[flowIdx];
56+
auto& results = burstOfResults[flowIdx];
57+
handleSingleReport(flowFeatures, results);
58+
}
59+
60+
return burstOfResults;
61+
}
62+
63+
void AlfClassifier::updateLastModelLoadTime()
64+
{
65+
const auto timestamp = std::chrono::system_clock::now();
66+
m_lastModelLoadTime
67+
= std::chrono::duration_cast<std::chrono::seconds>(timestamp.time_since_epoch()).count();
68+
}
69+
70+
void AlfClassifier::handleModelUpdate()
71+
{
72+
m_mlClassifier.reloadModelFromDisk();
73+
m_modelReloadNeeded = false;
74+
updateLastModelLoadTime();
75+
}
76+
77+
void AlfClassifier::handleSingleReport(const FlowFeatures flowFeatures, const ClfResult& result)
78+
{
79+
m_reporter.onRecordStart();
80+
81+
// Identification fields are reported first
82+
for (const auto id : sourceFeatureIDs()) {
83+
m_reporter.report(flowFeatures.getRaw(id));
84+
}
85+
86+
// ML features are reported as second
87+
for (const auto id : m_mlClassifier.sourceFeatureIDs()) {
88+
m_reporter.report(flowFeatures.getRaw(id));
89+
}
90+
91+
// Report last time ML model was reloaded
92+
m_reporter.report({m_lastModelLoadTime});
93+
94+
// ML proba array is reported as the last one
95+
m_reporter.report(result.get<std::vector<double>>());
96+
97+
m_reporter.onRecordEnd();
98+
}
99+
100+
} // namespace WIF

src/wif/classifiers/scikitMlClassifier.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,22 @@ ScikitMlClassifier::ScikitMlClassifier(
1717
m_scikitMlWrapper = std::make_unique<ScikitMlWrapper>(bridgePath, mlModelPath);
1818
}
1919

20+
const std::string& ScikitMlClassifier::getMlModelPath() const noexcept
21+
{
22+
return m_scikitMlWrapper->getMlModelPath();
23+
}
24+
2025
void ScikitMlClassifier::setFeatureSourceIDs(const std::vector<FeatureID>& sourceFeatureIDs)
2126
{
2227
Classifier::setFeatureSourceIDs(sourceFeatureIDs);
2328
m_scikitMlWrapper->setFeatureSourceIDs(sourceFeatureIDs);
2429
}
2530

31+
void ScikitMlClassifier::reloadModelFromDisk()
32+
{
33+
m_scikitMlWrapper->reloadModel();
34+
}
35+
2636
ClfResult ScikitMlClassifier::classify(const FlowFeatures& flowFeatures)
2737
{
2838
return m_scikitMlWrapper->classify({flowFeatures})[0];

0 commit comments

Comments
 (0)