Skip to content

Commit 1d620f2

Browse files
[NN Clusterizer] CCDB fetching within reco workflow (AliceO2Group#14841)
* Improve GPU filling kernel speed * Adjusting parameter bounds and additional GPU kernel optimizations * Adding back if statement for early exit * const'ing + fixing CPU kernel * Remiving print statements * Fixing CI build issue * Working version of NN CCDB fetching and loading to file * Cleanup * Please consider the following formatting changes * Using char* buffer for model loading * Please consider the following formatting changes * Bug-fix * Working version of CCDB fetching and loading into ROOT class of std::vector<char> * Please consider the following formatting changes * Disable dumpToFile by default * Moving macro, adding o2-test --------- Co-authored-by: ALICE Action Bot <[email protected]>
1 parent f9f3798 commit 1d620f2

File tree

19 files changed

+459
-146
lines changed

19 files changed

+459
-146
lines changed

Common/ML/include/ML/OrtInterface.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ class OrtModel
5151
void initOptions(std::unordered_map<std::string, std::string> optionsMap);
5252
void initEnvironment();
5353
void initSession();
54+
void initSessionFromBuffer(const char* buffer, size_t bufferSize);
5455
void memoryOnDevice(int32_t = 0);
5556
bool isInitialized() { return mInitialized; }
5657
void resetSession();

Common/ML/src/OrtInterface.cxx

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,24 @@ void OrtModel::initEnvironment()
138138
(mPImplOrt->env)->DisableTelemetryEvents(); // Disable telemetry events
139139
}
140140

