Skip to content

Commit 3295a76

Browse files
authored
Merge 331d446 into sapling-pr-archive-ktf
2 parents 3154413 + 331d446 commit 3295a76

File tree

71 files changed

+1827
-283
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

71 files changed

+1827
-283
lines changed

Common/ML/CMakeLists.txt

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,21 @@
99
# granted to it by virtue of its status as an Intergovernmental Organization
1010
# or submit itself to any jurisdiction.
1111

12+
# Pass ORT variables as a preprocessor definition
13+
if(DEFINED ENV{ORT_ROCM_BUILD})
14+
add_compile_definitions(ORT_ROCM_BUILD=$ENV{ORT_ROCM_BUILD})
15+
endif()
16+
if(DEFINED ENV{ORT_CUDA_BUILD})
17+
add_compile_definitions(ORT_CUDA_BUILD=$ENV{ORT_CUDA_BUILD})
18+
endif()
19+
if(DEFINED ENV{ORT_MIGRAPHX_BUILD})
20+
add_compile_definitions(ORT_MIGRAPHX_BUILD=$ENV{ORT_MIGRAPHX_BUILD})
21+
endif()
22+
if(DEFINED ENV{ORT_TENSORRT_BUILD})
23+
add_compile_definitions(ORT_TENSORRT_BUILD=$ENV{ORT_TENSORRT_BUILD})
24+
endif()
25+
1226
o2_add_library(ML
13-
SOURCES src/ort_interface.cxx
27+
SOURCES src/OrtInterface.cxx
1428
TARGETVARNAME targetName
1529
PRIVATE_LINK_LIBRARIES O2::Framework ONNXRuntime::ONNXRuntime)
Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,12 @@
99
// granted to it by virtue of its status as an Intergovernmental Organization
1010
// or submit itself to any jurisdiction.
1111

12-
/// \file ort_interface.h
12+
/// \file OrtInterface.h
1313
/// \author Christian Sonnabend <[email protected]>
1414
/// \brief A header library for loading ONNX models and inferencing them on CPU and GPU
1515

16-
#ifndef O2_ML_ONNX_INTERFACE_H
17-
#define O2_ML_ONNX_INTERFACE_H
16+
#ifndef O2_ML_ORTINTERFACE_H
17+
#define O2_ML_ORTINTERFACE_H
1818

1919
// C++ and system includes
2020
#include <vector>
@@ -89,4 +89,4 @@ class OrtModel
8989

9090
} // namespace o2
9191

92-
#endif // O2_ML_ORT_INTERFACE_H
92+
#endif // O2_ML_ORTINTERFACE_H
Lines changed: 40 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,11 @@
99
// granted to it by virtue of its status as an Intergovernmental Organization
1010
// or submit itself to any jurisdiction.
1111

12-
/// \file ort_interface.cxx
12+
/// \file OrtInterface.cxx
1313
/// \author Christian Sonnabend <[email protected]>
1414
/// \brief A header library for loading ONNX models and inferencing them on CPU and GPU
1515

16-
#include "ML/ort_interface.h"
16+
#include "ML/OrtInterface.h"
1717
#include "ML/3rdparty/GPUORTFloat16.h"
1818

1919
// ONNX includes
@@ -50,29 +50,35 @@ void OrtModel::reset(std::unordered_map<std::string, std::string> optionsMap)
5050
deviceId = (optionsMap.contains("device-id") ? std::stoi(optionsMap["device-id"]) : 0);
5151
allocateDeviceMemory = (optionsMap.contains("allocate-device-memory") ? std::stoi(optionsMap["allocate-device-memory"]) : 0);
5252
intraOpNumThreads = (optionsMap.contains("intra-op-num-threads") ? std::stoi(optionsMap["intra-op-num-threads"]) : 0);
53-
loggingLevel = (optionsMap.contains("logging-level") ? std::stoi(optionsMap["logging-level"]) : 0);
53+
loggingLevel = (optionsMap.contains("logging-level") ? std::stoi(optionsMap["logging-level"]) : 2);
5454
enableProfiling = (optionsMap.contains("enable-profiling") ? std::stoi(optionsMap["enable-profiling"]) : 0);
5555
enableOptimizations = (optionsMap.contains("enable-optimizations") ? std::stoi(optionsMap["enable-optimizations"]) : 0);
5656

