Skip to content

Commit 0d333c4

Browse files
authored
[PWGJE] Adding the IPz as a feature for ML (AliceO2Group#11088)
1 parent 3151e24 commit 0d333c4

File tree

3 files changed

+39
-6
lines changed

3 files changed

+39
-6
lines changed

PWGJE/Core/JetTaggingUtilities.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@ struct BJetTrackParams {
7979
double deltaRJetTrack = 0.0;
8080
double signedIP2D = 0.0;
8181
double signedIP2DSign = 0.0;
82+
double signedIPz = 0.0;
83+
double signedIPzSign = 0.0;
8284
double signedIP3D = 0.0;
8385
double signedIP3DSign = 0.0;
8486
double momFraction = 0.0;
@@ -1011,7 +1013,7 @@ void analyzeJetTrackInfo4ML(AnalysisJet const& analysisJet, AnyTracks const& /*a
10111013
}
10121014
}
10131015

1014-
tracksParams.emplace_back(BJetTrackParams{constituent.pt(), constituent.eta(), dotProduct, dotProduct / analysisJet.p(), deltaRJetTrack, std::abs(constituent.dcaXY()) * sign, constituent.sigmadcaXY(), std::abs(constituent.dcaXYZ()) * sign, constituent.sigmadcaXYZ(), constituent.p() / analysisJet.p(), rClosestSV});
1016+
tracksParams.emplace_back(BJetTrackParams{constituent.pt(), constituent.eta(), dotProduct, dotProduct / analysisJet.p(), deltaRJetTrack, std::abs(constituent.dcaXY()) * sign, constituent.sigmadcaXY(), std::abs(constituent.dcaZ()) * sign, constituent.sigmadcaZ(), std::abs(constituent.dcaXYZ()) * sign, constituent.sigmadcaXYZ(), constituent.p() / analysisJet.p(), rClosestSV});
10151017
}
10161018

10171019
auto compare = [](BJetTrackParams& tr1, BJetTrackParams& tr2) {
@@ -1036,7 +1038,7 @@ void analyzeJetTrackInfo4MLnoSV(AnalysisJet const& analysisJet, AnyTracks const&
10361038
double dotProduct = RecoDecay::dotProd(std::array<float, 3>{analysisJet.px(), analysisJet.py(), analysisJet.pz()}, std::array<float, 3>{constituent.px(), constituent.py(), constituent.pz()});
10371039
int sign = getGeoSign(analysisJet, constituent);
10381040

1039-
tracksParams.emplace_back(BJetTrackParams{constituent.pt(), constituent.eta(), dotProduct, dotProduct / analysisJet.p(), deltaRJetTrack, std::abs(constituent.dcaXY()) * sign, constituent.sigmadcaXY(), std::abs(constituent.dcaXYZ()) * sign, constituent.sigmadcaXYZ(), constituent.p() / analysisJet.p(), 0.0});
1041+
tracksParams.emplace_back(BJetTrackParams{constituent.pt(), constituent.eta(), dotProduct, dotProduct / analysisJet.p(), deltaRJetTrack, std::abs(constituent.dcaXY()) * sign, constituent.sigmadcaXY(), std::abs(constituent.dcaZ()) * sign, constituent.sigmadcaZ(), std::abs(constituent.dcaXYZ()) * sign, constituent.sigmadcaXYZ(), constituent.p() / analysisJet.p(), 0.0});
10401042
}
10411043

10421044
auto compare = [](BJetTrackParams& tr1, BJetTrackParams& tr2) {

PWGJE/Core/MlResponseHfTagging.h

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ enum class InputFeaturesBTag : uint8_t {
7171
deltaRJetTrack,
7272
signedIP2D,
7373
signedIP2DSign,
74+
signedIPz,
75+
signedIPzSign,
7476
signedIP3D,
7577
signedIP3DSign,
7678
momFraction,
@@ -148,6 +150,8 @@ class MlResponseHfTagging : public MlResponse<TypeOutputScore>
148150
CHECK_AND_FILL_VEC_BTAG(trackInput, track, deltaRJetTrack)
149151
CHECK_AND_FILL_VEC_BTAG(trackInput, track, signedIP2D)
150152
CHECK_AND_FILL_VEC_BTAG(trackInput, track, signedIP2DSign)
153+
CHECK_AND_FILL_VEC_BTAG(trackInput, track, signedIPz)
154+
CHECK_AND_FILL_VEC_BTAG(trackInput, track, signedIPzSign)
151155
CHECK_AND_FILL_VEC_BTAG(trackInput, track, signedIP3D)
152156
CHECK_AND_FILL_VEC_BTAG(trackInput, track, signedIP3DSign)
153157
CHECK_AND_FILL_VEC_BTAG(trackInput, track, momFraction)
@@ -192,6 +196,23 @@ class MlResponseHfTagging : public MlResponse<TypeOutputScore>
192196
}
193197
}
194198

199+
/// @brief Method to replace NaN and infinity values in a vector with a specified value
200+
/// @param vec is the vector to be processed
201+
/// @param value is the value to replace NaN values with
202+
/// @return the number of NaN values replaced
203+
template <typename T>
204+
static int replaceNaN(std::vector<T>& vec, T value)
205+
{
206+
int numNaN = 0;
207+
for (auto& el : vec) {
208+
if (std::isnan(el) || std::isinf(el)) {
209+
el = value;
210+
++numNaN;
211+
}
212+
}
213+
return numNaN;
214+
}
215+
195216
/// Method to get the input features vector needed for ML inference in a 2D vector
196217
/// \param jet is the b-jet candidate
197218
/// \param tracks is the vector of tracks associated to the jet
@@ -209,6 +230,10 @@ class MlResponseHfTagging : public MlResponse<TypeOutputScore>
209230

210231
std::vector<std::vector<float>> inputFeatures;
211232

233+
replaceNaN(jetInput, 0.f);
234+
replaceNaN(trackInput, 0.f);
235+
replaceNaN(svInput, 0.f);
236+
212237
inputFeatures.push_back(jetInput);
213238
inputFeatures.push_back(trackInput);
214239
inputFeatures.push_back(svInput);
@@ -237,6 +262,8 @@ class MlResponseHfTagging : public MlResponse<TypeOutputScore>
237262
inputFeatures.insert(inputFeatures.end(), trackInput.begin(), trackInput.end());
238263
inputFeatures.insert(inputFeatures.end(), svInput.begin(), svInput.end());
239264

265+
replaceNaN(inputFeatures, 0.f);
266+
240267
return inputFeatures;
241268
}
242269

@@ -261,6 +288,8 @@ class MlResponseHfTagging : public MlResponse<TypeOutputScore>
261288
FILL_MAP_BJET(deltaRJetTrack),
262289
FILL_MAP_BJET(signedIP2D),
263290
FILL_MAP_BJET(signedIP2DSign),
291+
FILL_MAP_BJET(signedIPz),
292+
FILL_MAP_BJET(signedIPzSign),
264293
FILL_MAP_BJET(signedIP3D),
265294
FILL_MAP_BJET(signedIP3DSign),
266295
FILL_MAP_BJET(momFraction),

PWGJE/Tasks/bjetTreeCreator.cxx

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,8 @@ DECLARE_SOA_COLUMN(DotProdTrackJetOverJet, trackdotjetoverjet, float); //! The d
8282
DECLARE_SOA_COLUMN(DeltaRJetTrack, rjettrack, float); //! The DR jet-track
8383
DECLARE_SOA_COLUMN(SignedIP2D, ip2d, float); //! The track signed 2D IP
8484
DECLARE_SOA_COLUMN(SignedIP2DSign, ip2dsigma, float); //! The track signed 2D IP significance
85-
DECLARE_SOA_COLUMN(SignedIP3D, ip3d, float); //! The track signed 3D IP
85+
DECLARE_SOA_COLUMN(SignedIPz, ipz, float); //! The track signed z IP
86+
DECLARE_SOA_COLUMN(SignedIPzSign, ipzsigma, float); //! The track signed z IP significance
8687
DECLARE_SOA_COLUMN(SignedIP3DSign, ip3dsigma, float); //! The track signed 3D IP significance
8788
DECLARE_SOA_COLUMN(MomFraction, momfraction, float); //! The track momentum fraction of the jets
8889
DECLARE_SOA_COLUMN(DeltaRTrackVertex, rtrackvertex, float); //! DR between the track and the closest SV, to be decided whether to add to or not
@@ -108,7 +109,8 @@ DECLARE_SOA_TABLE(bjetTracksParams, "AOD", "BJETTRACKSPARAM",
108109
trackInfo::DeltaRJetTrack,
109110
trackInfo::SignedIP2D,
110111
trackInfo::SignedIP2DSign,
111-
trackInfo::SignedIP3D,
112+
trackInfo::SignedIPz,
113+
trackInfo::SignedIPzSign,
112114
trackInfo::SignedIP3DSign,
113115
trackInfo::MomFraction,
114116
trackInfo::DeltaRTrackVertex);
@@ -460,7 +462,7 @@ struct BJetTreeCreator {
460462
}
461463

462464
if (produceTree) {
463-
bjetTracksParamsTable(bjetParamsTable.lastIndex() + 1, constituent.pt(), constituent.eta(), dotProduct, dotProduct / analysisJet.p(), deltaRJetTrack, std::abs(constituent.dcaXY()) * sign, constituent.sigmadcaXY(), std::abs(constituent.dcaXYZ()) * sign, constituent.sigmadcaXYZ(), constituent.p() / analysisJet.p(), RClosestSV);
465+
bjetTracksParamsTable(bjetParamsTable.lastIndex() + 1, constituent.pt(), constituent.eta(), dotProduct, dotProduct / analysisJet.p(), deltaRJetTrack, std::abs(constituent.dcaXY()) * sign, constituent.sigmadcaXY(), std::abs(constituent.dcaZ()) * sign, constituent.sigmadcaZ(), constituent.sigmadcaXYZ(), constituent.p() / analysisJet.p(), RClosestSV);
464466
}
465467
trackIndices.push_back(bjetTracksParamsTable.lastIndex());
466468
}
@@ -531,7 +533,7 @@ struct BJetTreeCreator {
531533
}
532534

533535
if (produceTree) {
534-
bjetTracksParamsTable(bjetParamsTable.lastIndex() + 1, constituent.pt(), constituent.eta(), dotProduct, dotProduct / analysisJet.p(), deltaRJetTrack, std::abs(constituent.dcaXY()) * sign, constituent.sigmadcaXY(), std::abs(constituent.dcaXYZ()) * sign, constituent.sigmadcaXYZ(), constituent.p() / analysisJet.p(), 0.);
536+
bjetTracksParamsTable(bjetParamsTable.lastIndex() + 1, constituent.pt(), constituent.eta(), dotProduct, dotProduct / analysisJet.p(), deltaRJetTrack, std::abs(constituent.dcaXY()) * sign, constituent.sigmadcaXY(), std::abs(constituent.dcaZ()) * sign, constituent.sigmadcaZ(), constituent.sigmadcaXYZ(), constituent.p() / analysisJet.p(), 0.);
535537
}
536538
trackIndices.push_back(bjetTracksParamsTable.lastIndex());
537539
}

0 commit comments

Comments
 (0)