Skip to content

Commit fe94271

Browse files
committed
AXOL1TLCondition: move model initialization to ctor
1 parent 9dc135c commit fe94271

File tree

2 files changed

+38
-30
lines changed

2 files changed

+38
-30
lines changed

L1Trigger/L1TGlobal/interface/AXOL1TLCondition.h

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,14 @@
77
* Description: evaluation of a CondAXOL1TL condition.
88
*/
99

10-
// system include files
1110
#include <iosfwd>
1211
#include <string>
1312

14-
// user include files
15-
// base classes
1613
#include "L1Trigger/L1TGlobal/interface/ConditionEvaluation.h"
1714
#include "DataFormats/L1Trigger/interface/L1Candidate.h"
1815

16+
#include "hls4ml/emulator.h"
17+
1918
// forward declarations
2019
class GlobalCondition;
2120
class AXOL1TLTemplate;
@@ -64,6 +63,10 @@ namespace l1t {
6463

6564
inline float getScore() const { return m_savedscore; }
6665

66+
void loadModel();
67+
68+
inline hls4mlEmulator::ModelLoader const& model_loader() const { return m_model_loader; }
69+
6770
private:
6871
/// copy function for copy constructor and operator=
6972
void copy(const AXOL1TLCondition& cp);
@@ -74,6 +77,11 @@ namespace l1t {
7477
/// pointer to uGt GlobalBoard, to be able to get the trigger objects
7578
const GlobalBoard* m_gtGTB;
7679

80+
static constexpr char const* kModelNamePrefix = "GTADModel_";
81+
82+
hls4mlEmulator::ModelLoader m_model_loader;
83+
std::shared_ptr<hls4mlEmulator::Model> m_model;
84+
7785
///axo score for possible score saving
7886
mutable float m_savedscore;
7987
};

L1Trigger/L1TGlobal/src/AXOL1TLCondition.cc

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
#include <vector>
2020
#include <algorithm>
2121
#include "ap_fixed.h"
22-
#include "hls4ml/emulator.h"
2322

2423
// user include files
2524
// base classes
@@ -42,17 +41,16 @@
4241
#include "FWCore/MessageLogger/interface/MessageLogger.h"
4342
#include "FWCore/MessageLogger/interface/MessageDrop.h"
4443

45-
// constructors
46-
// default
47-
l1t::AXOL1TLCondition::AXOL1TLCondition() : ConditionEvaluation() {
48-
// empty
49-
}
44+
l1t::AXOL1TLCondition::AXOL1TLCondition()
45+
: ConditionEvaluation(), m_gtAXOL1TLTemplate{nullptr}, m_gtGTB{nullptr}, m_model{nullptr} {}
5046

51-
// from base template condition (from event setup usually)
5247
l1t::AXOL1TLCondition::AXOL1TLCondition(const GlobalCondition* axol1tlTemplate, const GlobalBoard* ptrGTB)
5348
: ConditionEvaluation(),
5449
m_gtAXOL1TLTemplate(static_cast<const AXOL1TLTemplate*>(axol1tlTemplate)),
55-
m_gtGTB(ptrGTB) {}
50+
m_gtGTB(ptrGTB),
51+
m_model_loader{kModelNamePrefix + m_gtAXOL1TLTemplate->modelVersion()} {
52+
loadModel();
53+
}
5654

5755
// copy constructor
5856
void l1t::AXOL1TLCondition::copy(const l1t::AXOL1TLCondition& cp) {
@@ -64,6 +62,9 @@ void l1t::AXOL1TLCondition::copy(const l1t::AXOL1TLCondition& cp) {
6462
m_combinationsInCond = cp.getCombinationsInCond();
6563

6664
m_verbosity = cp.m_verbosity;
65+
66+
m_model_loader.reset(cp.model_loader().model_name());
67+
loadModel();
6768
}
6869

6970
l1t::AXOL1TLCondition::AXOL1TLCondition(const l1t::AXOL1TLCondition& cp) : ConditionEvaluation() { copy(cp); }
@@ -88,26 +89,25 @@ void l1t::AXOL1TLCondition::setuGtB(const GlobalBoard* ptrGTB) { m_gtGTB = ptrGT
8889
/// set score for score saving
8990
void l1t::AXOL1TLCondition::setScore(const float scoreval) const { m_savedscore = scoreval; }
9091

91-
const bool l1t::AXOL1TLCondition::evaluateCondition(const int bxEval) const {
92-
bool condResult = false;
93-
int useBx = bxEval + m_gtAXOL1TLTemplate->condRelativeBx();
94-
95-
//HLS4ML stuff
96-
std::string AXOL1TLmodelversion = "GTADModel_" + m_gtAXOL1TLTemplate->modelVersion(); //loading from menu/template
97-
98-
//otherwise load model (if possible) and run inference
99-
hls4mlEmulator::ModelLoader loader(AXOL1TLmodelversion);
100-
std::shared_ptr<hls4mlEmulator::Model> model;
101-
92+
void l1t::AXOL1TLCondition::loadModel() {
10293
try {
103-
model = loader.load_model();
94+
m_model = m_model_loader.load_model();
10495
} catch (std::runtime_error& e) {
105-
// for stopping with exception if model version cannot be loaded
106-
throw cms::Exception("ModelError")
107-
<< " ERROR: failed to load AXOL1TL model version \"" << AXOL1TLmodelversion
108-
<< "\" that was specified in menu. Model version not found in cms-hls4ml externals.";
96+
throw cms::Exception("ModelError") << " ERROR: failed to load AXOL1TL model version \""
97+
<< m_model_loader.model_name()
98+
<< "\". Model version not found in cms-hls4ml externals.";
99+
}
100+
}
101+
102+
const bool l1t::AXOL1TLCondition::evaluateCondition(const int bxEval) const {
103+
if (m_model == nullptr) {
104+
throw cms::Exception("ModelError") << " ERROR: no model was loaded for AXOL1TL model version \""
105+
<< m_model_loader.model_name() << "\".";
109106
}
110107

108+
bool condResult = false;
109+
int useBx = bxEval + m_gtAXOL1TLTemplate->condRelativeBx();
110+
111111
// //pointers to objects
112112
const BXVector<const l1t::Muon*>* candMuVec = m_gtGTB->getCandL1Mu();
113113
const BXVector<const l1t::L1Candidate*>* candJetVec = m_gtGTB->getCandL1Jet();
@@ -232,9 +232,9 @@ const bool l1t::AXOL1TLCondition::evaluateCondition(const int bxEval) const {
232232
}
233233

234234
//now run the inference
235-
model->prepare_input(ADModelInput); //scaling internal here
236-
model->predict();
237-
model->read_result(&ADModelResult); // this should be the square sum model result
235+
m_model->prepare_input(ADModelInput); //scaling internal here
236+
m_model->predict();
237+
m_model->read_result(&ADModelResult); // this should be the square sum model result
238238

239239
result = ADModelResult.first;
240240
loss = ADModelResult.second;

0 commit comments

Comments
 (0)