5757
std::string dev_mem_str = "Hip";
58-
#ifdef ORT_ROCM_BUILD
58+
#if defined(ORT_ROCM_BUILD)
59+
#if ORT_ROCM_BUILD == 1
5960
if (device == "ROCM") {
6061
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_ROCM(pImplOrt->sessionOptions, deviceId));
6162
LOG(info) << "(ORT) ROCM execution provider set";
6263
}
6364
#endif
64-
#ifdef ORT_MIGRAPHX_BUILD
65+
#endif
66+
#if defined(ORT_MIGRAPHX_BUILD)
67+
#if ORT_MIGRAPHX_BUILD == 1
6568
if (device == "MIGRAPHX") {
6669
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_MIGraphX(pImplOrt->sessionOptions, deviceId));
6770
LOG(info) << "(ORT) MIGraphX execution provider set";
6871
}
6972
#endif
70-
#ifdef ORT_CUDA_BUILD
73+
#endif
74+
#if defined(ORT_CUDA_BUILD)
75+
#if ORT_CUDA_BUILD == 1
7176
if (device == "CUDA") {
7277
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(pImplOrt->sessionOptions, deviceId));
7378
LOG(info) << "(ORT) CUDA execution provider set";
7479
dev_mem_str = "Cuda";
7580
}
81+
#endif
7682
#endif
7783

