Skip to content

Commit 38a8e65

Browse files
choich08365Changhwan Choi
andauthored
[PWGJE] Configurable GNN input feature transform function (AliceO2Group#14346)
Co-authored-by: Changhwan Choi <[email protected]>
1 parent 20efe06 commit 38a8e65

File tree

3 files changed

+19
-80
lines changed

3 files changed

+19
-80
lines changed

PWGJE/Core/MlResponseHfTagging.h

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include <cmath>
2727
#include <cstddef>
2828
#include <cstdint>
29+
#include <string>
2930
#include <utility>
3031
#include <vector>
3132

@@ -361,18 +362,20 @@ class GNNBjetAllocator : public TensorAllocator
361362

362363
std::vector<std::vector<int64_t>> edgesList;
363364

365+
std::function<float(float)> tfFunc;
366+
364367
// Jet feature normalization
365368
template <typename T>
366369
T jetFeatureTransform(T feat, int idx) const
367370
{
368-
return std::tanh((feat - tfJetMean[idx]) / tfJetStdev[idx]);
371+
return tfFunc((feat - tfJetMean[idx]) / tfJetStdev[idx]);
369372
}
370373

371374
// Track feature normalization
372375
template <typename T>
373376
T trkFeatureTransform(T feat, int idx) const
374377
{
375-
return std::tanh((feat - tfTrkMean[idx]) / tfTrkStdev[idx]);
378+
return tfFunc((feat - tfTrkMean[idx]) / tfTrkStdev[idx]);
376379
}
377380

378381
// Edge input of GNN (fully-connected graph)
@@ -419,10 +422,17 @@ class GNNBjetAllocator : public TensorAllocator
419422
}
420423

421424
public:
422-
GNNBjetAllocator() : TensorAllocator(), nJetFeat(4), nTrkFeat(13), nFlav(3), nTrkOrigin(5), maxNNodes(40) {}
423-
GNNBjetAllocator(int64_t nJetFeat, int64_t nTrkFeat, int64_t nFlav, int64_t nTrkOrigin, std::vector<float>& tfJetMean, std::vector<float>& tfJetStdev, std::vector<float>& tfTrkMean, std::vector<float>& tfTrkStdev, int64_t maxNNodes = 40)
424-
: TensorAllocator(), nJetFeat(nJetFeat), nTrkFeat(nTrkFeat), nFlav(nFlav), nTrkOrigin(nTrkOrigin), maxNNodes(maxNNodes), tfJetMean(tfJetMean), tfJetStdev(tfJetStdev), tfTrkMean(tfTrkMean), tfTrkStdev(tfTrkStdev)
425+
GNNBjetAllocator() : TensorAllocator(), nJetFeat(4), nTrkFeat(13), nFlav(3), nTrkOrigin(5), maxNNodes(40), tfFunc([](float x) { return x; }) {}
426+
GNNBjetAllocator(int64_t nJetFeat, int64_t nTrkFeat, int64_t nFlav, int64_t nTrkOrigin, std::vector<float>& tfJetMean, std::vector<float>& tfJetStdev, std::vector<float>& tfTrkMean, std::vector<float>& tfTrkStdev, int64_t maxNNodes = 40, std::string tfFuncType = "linear")
427+
: TensorAllocator(), nJetFeat(nJetFeat), nTrkFeat(nTrkFeat), nFlav(nFlav), nTrkOrigin(nTrkOrigin), maxNNodes(maxNNodes), tfJetMean(tfJetMean), tfJetStdev(tfJetStdev), tfTrkMean(tfTrkMean), tfTrkStdev(tfTrkStdev), tfFunc([](float x) { return x; })
425428
{
429+
if (tfFuncType == "asinh") {
430+
tfFunc = [](float x) { return std::asinh(x); };
431+
} else if (tfFuncType == "tanh") {
432+
tfFunc = [](float x) { return std::tanh(x); };
433+
} else {
434+
tfFunc = [](float x) { return x; };
435+
}
426436
setEdgesList();
427437
}
428438
~GNNBjetAllocator() = default;
@@ -439,6 +449,8 @@ class GNNBjetAllocator : public TensorAllocator
439449
tfJetStdev = other.tfJetStdev;
440450
tfTrkMean = other.tfTrkMean;
441451
tfTrkStdev = other.tfTrkStdev;
452+
tfFunc = other.tfFunc;
453+
edgesList.clear();
442454
setEdgesList();
443455
return *this;
444456
}

