Skip to content

Commit 2fb2ea7

Browse files
committed
model v5
1 parent 4eeb0e8 commit 2fb2ea7

File tree

2 files changed

+25
-11
lines changed

2 files changed

+25
-11
lines changed

L1Trigger/L1TGlobal/interface/AXOL1TLCondition.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
#include <iosfwd>
1111
#include <string>
12+
#include <utility>
1213

1314
#include "L1Trigger/L1TGlobal/interface/ConditionEvaluation.h"
1415
#include "DataFormats/L1Trigger/interface/L1Candidate.h"
@@ -24,6 +25,14 @@ namespace l1t {
2425
class L1Candidate;
2526
class GlobalBoard;
2627

28+
//template function for reading results
29+
template <typename ResultType, typename LossType>
30+
LossType readResult(hls4mlEmulator::Model& model) {
31+
std::pair<ResultType, LossType> ADModelResult; //model outputs a pair of the (result vector, loss)
32+
model.read_result(&ADModelResult);
33+
return ADModelResult.second;
34+
}
35+
2736
// class declaration
2837
class AXOL1TLCondition : public ConditionEvaluation {
2938
public:

L1Trigger/L1TGlobal/src/AXOL1TLCondition.cc

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -130,10 +130,7 @@ const bool l1t::AXOL1TLCondition::evaluateCondition(const int bxEval) const {
130130

131131
//types of inputs and outputs
132132
typedef ap_fixed<18, 13> inputtype;
133-
typedef std::array<ap_fixed<10, 7, AP_RND_CONV, AP_SAT>, 8> resulttype; //v3
134133
typedef ap_ufixed<18, 14> losstype;
135-
typedef std::pair<resulttype, losstype> pairtype;
136-
// typedef std::array<ap_fixed<10, 7>, 13> resulttype; //deprecated v1 type:
137134

138135
//define zero
139136
inputtype fillzero = 0.0;
@@ -148,10 +145,10 @@ const bool l1t::AXOL1TLCondition::evaluateCondition(const int bxEval) const {
148145
inputtype EtSumInput[EtSumVecSize];
149146

150147
//declare result vectors +score
151-
resulttype result;
148+
// resulttype result;
152149
losstype loss;
153-
pairtype ADModelResult; //model outputs a pair of the (result vector, loss)
154-
float score = -1.0; //not sure what the best default is hm??
150+
// pairtype ADModelResult; //model outputs a pair of the (result vector, loss)
151+
float score = -1.0; //not sure what the best default is hm??
155152

156153
//check number of input objects we actually have (muons, jets etc)
157154
int NCandMu = candMuVec->size(useBx);
@@ -198,8 +195,8 @@ const bool l1t::AXOL1TLCondition::evaluateCondition(const int bxEval) const {
198195
if (iMu < NMuons) { //stop if fill the Nobjects we need
199196
MuInput[0 + (3 * iMu)] = ((candMuVec->at(useBx, iMu))->hwPt()) /
200197
2; //index 0,3,6,9 //have to do hwPt/2 in order to match original et inputs
201-
MuInput[1 + (3 * iMu)] = (candMuVec->at(useBx, iMu))->hwEta(); //index 1,4,7,10
202-
MuInput[2 + (3 * iMu)] = (candMuVec->at(useBx, iMu))->hwPhi(); //index 2,5,8,11
198+
MuInput[1 + (3 * iMu)] = (candMuVec->at(useBx, iMu))->hwEtaAtVtx(); //index 1,4,7,10
199+
MuInput[2 + (3 * iMu)] = (candMuVec->at(useBx, iMu))->hwPhiAtVtx(); //index 2,5,8,11
203200
}
204201
}
205202
}
@@ -234,10 +231,18 @@ const bool l1t::AXOL1TLCondition::evaluateCondition(const int bxEval) const {
234231
//now run the inference
235232
m_model->prepare_input(ADModelInput); //scaling internal here
236233
m_model->predict();
237-
m_model->read_result(&ADModelResult); // this should be the square sum model result
234+
// m_model->read_result(&ADModelResult); // this should be the square sum model result
235+
if ((m_model_loader.model_name() == "GTADModel_v3") ||
236+
(m_model_loader.model_name() == "GTADModel_v4")) { //v3/v4 overwrite
237+
using resulttype = std::array<ap_fixed<10, 7, AP_RND_CONV, AP_SAT>, 8>;
238+
loss = readResult<resulttype, losstype>(*m_model);
239+
} else { //v5 default
240+
using resulttype = ap_fixed<18, 14, AP_RND_CONV, AP_SAT>;
241+
loss = readResult<resulttype, losstype>(*m_model);
242+
}
238243

239-
result = ADModelResult.first;
240-
loss = ADModelResult.second;
244+
// result = ADModelResult.first;
245+
// loss = ADModelResult.second;
241246
score = ((loss).to_float()) * 16.0; //scaling to match threshold
242247
//save score to class variable in case score saving needed
243248
setScore(score);

0 commit comments

Comments
 (0)