7884
if (allocateDeviceMemory) {
@@ -106,7 +112,27 @@ void OrtModel::reset(std::unordered_map<std::string, std::string> optionsMap)
106112
(pImplOrt->sessionOptions).SetGraphOptimizationLevel(GraphOptimizationLevel(enableOptimizations));
107113
(pImplOrt->sessionOptions).SetLogSeverityLevel(OrtLoggingLevel(loggingLevel));
108114

109-
pImplOrt->env = std::make_shared<Ort::Env>(OrtLoggingLevel(loggingLevel), (optionsMap["onnx-environment-name"].empty() ? "onnx_model_inference" : optionsMap["onnx-environment-name"].c_str()));
115+
pImplOrt->env = std::make_shared<Ort::Env>(
116+
OrtLoggingLevel(loggingLevel),
117+
(optionsMap["onnx-environment-name"].empty() ? "onnx_model_inference" : optionsMap["onnx-environment-name"].c_str()),
118+
// Integrate ORT logging into Fairlogger
119+
[](void* param, OrtLoggingLevel severity, const char* category, const char* logid, const char* code_location, const char* message) {
120+
if (severity == ORT_LOGGING_LEVEL_VERBOSE) {
121+
LOG(debug) << "(ORT) [" << logid << "|" << category << "|" << code_location << "]: " << message;
122+
} else if (severity == ORT_LOGGING_LEVEL_INFO) {
123+
LOG(info) << "(ORT) [" << logid << "|" << category << "|" << code_location << "]: " << message;
124+
} else if (severity == ORT_LOGGING_LEVEL_WARNING) {
125+
LOG(warning) << "(ORT) [" << logid << "|" << category << "|" << code_location << "]: " << message;
126+
} else if (severity == ORT_LOGGING_LEVEL_ERROR) {
127+
LOG(error) << "(ORT) [" << logid << "|" << category << "|" << code_location << "]: " << message;
128+
} else if (severity == ORT_LOGGING_LEVEL_FATAL) {
129+
LOG(fatal) << "(ORT) [" << logid << "|" << category << "|" << code_location << "]: " << message;
130+
} else {
131+
LOG(info) << "(ORT) [" << logid << "|" << category << "|" << code_location << "]: " << message;
132+
}
133+
},
134+
(void*)3);
135+
(pImplOrt->env)->DisableTelemetryEvents(); // Disable telemetry events
110136
pImplOrt->session = std::make_shared<Ort::Session>(*(pImplOrt->env), modelPath.c_str(), pImplOrt->sessionOptions);
111137

112138
for (size_t i = 0; i < (pImplOrt->session)->GetInputCount(); ++i) {
@@ -130,16 +156,14 @@ void OrtModel::reset(std::unordered_map<std::string, std::string> optionsMap)
130156
[&](const std::string& str) { return str.c_str(); });
131157

132158
// Print names
133-
if (loggingLevel > 1) {
134-
LOG(info) << "Input Nodes:";
135-
for (size_t i = 0; i < mInputNames.size(); i++) {
136-
LOG(info) << "\t" << mInputNames[i] << " : " << printShape(mInputShapes[i]);
137-
}
159+
LOG(info) << "\tInput Nodes:";
160+
for (size_t i = 0; i < mInputNames.size(); i++) {
161+
LOG(info) << "\t\t" << mInputNames[i] << " : " << printShape(mInputShapes[i]);
162+
}
138163

139-
LOG(info) << "Output Nodes:";
140-
for (size_t i = 0; i < mOutputNames.size(); i++) {
141-
LOG(info) << "\t" << mOutputNames[i] << " : " << printShape(mOutputShapes[i]);
142-
}
164+
LOG(info) << "\tOutput Nodes:";
165+
for (size_t i = 0; i < mOutputNames.size(); i++) {
166+
LOG(info) << "\t\t" << mOutputNames[i] << " : " << printShape(mOutputShapes[i]);
143167
}
144168
}
145169

DataFormats/Detectors/CTP/src/Scalers.cxx

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -674,35 +674,35 @@ std::pair<double, double> CTPRunScalers::getRate(uint32_t orbit, int classindex,
674674

675675
// then we can use binary search to find the right entries
676676
auto iter = std::lower_bound(mScalerRecordO2.begin(), mScalerRecordO2.end(), orbit, [&](CTPScalerRecordO2 const& a, uint32_t value) { return a.intRecord.orbit <= value; });
677-
auto nextindex = iter - mScalerRecordO2.begin(); // this points to the first index that has orbit greater or equal to given orbit
677+
auto nextindex = std::distance(mScalerRecordO2.begin(), iter); // this points to the first index that has orbit greater or equal to given orbit
678678

679679
auto calcRate = [&](auto index1, auto index2) -> double {
680-
auto next = &mScalerRecordO2[index2];
681-
auto prev = &mScalerRecordO2[index1];
682-
auto timedelta = (next->intRecord.orbit - prev->intRecord.orbit) * 88.e-6; // converts orbits into time
680+
const auto& snext = mScalerRecordO2[index2];
681+
const auto& sprev = mScalerRecordO2[index1];
682+
auto timedelta = (snext.intRecord.orbit - sprev.intRecord.orbit) * 88.e-6; // converts orbits into time
683683
if (type < 7) {
684-
auto s0 = &(prev->scalers[classindex]); // type CTPScalerO2*
685-
auto s1 = &(next->scalers[classindex]);
684+
const auto& s0 = sprev.scalers[classindex]; // type CTPScalerO2*
685+
const auto& s1 = snext.scalers[classindex];
686686
switch (type) {
687687
case 1:
688-
return (s1->lmBefore - s0->lmBefore) / timedelta;
688+
return (s1.lmBefore - s0.lmBefore) / timedelta;
689689
case 2:
690-
return (s1->lmAfter - s0->lmAfter) / timedelta;
690+
return (s1.lmAfter - s0.lmAfter) / timedelta;
691691
case 3:
692-
return (s1->l0Before - s0->l0Before) / timedelta;
692+
return (s1.l0Before - s0.l0Before) / timedelta;
693693
case 4:
694-
return (s1->l0After - s0->l0After) / timedelta;
694+
return (s1.l0After - s0.l0After) / timedelta;
695695
case 5:
696-
return (s1->l1Before - s0->l1Before) / timedelta;
696+
return (s1.l1Before - s0.l1Before) / timedelta;
697697
case 6:
698-
return (s1->l1After - s0->l1After) / timedelta;
698+
return (s1.l1After - s0.l1After) / timedelta;
699699
default:
700700
LOG(error) << "Wrong type:" << type;
701701
return -1; // wrong type
702702
}
703703
} else if (type == 7) {
704-
auto s0 = &(prev->scalersInps[classindex]); // type CTPScalerO2*
705-
auto s1 = &(next->scalersInps[classindex]);
704+
auto s0 = sprev.scalersInps[classindex]; // type CTPScalerO2*
705+
auto s1 = snext.scalersInps[classindex];
706706
return (s1 - s0) / timedelta;
707707
} else {
708708
LOG(error) << "Wrong type:" << type;
@@ -738,37 +738,37 @@ std::pair<double, double> CTPRunScalers::getRateGivenT(double timestamp, int cla
738738
// this points to the first index that has orbit greater to given orbit;
739739
// If this is 0, it means that the above condition was false from the beginning, basically saying that the timestamp is below any of the ScalerRecords' orbits.
740740
// If this is mScalerRecordO2.size(), it means mScalerRecordO2.end() was returned, condition was met throughout all ScalerRecords, basically saying the timestamp is above any of the ScalarRecordss orbits.
741-
auto nextindex = iter - mScalerRecordO2.begin();
741+
auto nextindex = std::distance(mScalerRecordO2.begin(), iter);
742742

743743
auto calcRate = [&](auto index1, auto index2) -> double {
744-
auto next = &mScalerRecordO2[index2];
745-
auto prev = &mScalerRecordO2[index1];
746-
auto timedelta = (next->intRecord.orbit - prev->intRecord.orbit) * 88.e-6; // converts orbits into time
744+
const auto& snext = mScalerRecordO2[index2];
745+
const auto& sprev = mScalerRecordO2[index1];
746+
auto timedelta = (snext.intRecord.orbit - sprev.intRecord.orbit) * 88.e-6; // converts orbits into time
747747
// std::cout << "timedelta:" << timedelta << std::endl;
748748
if (type < 7) {
749-
auto s0 = &(prev->scalers[classindex]); // type CTPScalerO2*
750-
auto s1 = &(next->scalers[classindex]);
749+
const auto& s0 = sprev.scalers[classindex]; // type CTPScalerO2*
750+
const auto& s1 = snext.scalers[classindex];
751751
switch (type) {
752752
case 1:
753-
return (s1->lmBefore - s0->lmBefore) / timedelta;
753+
return (s1.lmBefore - s0.lmBefore) / timedelta;
754754
case 2:
755-
return (s1->lmAfter - s0->lmAfter) / timedelta;
755+
return (s1.lmAfter - s0.lmAfter) / timedelta;
756756
case 3:
757-
return (s1->l0Before - s0->l0Before) / timedelta;
757+
return (s1.l0Before - s0.l0Before) / timedelta;
758758
case 4:
759-
return (s1->l0After - s0->l0After) / timedelta;
759+
return (s1.l0After - s0.l0After) / timedelta;
760760
case 5:
761-
return (s1->l1Before - s0->l1Before) / timedelta;
761+
return (s1.l1Before - s0.l1Before) / timedelta;
762762
case 6:
763-
return (s1->l1After - s0->l1After) / timedelta;
763+
return (s1.l1After - s0.l1After) / timedelta;
764764
default:
765765
LOG(error) << "Wrong type:" << type;
766766
return -1; // wrong type
767767
}
768768
} else if (type == 7) {
769769
// LOG(info) << "doing input:";
770-
auto s0 = prev->scalersInps[classindex]; // type CTPScalerO2*
771-
auto s1 = next->scalersInps[classindex];
770+
auto s0 = sprev.scalersInps[classindex]; // type CTPScalerO2*
771+
auto s1 = snext.scalersInps[classindex];
772772
return (s1 - s0) / timedelta;
773773
} else {
774774
LOG(error) << "Wrong type:" << type;

DataFormats/Detectors/ZDC/include/DataFormatsZDC/BCData.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,9 @@ struct BCData {
5555
o2::dataformats::RangeRefComp<6> ref;
5656
o2::InteractionRecord ir;
5757
std::array<uint16_t, NModules> moduleTriggers{};
58+
// N.B. channels and triggers have geographical addressing (0x1 << (NChPerModule * im + ic)
5859
uint32_t channels = 0; // pattern of channels it refers to
59-
uint32_t triggers = 0; // pattern of triggered channels (not necessarily stored) in this BC
60+
uint32_t triggers = 0; // pattern of triggered channels (not necessarily stored) in this BC (i.e. with Hit bit on)
6061
uint8_t ext_triggers = 0; // pattern of ALICE triggers
6162

6263
BCData() = default;

DataFormats/Reconstruction/include/ReconstructionDataFormats/MatchInfoTOF.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ class MatchInfoTOF
2828
using GTrackID = o2::dataformats::GlobalTrackID;
2929

3030
public:
31-
MatchInfoTOF(int idLocal, int idxTOFCl, double time, float chi2, o2::track::TrackLTIntegral trkIntLT, GTrackID idxTrack, float dt = 0, float z = 0, float dx = 0, float dz = 0) : mIdLocal(idLocal), mIdxTOFCl(idxTOFCl), mSignal(time), mChi2(chi2), mIntLT(trkIntLT), mIdxTrack(idxTrack), mDeltaT(dt), mZatTOF(z), mDXatTOF(dx), mDZatTOF(dz){};
31+
MatchInfoTOF(int idLocal, int idxTOFCl, double time, float chi2, o2::track::TrackLTIntegral trkIntLT, GTrackID idxTrack, float dt = 0, float z = 0, float dx = 0, float dz = 0, float dy = 0) : mIdLocal(idLocal), mIdxTOFCl(idxTOFCl), mSignal(time), mChi2(chi2), mIntLT(trkIntLT), mIdxTrack(idxTrack), mDeltaT(dt), mZatTOF(z), mDXatTOF(dx), mDZatTOF(dz), mDYatTOF(dy){};
3232
MatchInfoTOF() = default;
3333
void setIdxTOFCl(int index) { mIdxTOFCl = index; }
3434
void setIdxTrack(GTrackID index) { mIdxTrack = index; }
@@ -59,6 +59,8 @@ class MatchInfoTOF
5959
float getDZatTOF() const { return mDZatTOF; }
6060
void setDXatTOF(float val) { mDXatTOF = val; }
6161
float getDXatTOF() const { return mDXatTOF; }
62+
void setDYatTOF(float val) { mDYatTOF = val; }
63+
float getDYatTOF() const { return mDYatTOF; }
6264
void setSignal(double time) { mSignal = time; }
6365
double getSignal() const { return mSignal; }
6466

@@ -78,6 +80,7 @@ class MatchInfoTOF
7880
float mZatTOF = 0.0; ///< Z position at TOF
7981
float mDXatTOF = 0.0; ///< DX position at TOF
8082
float mDZatTOF = 0.0; ///< DZ position at TOF
83+
float mDYatTOF = 0.0; ///< DY position at TOF
8184
float mDeltaT = 0.0; ///< tTOF - TPC (microsec)
8285
double mSignal = 0.0; ///< TOF time in ps
8386
float mVz = 0.0; ///< Vz from TOF match

DataFormats/Reconstruction/include/ReconstructionDataFormats/MatchInfoTOFReco.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ class MatchInfoTOFReco : public MatchInfoTOF
3535
ITSTPCTRD,
3636
SIZEALL };
3737

38-
MatchInfoTOFReco(int idLocal, int idxTOFCl, double time, float chi2, o2::track::TrackLTIntegral trkIntLT, GTrackID idxTrack, TrackType trkType, float dt = 0, float z = 0, float dx = 0, float dz = 0) : MatchInfoTOF(idLocal, idxTOFCl, time, chi2, trkIntLT, idxTrack, dt, z, dx, dz), mTrackType(trkType){};
38+
MatchInfoTOFReco(int idLocal, int idxTOFCl, double time, float chi2, o2::track::TrackLTIntegral trkIntLT, GTrackID idxTrack, TrackType trkType, float dt = 0, float z = 0, float dx = 0, float dz = 0, float dy = 0) : MatchInfoTOF(idLocal, idxTOFCl, time, chi2, trkIntLT, idxTrack, dt, z, dx, dz, dy), mTrackType(trkType){};
3939

4040
MatchInfoTOFReco() = default;
4141

DataFormats/Reconstruction/include/ReconstructionDataFormats/TrackParametrizationWithError.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,8 @@ class TrackParametrizationWithError : public TrackParametrization<value_T>
100100

101101
template <typename T>
102102
GPUd() value_t getPredictedChi2(const BaseCluster<T>& p) const;
103+
template <typename T>
104+
GPUd() value_t getPredictedChi2Quiet(const BaseCluster<T>& p) const;
103105

104106
GPUd() void buildCombinedCovMatrix(const TrackParametrizationWithError& rhs, MatrixDSym5& cov) const;
105107
GPUd() value_t getPredictedChi2(const TrackParametrizationWithError& rhs, MatrixDSym5& covToSet) const;
@@ -315,6 +317,16 @@ GPUdi() auto TrackParametrizationWithError<value_T>::getPredictedChi2(const Base
315317
return getPredictedChi2(pyz, cov);
316318
}
317319

320+
//__________________________________________________________________________
321+
template <typename value_T>
322+
template <typename T>
323+
GPUdi() auto TrackParametrizationWithError<value_T>::getPredictedChi2Quiet(const BaseCluster<T>& p) const -> value_t
324+
{
325+
const dim2_t pyz = {value_T(p.getY()), value_T(p.getZ())};
326+
const dim3_t cov = {value_T(p.getSigmaY2()), value_T(p.getSigmaYZ()), value_T(p.getSigmaZ2())};
327+
return getPredictedChi2Quiet(pyz, cov);
328+
}
329+
318330
//______________________________________________
319331
template <typename value_T>
320332
GPUdi() auto TrackParametrizationWithError<value_T>::getPredictedChi2(const dim2_t& p, const dim3_t& cov) const -> value_t

Detectors/GlobalTracking/include/GlobalTracking/MatchTOF.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,8 @@ class MatchTOF
220220
void BestMatchesHP(std::vector<o2::dataformats::MatchInfoTOFReco>& matchedTracksPairs, std::vector<o2::dataformats::MatchInfoTOF>* matchedTracks, std::vector<int>* matchedTracksIndex, int* matchedClustersIndex, const gsl::span<const o2::ft0::RecPoints>& FITRecPoints, const std::vector<Cluster>& TOFClusWork, std::vector<o2::dataformats::CalibInfoTOF>& CalibInfoTOF, unsigned long Timestamp, bool MCTruthON, const o2::dataformats::MCTruthContainer<o2::MCCompLabel>* TOFClusLabels, const std::vector<o2::MCCompLabel>* TracksLblWork, std::vector<o2::MCCompLabel>* OutTOFLabels);
221221
bool propagateToRefX(o2::track::TrackParCov& trc, float xRef /*in cm*/, float stepInCm /*in cm*/, o2::track::TrackLTIntegral& intLT);
222222
bool propagateToRefXWithoutCov(const o2::track::TrackParCov& trc, float xRef /*in cm*/, float stepInCm /*in cm*/, float bz);
223+
bool propagateToRefXWithoutCov(const o2::track::TrackParCov& trc, float xRef /*in cm*/, float stepInCm /*in cm*/, float bz, float pos[3]);
224+
void updateTL(o2::track::TrackLTIntegral& intLT, float deltal);
223225

224226
void updateTimeDependentParams();
225227

0 commit comments

Comments
 (0)