PWGJE/TableProducer/jetTaggerHF.cxx

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ struct JetTaggerHFTask {
139139
Configurable<std::vector<float>> transformFeatureTrkStdev{"transformFeatureTrkStdev",
140140
std::vector<float>{-999},
141141
"Stdev values for each GNN input feature (track)"};
142+
Configurable<std::string> tfFuncTypeGNN{"tfFuncTypeGNN", "linear", "Transformation function type for GNN"};
142143

143144
// axis spec
144145
ConfigurableAxis binTrackProbability{"binTrackProbability", {100, 0.f, 1.f}, ""};
@@ -525,7 +526,7 @@ struct JetTaggerHFTask {
525526
}
526527

527528
if (doprocessAlgorithmGNN) {
528-
tensorAlloc = o2::analysis::GNNBjetAllocator(nJetFeat.value, nTrkFeat.value, nClassesMl.value, nTrkOrigin.value, transformFeatureJetMean.value, transformFeatureJetStdev.value, transformFeatureTrkMean.value, transformFeatureTrkStdev.value, nJetConst);
529+
tensorAlloc = o2::analysis::GNNBjetAllocator(nJetFeat.value, nTrkFeat.value, nClassesMl.value, nTrkOrigin.value, transformFeatureJetMean.value, transformFeatureJetStdev.value, transformFeatureTrkMean.value, transformFeatureTrkStdev.value, nJetConst, tfFuncTypeGNN.value);
529530

530531
registry.add("h2_count_db", "#it{D}_{b} underflow/overflow;Jet flavour;#it{D}_{b} range", {HistType::kTH2F, {{4, 0., 4.}, {3, 0., 3.}}});
531532
auto h2CountDb = registry.get<TH2>(HIST("h2_count_db"));

PWGJE/Tasks/bjetTaggingGnn.cxx

Lines changed: 0 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -809,80 +809,6 @@ struct BjetTaggingGnn {
809809
registry.fill(HIST("h2_Response_DetjetpT_PartjetpT_b_inelgt0"), analysisJet.pt(), mcpjetpT, isTrueINELgt0 && (hasAll(evtselCode, EvtSelFlag::INELgt0rec)) ? weightEvt : 0.0);
810810
}
811811
}
812-
813-
// switch (evtselCode) {
814-
// case static_cast<int>(EvtSel::INELgt0rec:
815-
// registry.fill(HIST("h_jetpT_inelgt0rec"), analysisJet.pt(), weightEvt);
816-
// if (isMatched) {
817-
// registry.fill(HIST("h2_Response_DetjetpT_PartjetpT_inelgt0"), analysisJet.pt(), mcpjetpT, weightEvt);
818-
// }
819-
// if (isBjet) {
820-
// registry.fill(HIST("h_jetpT_b_inelgt0"), analysisJet.pt(), weightEvt);
821-
// if (isMatched) {
822-
// registry.fill(HIST("h2_Response_DetjetpT_PartjetpT_b_inelgt0"), analysisJet.pt(), mcpjetpT, weightEvt);
823-
// }
824-
// }
825-
// case static_cast<int>(EvtSel::Sel8Zvtx:
826-
// registry.fill(HIST("h_jetpT_sel8_zvtx"), analysisJet.pt(), weightEvt);
827-
// if (isMatched) {
828-
// registry.fill(HIST("h2_Response_DetjetpT_PartjetpT_sel8"), analysisJet.pt(), mcpjetpT, weightEvt);
829-
// }
830-
// if (isBjet) {
831-
// registry.fill(HIST("h_jetpT_b_sel8_zvtx"), analysisJet.pt(), weightEvt);
832-
// if (isMatched) {
833-
// registry.fill(HIST("h2_Response_DetjetpT_PartjetpT_b_sel8"), analysisJet.pt(), mcpjetpT, weightEvt);
834-
// }
835-
// }
836-
// case static_cast<int>(EvtSel::SelMCZvtx:
837-
// registry.fill(HIST("h_jetpT_selmc_zvtx"), analysisJet.pt(), weightEvt);
838-
// if (isMatched) {
839-
// registry.fill(HIST("h2_Response_DetjetpT_PartjetpT_selmc"), analysisJet.pt(), mcpjetpT, weightEvt);
840-
// }
841-
// if (isBjet) {
842-
// registry.fill(HIST("h_jetpT_b_selmc_zvtx"), analysisJet.pt(), weightEvt);
843-
// if (isMatched) {
844-
// registry.fill(HIST("h2_Response_DetjetpT_PartjetpT_b_selmc"), analysisJet.pt(), mcpjetpT, weightEvt);
845-
// }
846-
// }
847-
// case static_cast<int>(EvtSel::TVXZvtx:
848-
// registry.fill(HIST("h_jetpT_tvx_zvtx"), analysisJet.pt(), weightEvt);
849-
// if (isBjet) {
850-
// registry.fill(HIST("h_jetpT_b_tvx_zvtx"), analysisJet.pt(), weightEvt);
851-
// }
852-
// case static_cast<int>(EvtSel::CollZvtx:
853-
// registry.fill(HIST("h_jetpT_coll_zvtx"), analysisJet.pt(), weightEvt);
854-
// if (isBjet) {
855-
// registry.fill(HIST("h_jetpT_b_coll_zvtx"), analysisJet.pt(), weightEvt);
856-
// }
857-
// default:
858-
// switch (evtselCode) {
859-
// case static_cast<int>(EvtSel::Sel8:
860-
// case static_cast<int>(EvtSel::Sel8Zvtx:
861-
// registry.fill(HIST("h_jetpT_sel8"), analysisJet.pt(), weightEvt);
862-
// if (isBjet) {
863-
// registry.fill(HIST("h_jetpT_b_sel8"), analysisJet.pt(), weightEvt);
864-
// }
865-
// case static_cast<int>(EvtSel::SelMC:
866-
// case static_cast<int>(EvtSel::SelMCZvtx:
867-
// registry.fill(HIST("h_jetpT_selmc"), analysisJet.pt(), weightEvt);
868-
// if (isBjet) {
869-
// registry.fill(HIST("h_jetpT_b_selmc"), analysisJet.pt(), weightEvt);
870-
// }
871-
// case static_cast<int>(EvtSel::TVX:
872-
// case static_cast<int>(EvtSel::TVXZvtx:
873-
// registry.fill(HIST("h_jetpT_tvx"), analysisJet.pt(), weightEvt);
874-
// if (isBjet) {
875-
// registry.fill(HIST("h_jetpT_b_tvx"), analysisJet.pt(), weightEvt);
876-
// }
877-
// case static_cast<int>(EvtSel::Coll:
878-
// case static_cast<int>(EvtSel::CollZvtx:
879-
// default:
880-
// registry.fill(HIST("h_jetpT_coll"), analysisJet.pt(), weightEvt);
881-
// if (isBjet) {
882-
// registry.fill(HIST("h_jetpT_b_coll"), analysisJet.pt(), weightEvt);
883-
// }
884-
// }
885-
// }
886812
}
887813
}
888814
PROCESS_SWITCH(BjetTaggingGnn, processMCDJetsSel, "jet information in MC (event selection)", false);

0 commit comments

Comments
 (0)