1+ // to use computeRadiusFromThreeAnchorHits
2+ #include " LSTEvent.h"
3+ #include " Triplet.h"
4+
15#include " write_lst_ntuple.h"
26
37using namespace ALPAKA_ACCELERATOR_NAMESPACE ::lst;
@@ -72,6 +76,8 @@ void createOptionalOutputBranches() {
7276 ana.tx ->createBranch <std::vector<float >>(" pT5_eta" );
7377 ana.tx ->createBranch <std::vector<float >>(" pT5_phi" );
7478 ana.tx ->createBranch <std::vector<int >>(" pT5_isFake" );
79+ ana.tx ->createBranch <std::vector<float >>(" t5_sim_vxy" );
80+ ana.tx ->createBranch <std::vector<float >>(" t5_sim_vz" );
7581 ana.tx ->createBranch <std::vector<int >>(" pT5_isDuplicate" );
7682 ana.tx ->createBranch <std::vector<int >>(" pT5_score" );
7783 ana.tx ->createBranch <std::vector<int >>(" pT5_layer_binary" );
@@ -118,6 +124,7 @@ void createOptionalOutputBranches() {
118124 ana.tx ->createBranch <std::vector<int >>(" t5_isDuplicate" );
119125 ana.tx ->createBranch <std::vector<int >>(" t5_foundDuplicate" );
120126 ana.tx ->createBranch <std::vector<float >>(" t5_pt" );
127+ ana.tx ->createBranch <std::vector<float >>(" t5_pMatched" );
121128 ana.tx ->createBranch <std::vector<float >>(" t5_eta" );
122129 ana.tx ->createBranch <std::vector<float >>(" t5_phi" );
123130 ana.tx ->createBranch <std::vector<float >>(" t5_score_rphisum" );
@@ -126,13 +133,14 @@ void createOptionalOutputBranches() {
126133 ana.tx ->createBranch <std::vector<int >>(" t5_moduleType_binary" );
127134 ana.tx ->createBranch <std::vector<int >>(" t5_layer_binary" );
128135 ana.tx ->createBranch <std::vector<float >>(" t5_matched_pt" );
129- ana.tx ->createBranch <std::vector<int >>(" t5_partOfTC" );
130136 ana.tx ->createBranch <std::vector<float >>(" t5_innerRadius" );
131137 ana.tx ->createBranch <std::vector<float >>(" t5_outerRadius" );
132138 ana.tx ->createBranch <std::vector<float >>(" t5_bridgeRadius" );
133139 ana.tx ->createBranch <std::vector<float >>(" t5_chiSquared" );
134140 ana.tx ->createBranch <std::vector<float >>(" t5_rzChiSquared" );
135141 ana.tx ->createBranch <std::vector<float >>(" t5_nonAnchorChiSquared" );
142+ ana.tx ->createBranch <std::vector<float >>(" t5_dBeta1" );
143+ ana.tx ->createBranch <std::vector<float >>(" t5_dBeta2" );
136144
137145 // Occupancy branches
138146 ana.tx ->createBranch <std::vector<int >>(" module_layers" );
@@ -151,9 +159,39 @@ void createOptionalOutputBranches() {
151159 ana.tx ->createBranch <int >(" pT3_occupancies" );
152160 ana.tx ->createBranch <int >(" pT5_occupancies" );
153161
162+ // T5 DNN branches
163+ createT5DNNBranches ();
164+
154165#endif
155166}
156167
168+ // ________________________________________________________________________________________________________________________________
169+ void createT5DNNBranches () {
170+ // Common branches
171+ ana.tx ->createBranch <std::vector<int >>(" t5_t3_idx0" );
172+ ana.tx ->createBranch <std::vector<int >>(" t5_t3_idx1" );
173+ ana.tx ->createBranch <std::vector<int >>(" t5_tc_idx" );
174+ ana.tx ->createBranch <std::vector<int >>(" t5_partOfTC" );
175+ ana.tx ->createBranch <std::vector<float >>(" t5_t3_pt" );
176+ ana.tx ->createBranch <std::vector<float >>(" t5_t3_eta" );
177+ ana.tx ->createBranch <std::vector<float >>(" t5_t3_phi" );
178+
179+ // Hit-specific branches
180+ std::vector<std::string> hitIndices = {" 0" , " 1" , " 2" , " 3" , " 4" , " 5" };
181+ std::vector<std::string> hitProperties = {" r" , " x" , " y" , " z" , " eta" , " phi" , " detId" , " layer" , " moduleType" };
182+
183+ for (const auto & idx : hitIndices) {
184+ for (const auto & prop : hitProperties) {
185+ std::string branchName = " t5_t3_" + idx + " _" + prop;
186+ if (prop == " detId" || prop == " layer" || prop == " moduleType" ) {
187+ ana.tx ->createBranch <std::vector<int >>(branchName);
188+ } else {
189+ ana.tx ->createBranch <std::vector<float >>(branchName);
190+ }
191+ }
192+ }
193+ }
194+
157195// ________________________________________________________________________________________________________________________________
158196void createGnnNtupleBranches () {
159197 // Mini Doublets
@@ -302,6 +340,7 @@ void setOptionalOutputBranches(LSTEvent* event) {
302340 setQuintupletOutputBranches (event);
303341 setPixelTripletOutputBranches (event);
304342 setOccupancyBranches (event);
343+ setT5DNNBranches (event);
305344
306345#endif
307346}
@@ -474,17 +513,22 @@ void setQuintupletOutputBranches(LSTEvent* event) {
474513 moduleType_binary |= (modules.moduleType ()[module_idx[i]] << i);
475514 }
476515
477- std::vector<int > simidx = matchedSimTrkIdxs (hit_idx, hit_type);
516+ float percent_matched;
517+ std::vector<int > simidx = matchedSimTrkIdxs (hit_idx, hit_type, false , &percent_matched);
478518
479519 ana.tx ->pushbackToBranch <int >(" t5_isFake" , static_cast <int >(simidx.size () == 0 ));
480520 ana.tx ->pushbackToBranch <float >(" t5_pt" , pt);
521+ ana.tx ->pushbackToBranch <float >(" t5_pMatched" , percent_matched);
481522 ana.tx ->pushbackToBranch <float >(" t5_eta" , eta);
482523 ana.tx ->pushbackToBranch <float >(" t5_phi" , phi);
483524 ana.tx ->pushbackToBranch <float >(" t5_innerRadius" , __H2F (quintuplets.innerRadius ()[quintupletIndex]));
484525 ana.tx ->pushbackToBranch <float >(" t5_bridgeRadius" , __H2F (quintuplets.bridgeRadius ()[quintupletIndex]));
485526 ana.tx ->pushbackToBranch <float >(" t5_outerRadius" , __H2F (quintuplets.outerRadius ()[quintupletIndex]));
486527 ana.tx ->pushbackToBranch <float >(" t5_chiSquared" , quintuplets.chiSquared ()[quintupletIndex]);
487528 ana.tx ->pushbackToBranch <float >(" t5_rzChiSquared" , quintuplets.rzChiSquared ()[quintupletIndex]);
529+ ana.tx ->pushbackToBranch <float >(" t5_nonAnchorChiSquared" , quintuplets.nonAnchorChiSquared ()[quintupletIndex]);
530+ ana.tx ->pushbackToBranch <float >(" t5_dBeta1" , quintuplets.dBeta1 ()[quintupletIndex]);
531+ ana.tx ->pushbackToBranch <float >(" t5_dBeta2" , quintuplets.dBeta2 ()[quintupletIndex]);
488532 ana.tx ->pushbackToBranch <int >(" t5_layer_binary" , layer_binary);
489533 ana.tx ->pushbackToBranch <int >(" t5_moduleType_binary" , moduleType_binary);
490534
@@ -495,6 +539,21 @@ void setQuintupletOutputBranches(LSTEvent* event) {
495539 sim_t5_matched.at (simtrk) += 1 ;
496540 }
497541 }
542+
543+ // Avoid fakes when calculating the vertex distance, set default to 0.0.
544+ if (simidx.size () == 0 ) {
545+ ana.tx ->pushbackToBranch <float >(" t5_sim_vxy" , 0.0 );
546+ ana.tx ->pushbackToBranch <float >(" t5_sim_vz" , 0.0 );
547+ continue ;
548+ }
549+
550+ int vtxidx = trk.sim_parentVtxIdx ()[simidx[0 ]];
551+ float vtx_x = trk.simvtx_x ()[vtxidx];
552+ float vtx_y = trk.simvtx_y ()[vtxidx];
553+ float vtx_z = trk.simvtx_z ()[vtxidx];
554+
555+ ana.tx ->pushbackToBranch <float >(" t5_sim_vxy" , sqrt (vtx_x * vtx_x + vtx_y * vtx_y));
556+ ana.tx ->pushbackToBranch <float >(" t5_sim_vz" , vtx_z);
498557 }
499558 }
500559
@@ -579,6 +638,109 @@ void setPixelTripletOutputBranches(LSTEvent* event) {
579638 ana.tx ->setBranch <std::vector<int >>(" pT3_isDuplicate" , pT3_isDuplicate);
580639}
581640
641+ // ________________________________________________________________________________________________________________________________
642+ void fillT5DNNBranches (LSTEvent* event, unsigned int iT3) {
643+ auto hits = event->getHits <HitsSoA>();
644+ auto modules = event->getModules <ModulesSoA>();
645+
646+ std::vector<unsigned int > hitIdx = getHitsFromT3 (event, iT3);
647+ std::vector<lst_math::Hit> hitObjects (hitIdx.size ());
648+
649+ for (int i = 0 ; i < hitIdx.size (); ++i) {
650+ unsigned int hit = hitIdx[i];
651+ float x = hits.xs ()[hit];
652+ float y = hits.ys ()[hit];
653+ float z = hits.zs ()[hit];
654+ hitObjects[i] = lst_math::Hit (x, y, z);
655+
656+ std::string idx = std::to_string (i);
657+ ana.tx ->pushbackToBranch <float >(" t5_t3_" + idx + " _r" , sqrt (x * x + y * y));
658+ ana.tx ->pushbackToBranch <float >(" t5_t3_" + idx + " _x" , x);
659+ ana.tx ->pushbackToBranch <float >(" t5_t3_" + idx + " _y" , y);
660+ ana.tx ->pushbackToBranch <float >(" t5_t3_" + idx + " _z" , z);
661+ ana.tx ->pushbackToBranch <float >(" t5_t3_" + idx + " _eta" , hitObjects[i].eta ());
662+ ana.tx ->pushbackToBranch <float >(" t5_t3_" + idx + " _phi" , hitObjects[i].phi ());
663+
664+ int subdet = trk.ph2_subdet ()[hits.idxs ()[hit]];
665+ int is_endcap = subdet == 4 ;
666+ int layer = trk.ph2_layer ()[hits.idxs ()[hit]] + 6 * is_endcap;
667+ int detId = trk.ph2_detId ()[hits.idxs ()[hit]];
668+ unsigned int module = hits.moduleIndices ()[hit];
669+
670+ ana.tx ->pushbackToBranch <int >(" t5_t3_" + idx + " _detId" , detId);
671+ ana.tx ->pushbackToBranch <int >(" t5_t3_" + idx + " _layer" , layer);
672+ ana.tx ->pushbackToBranch <int >(" t5_t3_" + idx + " _moduleType" , modules.moduleType ()[module ]);
673+ }
674+
675+ float g, f;
676+ auto const & devHost = cms::alpakatools::host ();
677+ float radius = computeRadiusFromThreeAnchorHits (devHost,
678+ hitObjects[0 ].x (),
679+ hitObjects[0 ].y (),
680+ hitObjects[1 ].x (),
681+ hitObjects[1 ].y (),
682+ hitObjects[2 ].x (),
683+ hitObjects[2 ].y (),
684+ g,
685+ f);
686+ ana.tx ->pushbackToBranch <float >(" t5_t3_pt" , k2Rinv1GeVf * 2 * radius);
687+
688+ // Angles
689+ ana.tx ->pushbackToBranch <float >(" t5_t3_eta" , hitObjects[2 ].eta ());
690+ ana.tx ->pushbackToBranch <float >(" t5_t3_phi" , hitObjects[0 ].phi ());
691+ }
692+
693+ // ________________________________________________________________________________________________________________________________
694+ void setT5DNNBranches (LSTEvent* event) {
695+ auto triplets = event->getTriplets <TripletsOccupancySoA>();
696+ auto modules = event->getModules <ModulesSoA>();
697+ auto ranges = event->getRanges ();
698+ auto const quintuplets = event->getQuintuplets <QuintupletsOccupancySoA>();
699+ auto trackCandidates = event->getTrackCandidates ();
700+
701+ std::unordered_set<unsigned int > allT3s;
702+ std::unordered_map<unsigned int , unsigned int > t3_index_map;
703+
704+ for (unsigned int idx = 0 ; idx < modules.nLowerModules (); ++idx) {
705+ for (unsigned int jdx = 0 ; jdx < triplets.nTriplets ()[idx]; ++jdx) {
706+ unsigned int t3Idx = ranges.tripletModuleIndices ()[idx] + jdx;
707+ if (allT3s.insert (t3Idx).second ) {
708+ t3_index_map[t3Idx] = allT3s.size () - 1 ;
709+ fillT5DNNBranches (event, t3Idx);
710+ }
711+ }
712+ }
713+
714+ std::unordered_map<unsigned int , unsigned int > t5_tc_index_map;
715+ std::unordered_set<unsigned int > t5s_used_in_tc;
716+
717+ for (unsigned int idx = 0 ; idx < trackCandidates.nTrackCandidates (); idx++) {
718+ if (trackCandidates.trackCandidateType ()[idx] == LSTObjType::T5) {
719+ unsigned int objIdx = trackCandidates.directObjectIndices ()[idx];
720+ t5s_used_in_tc.insert (objIdx);
721+ t5_tc_index_map[objIdx] = idx;
722+ }
723+ }
724+
725+ for (unsigned int idx = 0 ; idx < modules.nLowerModules (); ++idx) {
726+ for (unsigned int jdx = 0 ; jdx < quintuplets.nQuintuplets ()[idx]; ++jdx) {
727+ unsigned int t5Idx = ranges.quintupletModuleIndices ()[idx] + jdx;
728+ std::vector<unsigned int > t3sIdx = getT3sFromT5 (event, t5Idx);
729+
730+ ana.tx ->pushbackToBranch <int >(" t5_t3_idx0" , t3_index_map[t3sIdx[0 ]]);
731+ ana.tx ->pushbackToBranch <int >(" t5_t3_idx1" , t3_index_map[t3sIdx[1 ]]);
732+
733+ if (t5s_used_in_tc.find (t5Idx) != t5s_used_in_tc.end ()) {
734+ ana.tx ->pushbackToBranch <int >(" t5_partOfTC" , 1 );
735+ ana.tx ->pushbackToBranch <int >(" t5_tc_idx" , t5_tc_index_map[t5Idx]);
736+ } else {
737+ ana.tx ->pushbackToBranch <int >(" t5_partOfTC" , 0 );
738+ ana.tx ->pushbackToBranch <int >(" t5_tc_idx" , -999 );
739+ }
740+ }
741+ }
742+ }
743+
582744// ________________________________________________________________________________________________________________________________
583745void setGnnNtupleBranches (LSTEvent* event) {
584746 // Get relevant information
@@ -798,16 +960,16 @@ std::tuple<int, float, float, float, int, std::vector<int>> parseTrackCandidate(
798960 float pt, eta, phi;
799961 std::vector<unsigned int > hit_idx, hit_type;
800962 switch (type) {
801- case lst:: LSTObjType::pT5:
963+ case LSTObjType::pT5:
802964 std::tie (pt, eta, phi, hit_idx, hit_type) = parsepT5 (event, idx);
803965 break ;
804- case lst:: LSTObjType::pT3:
966+ case LSTObjType::pT3:
805967 std::tie (pt, eta, phi, hit_idx, hit_type) = parsepT3 (event, idx);
806968 break ;
807- case lst:: LSTObjType::T5:
969+ case LSTObjType::T5:
808970 std::tie (pt, eta, phi, hit_idx, hit_type) = parseT5 (event, idx);
809971 break ;
810- case lst:: LSTObjType::pLS:
972+ case LSTObjType::pLS:
811973 std::tie (pt, eta, phi, hit_idx, hit_type) = parsepLS (event, idx);
812974 break ;
813975 }
0 commit comments