Skip to content

Commit ddffbf2

Browse files
authored
Merge pull request #45232 from mmusich/mm_fix_stream_PhotonXGBoostProducer
introduce const variant of `XGBooster::predict `and use it in `PhotonXGBoostEstimator`
2 parents 8cd5535 + 0cb0985 commit ddffbf2

File tree

5 files changed

+29
-26
lines changed

5 files changed

+29
-26
lines changed

PhysicsTools/XGBoost/interface/XGBooster.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ namespace pat {
2323
void set(std::string name, float value);
2424

2525
float predict(const int iterationEnd = 0);
26+
float predict(const std::vector<float>& features, const int iterationEnd = 0) const;
2627

2728
private:
2829
std::vector<float> features_;

PhysicsTools/XGBoost/src/XGBooster.cc

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,6 @@ void XGBooster::addFeature(std::string name) {
7979
void XGBooster::set(std::string name, float value) { features_.at(feature_name_to_index_[name]) = value; }
8080

8181
float XGBooster::predict(const int iterationEnd) {
82-
float result(-999.);
83-
8482
// check if all feature values are set properly
8583
for (unsigned int i = 0; i < features_.size(); ++i)
8684
if (std::isnan(features_.at(i))) {
@@ -94,8 +92,25 @@ float XGBooster::predict(const int iterationEnd) {
9492
throw std::runtime_error("Feature is not set: " + feature_name);
9593
}
9694

95+
float const ret = predict(features_, iterationEnd);
96+
97+
reset();
98+
99+
return ret;
100+
}
101+
102+
float XGBooster::predict(const std::vector<float>& features, const int iterationEnd) const {
103+
float result{-999.};
104+
105+
if (features.empty()) {
106+
throw std::runtime_error("Vector of input features is empty");
107+
}
108+
109+
if (feature_name_to_index_.size() != features.size())
110+
throw std::runtime_error("Feature size mismatch");
111+
97112
DMatrixHandle dvalues;
98-
XGDMatrixCreateFromMat(&features_[0], 1, features_.size(), 9e99, &dvalues);
113+
XGDMatrixCreateFromMat(&features[0], 1, features.size(), 9e99, &dvalues);
99114

100115
bst_ulong out_len = 0;
101116
const float* score = nullptr;
@@ -126,7 +141,5 @@ float XGBooster::predict(const int iterationEnd) {
126141

127142
XGDMatrixFree(dvalues);
128143

129-
reset();
130-
131144
return result;
132145
}

RecoEgamma/PhotonIdentification/interface/PhotonXGBoostEstimator.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@
44
#include "FWCore/ParameterSet/interface/FileInPath.h"
55
#include "PhysicsTools/XGBoost/interface/XGBooster.h"
66

7+
#include <memory>
8+
79
class PhotonXGBoostEstimator {
810
public:
911
PhotonXGBoostEstimator(const edm::FileInPath& weightsFile, int best_ntree_limit);
10-
~PhotonXGBoostEstimator();
1112

1213
float computeMva(float rawEnergyIn,
1314
float r9In,
@@ -22,7 +23,6 @@ class PhotonXGBoostEstimator {
2223
private:
2324
std::unique_ptr<pat::XGBooster> booster_;
2425
int best_ntree_limit_ = -1;
25-
std::string config_;
2626
};
2727

2828
#endif

RecoEgamma/PhotonIdentification/plugins/PhotonXGBoostProducer.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@ class PhotonXGBoostProducer : public edm::global::EDProducer<> {
3838
const unsigned mvaNTreeLimitB_;
3939
const unsigned mvaNTreeLimitE_;
4040
const double mvaThresholdEt_;
41-
std::unique_ptr<PhotonXGBoostEstimator> mvaEstimatorB_;
42-
std::unique_ptr<PhotonXGBoostEstimator> mvaEstimatorE_;
41+
const std::unique_ptr<const PhotonXGBoostEstimator> mvaEstimatorB_;
42+
const std::unique_ptr<const PhotonXGBoostEstimator> mvaEstimatorE_;
4343
};
4444

4545
PhotonXGBoostProducer::PhotonXGBoostProducer(edm::ParameterSet const& config)
@@ -54,9 +54,9 @@ PhotonXGBoostProducer::PhotonXGBoostProducer(edm::ParameterSet const& config)
5454
mvaFileXgbE_(config.getParameter<edm::FileInPath>("mvaFileXgbE")),
5555
mvaNTreeLimitB_(config.getParameter<unsigned int>("mvaNTreeLimitB")),
5656
mvaNTreeLimitE_(config.getParameter<unsigned int>("mvaNTreeLimitE")),
57-
mvaThresholdEt_(config.getParameter<double>("mvaThresholdEt")) {
58-
mvaEstimatorB_ = std::make_unique<PhotonXGBoostEstimator>(mvaFileXgbB_, mvaNTreeLimitB_);
59-
mvaEstimatorE_ = std::make_unique<PhotonXGBoostEstimator>(mvaFileXgbE_, mvaNTreeLimitE_);
57+
mvaThresholdEt_(config.getParameter<double>("mvaThresholdEt")),
58+
mvaEstimatorB_{std::make_unique<const PhotonXGBoostEstimator>(mvaFileXgbB_, mvaNTreeLimitB_)},
59+
mvaEstimatorE_{std::make_unique<const PhotonXGBoostEstimator>(mvaFileXgbE_, mvaNTreeLimitE_)} {
6060
produces<reco::RecoEcalCandidateIsolationMap>();
6161
}
6262

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#include "RecoEgamma/PhotonIdentification/interface/PhotonXGBoostEstimator.h"
2-
#include <sstream>
32

43
PhotonXGBoostEstimator::PhotonXGBoostEstimator(const edm::FileInPath& weightsFile, int best_ntree_limit) {
54
booster_ = std::make_unique<pat::XGBooster>(weightsFile.fullPath());
@@ -16,8 +15,6 @@ PhotonXGBoostEstimator::PhotonXGBoostEstimator(const edm::FileInPath& weightsFil
1615
best_ntree_limit_ = best_ntree_limit;
1716
}
1817

19-
PhotonXGBoostEstimator::~PhotonXGBoostEstimator() {}
20-
2118
float PhotonXGBoostEstimator::computeMva(float rawEnergyIn,
2219
float r9In,
2320
float sigmaIEtaIEtaIn,
@@ -27,15 +24,7 @@ float PhotonXGBoostEstimator::computeMva(float rawEnergyIn,
2724
float etaIn,
2825
float hOvrEIn,
2926
float ecalPFIsoIn) const {
30-
booster_->set("rawEnergy", rawEnergyIn);
31-
booster_->set("r9", r9In);
32-
booster_->set("sigmaIEtaIEta", sigmaIEtaIEtaIn);
33-
booster_->set("etaWidth", etaWidthIn);
34-
booster_->set("phiWidth", phiWidthIn);
35-
booster_->set("s4", s4In);
36-
booster_->set("eta", etaIn);
37-
booster_->set("hOvrE", hOvrEIn);
38-
booster_->set("ecalPFIso", ecalPFIsoIn);
39-
40-
return booster_->predict(best_ntree_limit_);
27+
return booster_->predict(
28+
{rawEnergyIn, r9In, sigmaIEtaIEtaIn, etaWidthIn, phiWidthIn, s4In, etaIn, hOvrEIn, ecalPFIsoIn},
29+
best_ntree_limit_);
4130
}

0 commit comments

Comments
 (0)