4040#include " Common/DataModel/EventSelection.h"
4141#include " Common/DataModel/Multiplicity.h"
4242#include " Common/DataModel/PIDResponseTPC.h"
43+ #include " Common/CCDB/ctpRateFetcher.h"
4344#include " Tools/ML/model.h"
4445
4546#include " CCDB/BasicCCDBManager.h"
5051#include " Framework/runDataProcessing.h"
5152#include " ReconstructionDataFormats/Track.h"
5253
54+
5355namespace o2 ::aod
5456{
5557namespace pid
@@ -158,6 +160,43 @@ int getPIDIndex(const int pdgCode) // Get O2 PID index corresponding to MC PDG c
158160 }
159161}
160162
163+ typedef struct Str_dEdx_correction {
164+ TMatrixD fMatrix ;
165+ bool warning = true ;
166+
167+ // void init(std::vector<double>& params)
168+ void init ()
169+ {
170+ double elements[32 ] = {0.99091 , -0.015053 , 0.0018912 , -0.012305 ,
171+ 0.081387 , 0.003205 , -0.0087404 , -0.0028608 ,
172+ 0.013066 , 0.017012 , -0.0018469 , -0.0052177 ,
173+ -0.0035655 , 0.0017846 , 0.0019127 , -0.00012964 ,
174+ 0.0049428 , 0.0055592 , -0.0010618 , -0.0016134 ,
175+ -0.0059098 , 0.0013335 , 0.00052133 , 3.1119e-05 ,
176+ -0.004882 , 0.00077317 , -0.0013827 , 0.003249 ,
177+ -0.00063689 , 0.0016218 , -0.00045215 , -1.5815e-05 };
178+ fMatrix .ResizeTo (4 , 8 );
179+ fMatrix .SetMatrixArray (elements);
180+ }
181+
182+ float fReal_fTPCSignalN (std::vector<float > vec1, std::vector<float > vec2)
183+ {
184+ float result = 0 .f ;
185+ // push 1.
186+ vec1.insert (vec1.begin (), 1.0 );
187+ vec2.insert (vec2.begin (), 1.0 );
188+ for (int i = 0 ; i < fMatrix .GetNrows (); i++) {
189+ for (int j = 0 ; j < fMatrix .GetNcols (); j++) {
190+ double param = fMatrix (i, j);
191+ double value1 = i > static_cast <int >(vec1.size ()) ? 0 : vec1[i];
192+ double value2 = j > static_cast <int >(vec2.size ()) ? 0 : vec2[j];
193+ result += param * value1 * value2;
194+ }
195+ }
196+ return result;
197+ }
198+ } Str_dEdx_correction;
199+
161200class pidTPCModule
162201{
163202 public:
@@ -181,6 +220,10 @@ class pidTPCModule
181220 // Parametrization configuration
182221 bool useCCDBParam = false ;
183222
223+ // for dEdx correction
224+ ctpRateFetcher mRateFetcher ;
225+ Str_dEdx_correction str_dedx_correction;
226+
184227 // __________________________________________________
185228 template <typename TCCDB, typename TCCDBApi, typename TContext, typename TpidTPCOpts, typename TMetadataInfo>
186229 void init (TCCDB& ccdb, TCCDBApi& ccdbApi, TContext& context, TpidTPCOpts const & external_pidtpcopts, TMetadataInfo const & metadataInfo)
@@ -547,7 +590,75 @@ class pidTPCModule
547590 uint64_t count_tracks = 0 ;
548591
549592 for (auto const & trk : tracks) {
550- // Loop on Tracks
593+ // get the TPC signal to be used in the PID
594+ float tpcSignalToEvaluatePID = trk.tpcSignal ();
595+
596+ // if corrected dE/dx is requested, correct it here on the spot and use that
597+ if (pidTPCopts.useCorrecteddEdx ){
598+ double hadronicRate;
599+ int multTPC;
600+ int occupancy;
601+ if (trk.has_collision ()) {
602+ auto collision = cols.iteratorAt (trk.collisionId ());
603+ auto bc = collision.template bc_as <aod::BCsWithTimestamps>();
604+ const int runnumber = bc.runNumber ();
605+ hadronicRate = mRateFetcher .fetch (ccdb.service , bc.timestamp (), runnumber, " ZNC hadronic" ) * 1 .e -3 ; // kHz
606+ multTPC = pidmults[trk.collisionId ()];
607+ occupancy = collision.trackOccupancyInTimeRange ();
608+ } else {
609+ auto bc = bcs.begin ();
610+ const int runnumber = bc.runNumber ();
611+ hadronicRate = mRateFetcher .fetch (ccdb.service , bc.timestamp (), runnumber, " ZNC hadronic" ) * 1 .e -3 ; // kHz
612+ multTPC = 0 ;
613+ occupancy = 0 ;
614+ }
615+
616+ float fTPCSignal = trk.tpcSignal ();
617+ float fNormMultTPC = multTPC / 11000 .;
618+
619+ float fTrackOccN = occupancy / 1000 .;
620+ float fOccTPCN = fNormMultTPC * 10 ; // (fNormMultTPC*10).clip(0,12)
621+ if (fOccTPCN > 12 )
622+ fOccTPCN = 12 ;
623+ else if (fOccTPCN < 0 )
624+ fOccTPCN = 0 ;
625+
626+ float fTrackOccMeanN = hadronicRate / 5 ;
627+ float side = trk.tgl () > 0 ? 1 : 0 ;
628+ float a1pt = std::abs (trk.signed1Pt ());
629+ float a1pt2 = a1pt * a1pt;
630+ float atgl = std::abs (trk.tgl ());
631+ float mbb0R = 50 / fTPCSignal ;
632+ if (mbb0R > 1.05 )
633+ mbb0R = 1.05 ;
634+ else if (mbb0R < 0.05 )
635+ mbb0R = 0.05 ;
636+ // float mbb0R = max(0.05, min(50 / fTPCSignal, 1.05));
637+ float a1ptmbb0R = a1pt * mbb0R;
638+ float atglmbb0R = atgl * mbb0R;
639+
640+ std::vector<float > vec_occu = {fTrackOccN , fOccTPCN , fTrackOccMeanN };
641+ std::vector<float > vec_track = {mbb0R, a1pt, atgl, atglmbb0R, a1ptmbb0R, side, a1pt2};
642+
643+ float fTPCSignalN_CR0 = str_dedx_correction.fReal_fTPCSignalN (vec_occu, vec_track);
644+
645+ float mbb0R1 = 50 / (fTPCSignal / fTPCSignalN_CR0 );
646+ if (mbb0R1 > 1.05 )
647+ mbb0R1 = 1.05 ;
648+ else if (mbb0R1 < 0.05 )
649+ mbb0R1 = 0.05 ;
650+
651+ std::vector<float > vec_track1 = {mbb0R1, a1pt, atgl, atgl * mbb0R1, a1pt * mbb0R1, side, a1pt2};
652+ float fTPCSignalN_CR1 = str_dedx_correction.fReal_fTPCSignalN (vec_occu, vec_track1);
653+
654+ // change the signal used for PID
655+ tpcSignalToEvaluatePID = fTPCSignal / fTPCSignalN_CR1 ;
656+
657+ if (pidTPCopts.savedEdxsCorrected ){
658+ // populated cursor if requested or autodetected
659+ products.dEdxCorrected (tpcSignalToEvaluatePID);
660+ }
661+ }
551662
552663 const auto & bc = trk.has_collision () ? cols.iteratorAt (trk.collisionId ()).template bc_as <aod::BCsWithTimestamps>() : bcs.begin ();
553664 if (useCCDBParam && pidTPCopts.ccdbTimestamp .value == 0 && !ccdb->isCachedObjectValid (pidTPCopts.ccdbPath .value , bc.timestamp ())) { // Updating parametrisation only if the initial timestamp is 0
@@ -570,8 +681,8 @@ class pidTPCModule
570681 response->PrintAll ();
571682 }
572683
573- auto makePidTablesDefault = [&trk, &cols, &pidmults, &network_prediction, &count_tracks, &tracksForNet_size, this ](const int flagFull, auto & tableFull, const int flagTiny, auto & tableTiny, const o2::track::PID::ID pid) {
574- makePidTables (flagFull, tableFull, flagTiny, tableTiny, pid, trk. tpcSignal () , trk, cols, pidmults[trk.collisionId ()], network_prediction, count_tracks, tracksForNet_size);
684+ auto makePidTablesDefault = [&trk, &tpcSignalToEvaluatePID, & cols, &pidmults, &network_prediction, &count_tracks, &tracksForNet_size, this ](const int flagFull, auto & tableFull, const int flagTiny, auto & tableTiny, const o2::track::PID::ID pid) {
685+ makePidTables (flagFull, tableFull, flagTiny, tableTiny, pid, tpcSignalToEvaluatePID , trk, cols, pidmults[trk.collisionId ()], network_prediction, count_tracks, tracksForNet_size);
575686 };
576687
577688 makePidTablesDefault (pidTPCopts.pidFullEl , products.tablePIDFullEl , pidTPCopts.pidTinyEl , products.tablePIDTinyEl , o2::track::PID::Electron);
0 commit comments