141+
void OrtModel::initSessionFromBuffer(const char* buffer, size_t bufferSize)
142+
{
143+
mPImplOrt->sessionOptions.AddConfigEntry("session.load_model_format", "ONNX");
144+
mPImplOrt->sessionOptions.AddConfigEntry("session.use_ort_model_bytes_directly", "1");
145+
146+
mPImplOrt->session = std::make_unique<Ort::Session>(*mPImplOrt->env,
147+
buffer,
148+
bufferSize,
149+
mPImplOrt->sessionOptions);
150+
mPImplOrt->ioBinding = std::make_unique<Ort::IoBinding>(*mPImplOrt->session);
151+
152+
setIO();
153+
154+
if (mLoggingLevel < 2) {
155+
LOG(info) << "(ORT) Model loaded successfully from buffer! (inputs: " << printShape(mInputShapes, mInputNames) << ", outputs: " << printShape(mOutputShapes, mInputNames) << ")";
156+
}
157+
}
158+
141159
void OrtModel::initSession()
142160
{
143161
if (mAllocateDeviceMemory) {

Detectors/TPC/base/test/testTPCCDBInterface.cxx

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222

2323
// o2 includes
2424
#include "TPCBase/CDBInterface.h"
25-
#include "TPCBase/CDBInterface.h"
2625
#include "TPCBase/CalArray.h"
2726
#include "TPCBase/CalDet.h"
2827
#include "TPCBase/Mapper.h"

Detectors/TPC/calibration/CMakeLists.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ o2_add_library(TPCCalibration
2525
src/CalibPadGainTracksBase.cxx
2626
src/CalibLaserTracks.cxx
2727
src/LaserTracksCalibrator.cxx
28-
src/NeuralNetworkClusterizer.cxx
2928
src/SACDecoder.cxx
3029
src/IDCAverageGroup.cxx
3130
src/IDCAverageGroupBase.cxx
@@ -84,7 +83,6 @@ o2_target_root_dictionary(TPCCalibration
8483
include/TPCCalibration/FastHisto.h
8584
include/TPCCalibration/CalibLaserTracks.h
8685
include/TPCCalibration/LaserTracksCalibrator.h
87-
include/TPCCalibration/NeuralNetworkClusterizer.h
8886
include/TPCCalibration/SACDecoder.h
8987
include/TPCCalibration/IDCAverageGroup.h
9088
include/TPCCalibration/IDCAverageGroupBase.h

Detectors/TPC/calibration/include/TPCCalibration/NeuralNetworkClusterizer.h

Lines changed: 0 additions & 38 deletions
This file was deleted.

Detectors/TPC/calibration/src/NeuralNetworkClusterizer.cxx

Lines changed: 0 additions & 48 deletions
This file was deleted.

GPU/GPUTracking/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,7 @@ set(SRCS_DATATYPES
209209
DataTypes/TPCPadBitMap.cxx
210210
DataTypes/TPCZSLinkMapping.cxx
211211
DataTypes/CalibdEdxContainer.cxx
212+
DataTypes/ORTRootSerializer.cxx
212213
DataTypes/CalibdEdxTrackTopologyPol.cxx
213214
DataTypes/CalibdEdxTrackTopologySpline.cxx
214215
DataTypes/GPUTRDTrackO2.cxx)

GPU/GPUTracking/DataTypes/GPUDataTypes.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ class Cluster;
8585
namespace tpc
8686
{
8787
class CalibdEdxContainer;
88+
class ORTRootSerializer;
8889
} // namespace tpc
8990
} // namespace o2
9091

@@ -182,6 +183,9 @@ struct GPUCalibObjectsTemplate { // use only pointers on PODs or flat objects he
182183
typename S<o2::tpc::CalibdEdxContainer>::type* dEdxCalibContainer = nullptr;
183184
typename S<o2::base::PropagatorImpl<float>>::type* o2Propagator = nullptr;
184185
typename S<o2::itsmft::TopologyDictionary>::type* itsPatternDict = nullptr;
186+
187+
// NN clusterizer objects
188+
typename S<o2::tpc::ORTRootSerializer>::type* nnClusterizerNetworks[3] = {nullptr, nullptr, nullptr};
185189
};
186190
typedef GPUCalibObjectsTemplate<DefaultPtr> GPUCalibObjects; // NOTE: These 2 must have identical layout since they are memcopied
187191
typedef GPUCalibObjectsTemplate<ConstPtr> GPUCalibObjectsConst;
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
// Copyright 2019-2020 CERN and copyright holders of ALICE O2.
2+
// See https://alice-o2.web.cern.ch/copyright for details of the copyright holders.
3+
// All rights not expressly granted are reserved.
4+
//
5+
// This software is distributed under the terms of the GNU General Public
6+
// License v3 (GPL Version 3), copied verbatim in the file "COPYING".
7+
//
8+
// In applying this license CERN does not waive the privileges and immunities
9+
// granted to it by virtue of its status as an Intergovernmental Organization
10+
// or submit itself to any jurisdiction.
11+
12+
/// \file ORTRootSerializer.cxx
13+
/// \author Christian Sonnabend <[email protected]>
14+
15+
#include "ORTRootSerializer.h"
16+
#include <cstring>
17+
18+
using namespace o2::tpc;
19+
20+
/// Initialize the serialization from a char* buffer containing the model
21+
void ORTRootSerializer::setOnnxModel(const char* onnxModel, uint32_t size)
22+
{
23+
mModelBuffer.resize(size);
24+
std::memcpy(mModelBuffer.data(), onnxModel, size);
25+
}
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
// Copyright 2019-2020 CERN and copyright holders of ALICE O2.
2+
// See https://alice-o2.web.cern.ch/copyright for details of the copyright holders.
3+
// All rights not expressly granted are reserved.
4+
//
5+
// This software is distributed under the terms of the GNU General Public
6+
// License v3 (GPL Version 3), copied verbatim in the file "COPYING".
7+
//
8+
// In applying this license CERN does not waive the privileges and immunities
9+
// granted to it by virtue of its status as an Intergovernmental Organization
10+
// or submit itself to any jurisdiction.
11+
12+
/// \file ORTRootSerializer.h
13+
/// \brief Class to serialize ONNX objects for ROOT snapshots of CCDB objects at runtime
14+
/// \author Christian Sonnabend <[email protected]>
15+
16+
#ifndef ALICEO2_TPC_ORTROOTSERIALIZER_H_
17+
#define ALICEO2_TPC_ORTROOTSERIALIZER_H_
18+
19+
#include "GPUCommonRtypes.h"
20+
#include <vector>
21+
#include <string>
22+
23+
namespace o2::tpc
24+
{
25+
26+
class ORTRootSerializer
27+
{
28+
public:
29+
ORTRootSerializer() = default;
30+
~ORTRootSerializer() = default;
31+
32+
void setOnnxModel(const char* onnxModel, uint32_t size);
33+
const char* getONNXModel() const { return mModelBuffer.data(); }
34+
uint32_t getONNXModelSize() const { return static_cast<uint32_t>(mModelBuffer.size()); }
35+
36+
private:
37+
std::vector<char> mModelBuffer; ///< buffer for serialization
38+
ClassDefNV(ORTRootSerializer, 1);
39+
};
40+
41+
} // namespace o2::tpc
42+
43+
#endif // ALICEO2_TPC_ORTROOTSERIALIZER_H_

0 commit comments

Comments
 (0)