Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions HeterogeneousCore/SonicTriton/BuildFile.xml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
<use name="FWCore/Utilities"/>
<use name="FWCore/ParameterSet"/>
<use name="FWCore/MessageLogger"/>
<use name="DataFormats/Common"/>
<use name="HeterogeneousCore/SonicCore"/>
<use name="triton-inference-server"/>
<use name="protobuf"/>
Expand Down
41 changes: 41 additions & 0 deletions HeterogeneousCore/SonicTriton/interface/TritonConverterBase.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#ifndef HeterogeneousCore_SonicTriton_TritonConverterBase
#define HeterogeneousCore_SonicTriton_TritonConverterBase

#include "FWCore/ParameterSet/interface/ParameterSet.h"
#include "DataFormats/Common/interface/Handle.h"

#include <string>

template <typename DT>
class TritonConverterBase {
//class needs to be templated since the convert functions require the data type, but need to also be virtual, and virtual member function templates are not allowed in C++
public:
TritonConverterBase(const std::string convName)
: converterName_(convName), byteSize_(sizeof(DT)) {}
TritonConverterBase(const std::string convName, size_t byteSize)
: converterName_(convName), byteSize_(byteSize) {}
TritonConverterBase(const TritonConverterBase&) = delete;
virtual ~TritonConverterBase() = default;
TritonConverterBase& operator=(const TritonConverterBase&) = delete;

virtual const uint8_t* convertIn (const DT* in) const = 0;
virtual const DT* convertOut (const uint8_t* in) const = 0;

const int64_t byteSize() const { return byteSize_; }

const std::string& name() const { return converterName_; }

private:
const std::string converterName_;
const int64_t byteSize_;
};

#include "FWCore/PluginManager/interface/PluginFactory.h"

template <typename DT>
using TritonConverterFactory = edmplugin::PluginFactory<TritonConverterBase<DT>*()>;

#define DEFINE_TRITON_CONVERTER(input, type, name) DEFINE_EDM_PLUGIN(TritonConverterFactory<input>, type, name)
#define DEFINE_TRITON_CONVERTER_SIMPLE(input, type) DEFINE_EDM_PLUGIN(TritonConverterFactory<input>, type, #type)

#endif
15 changes: 15 additions & 0 deletions HeterogeneousCore/SonicTriton/interface/TritonData.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
#include "FWCore/Utilities/interface/Exception.h"
#include "FWCore/Utilities/interface/Span.h"

#include "FWCore/PluginManager/interface/PluginFactory.h"
#include "HeterogeneousCore/SonicTriton/interface/TritonConverterBase.h"

#include <vector>
#include <string>
#include <unordered_map>
Expand Down Expand Up @@ -40,6 +43,16 @@ class TritonData {
bool setShape(const ShapeType& newShape) { return setShape(newShape, true); }
bool setShape(unsigned loc, int64_t val) { return setShape(loc, val, true); }

void setConverterParams(const edm::ParameterSet& conf) {
converterName_ = conf.getParameter<std::string>("converterName");
}
template <typename DT>
void createConverter() const {
using ConverterType = std::shared_ptr<TritonConverterBase<DT>>;
//this contruction catches bad any_cast without throwing std exception
if (auto ptr = std::any_cast<ConverterType>(&converter_)) {} else { converter_ = ConverterType(TritonConverterFactory<DT>::get()->create(converterName_)); }
}

//io accessors
template <typename DT>
void toServer(std::shared_ptr<TritonInput<DT>> ptr);
Expand Down Expand Up @@ -93,6 +106,8 @@ class TritonData {
int64_t byteSize_;
std::any holder_;
std::shared_ptr<Result> result_;
mutable std::any converter_;
std::string converterName_;
};

using TritonInputData = TritonData<nvidia::inferenceserver::client::InferInput>;
Expand Down
5 changes: 5 additions & 0 deletions HeterogeneousCore/SonicTriton/plugins/BuildFile.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
<library name="HeterogeneousCoreSonicTritonPlugins_converters" file="converters/*.cc">
<use name="HeterogeneousCore/SonicTriton"/>
<use name="hls"/>
<flags EDM_PLUGIN="1"/>
</library>
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#include "HeterogeneousCore/SonicTriton/interface/TritonConverterBase.h"

#include <string>
#include "ap_fixed.h"

template <int I>
class FloatApFixed16Converter : public TritonConverterBase<float> {
public:
FloatApFixed16Converter() : TritonConverterBase<float>("FloatApFixed16F"+std::to_string(I)+"Converter", 2) {}

const uint8_t* convertIn(const float* in) const {
return reinterpret_cast<const uint8_t*>((this->makeVecIn(in)).data());
}
const float* convertOut(const uint8_t* in) const {
return (this->makeVecOut(reinterpret_cast<const ap_fixed<16, I>*>(in))).data();
}

private:
std::vector<ap_fixed<16, I>> makeVecIn(const float* in) const {
unsigned int nfeat = sizeof(in) / sizeof(float);
std::vector<ap_fixed<16, I>> temp_storage(in, in + nfeat);
return temp_storage;
}

std::vector<float> makeVecOut(const ap_fixed<16, I>* in) const {
unsigned int nfeat = sizeof(in) / sizeof(ap_fixed<16, I>);
std::vector<float> temp_storage(in, in + nfeat);
return temp_storage;
}
};

DEFINE_TRITON_CONVERTER(float, FloatApFixed16Converter<6>, "FloatApFixed16F6Converter");
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#include "HeterogeneousCore/SonicTriton/interface/TritonConverterBase.h"

class FloatStandardConverter : public TritonConverterBase<float> {
public:
FloatStandardConverter() : TritonConverterBase<float>("FloatStandardConverter") {}

const uint8_t* convertIn(const float* in) const { return reinterpret_cast<const uint8_t*>(in); }
const float* convertOut(const uint8_t* in) const { return reinterpret_cast<const float*>(in); }
};

DEFINE_TRITON_CONVERTER_SIMPLE(float, FloatStandardConverter);
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#include "HeterogeneousCore/SonicTriton/interface/TritonConverterBase.h"

class Int64StandardConverter : public TritonConverterBase<int64_t> {
public:
Int64StandardConverter() : TritonConverterBase<int64_t>("Int64StandardConverter") {}

const uint8_t* convertIn(const int64_t* in) const { return reinterpret_cast<const uint8_t*>(in); }
const int64_t* convertOut(const uint8_t* in) const { return reinterpret_cast<const int64_t*>(in); }
};

DEFINE_TRITON_CONVERTER_SIMPLE(int64_t, Int64StandardConverter);
7 changes: 7 additions & 0 deletions HeterogeneousCore/SonicTriton/src/TritonClient.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ TritonClient::TritonClient(const edm::ParameterSet& params)
if (!msg_str.empty())
throw cms::Exception("ModelErrors") << msg_str;

const edm::ParameterSet& converterDefs = params.getParameterSet("converterDefinition");
//setup input map
std::stringstream io_msg;
if (verbose_)
Expand All @@ -90,6 +91,7 @@ TritonClient::TritonClient(const edm::ParameterSet& params)
auto [curr_itr, success] = input_.emplace(
std::piecewise_construct, std::forward_as_tuple(iname), std::forward_as_tuple(iname, nicInput, noBatch_));
auto& curr_input = curr_itr->second;
curr_input.setConverterParams(converterDefs);
inputsTriton_.push_back(curr_input.data());
if (verbose_) {
io_msg << " " << iname << " (" << curr_input.dname() << ", " << curr_input.byteSize()
Expand All @@ -113,6 +115,7 @@ TritonClient::TritonClient(const edm::ParameterSet& params)
auto [curr_itr, success] = output_.emplace(
std::piecewise_construct, std::forward_as_tuple(oname), std::forward_as_tuple(oname, nicOutput, noBatch_));
auto& curr_output = curr_itr->second;
curr_output.setConverterParams(converterDefs);
outputsTriton_.push_back(curr_output.data());
if (verbose_) {
io_msg << " " << oname << " (" << curr_output.dname() << ", " << curr_output.byteSize()
Expand Down Expand Up @@ -336,10 +339,14 @@ inference::ModelStatistics TritonClient::getServerSideStatus() const {

//for fillDescriptions
void TritonClient::fillPSetDescription(edm::ParameterSetDescription& iDesc) {
edm::ParameterSetDescription descConverter;
fillBasePSetDescription(descConverter);
descConverter.add<std::string>("converterName");
edm::ParameterSetDescription descClient;
fillBasePSetDescription(descClient);
descClient.add<std::string>("modelName");
descClient.add<std::string>("modelVersion", "");
descClient.add<edm::ParameterSetDescription>("converterDefinition", descConverter);
//server parameters should not affect the physics results
descClient.addUntracked<unsigned>("batchSize");
descClient.addUntracked<std::string>("address");
Expand Down
15 changes: 10 additions & 5 deletions HeterogeneousCore/SonicTriton/src/TritonData.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "HeterogeneousCore/SonicTriton/interface/TritonData.h"
#include "HeterogeneousCore/SonicTriton/interface/triton_utils.h"
#include "HeterogeneousCore/SonicTriton/interface/TritonConverterBase.h"
#include "FWCore/MessageLogger/interface/MessageLogger.h"

#include "model_config.pb.h"
Expand Down Expand Up @@ -116,14 +117,16 @@ void TritonInputData::toServer(std::shared_ptr<TritonInput<DT>> ptr) {
//shape must be specified for variable dims or if batch size changes
data_->SetShape(fullShape_);

if (byteSize_ != sizeof(DT))
throw cms::Exception("TritonDataError") << name_ << " input(): inconsistent byte size " << sizeof(DT)
createConverter<DT>();

if (byteSize_ != std::any_cast<std::shared_ptr<TritonConverterBase<DT>>>(converter_)->byteSize())
throw cms::Exception("TritonDataError") << name_ << " input(): inconsistent byte size " << std::any_cast<std::shared_ptr<TritonConverterBase<DT>>>(converter_)->byteSize()
<< " (should be " << byteSize_ << " for " << dname_ << ")";

int64_t nInput = sizeShape();
for (unsigned i0 = 0; i0 < batchSize_; ++i0) {
const DT* arr = data_in[i0].data();
triton_utils::throwIfError(data_->AppendRaw(reinterpret_cast<const uint8_t*>(arr), nInput * byteSize_),
triton_utils::throwIfError(data_->AppendRaw(std::any_cast<std::shared_ptr<TritonConverterBase<DT>>>(converter_)->convertIn(arr), nInput * byteSize_),
name_ + " input(): unable to set data for batch entry " + std::to_string(i0));
}

Expand All @@ -138,6 +141,8 @@ TritonOutput<DT> TritonOutputData::fromServer() const {
throw cms::Exception("TritonDataError") << name_ << " output(): missing result";
}

createConverter<DT>();

if (byteSize_ != sizeof(DT)) {
throw cms::Exception("TritonDataError") << name_ << " output(): inconsistent byte size " << sizeof(DT)
<< " (should be " << byteSize_ << " for " << dname_ << ")";
Expand All @@ -147,14 +152,14 @@ TritonOutput<DT> TritonOutputData::fromServer() const {
TritonOutput<DT> dataOut;
const uint8_t* r0;
size_t contentByteSize;
size_t expectedContentByteSize = nOutput * byteSize_ * batchSize_;
size_t expectedContentByteSize = nOutput * std::any_cast<std::shared_ptr<TritonConverterBase<DT>>>(converter_)->byteSize() * batchSize_;
triton_utils::throwIfError(result_->RawData(name_, &r0, &contentByteSize), "output(): unable to get raw");
if (contentByteSize != expectedContentByteSize) {
throw cms::Exception("TritonDataError") << name_ << " output(): unexpected content byte size " << contentByteSize
<< " (expected " << expectedContentByteSize << ")";
}

const DT* r1 = reinterpret_cast<const DT*>(r0);
const DT* r1 = std::any_cast<std::shared_ptr<TritonConverterBase<DT>>>(converter_)->convertOut(r0);
dataOut.reserve(batchSize_);
for (unsigned i0 = 0; i0 < batchSize_; ++i0) {
auto offset = i0 * nOutput;
Expand Down
4 changes: 4 additions & 0 deletions HeterogeneousCore/SonicTriton/src/pluginFactories.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#include "HeterogeneousCore/SonicTriton/interface/TritonConverterBase.h"

EDM_REGISTER_PLUGINFACTORY(TritonConverterFactory<float>, "TritonConverterFloatFactory");
EDM_REGISTER_PLUGINFACTORY(TritonConverterFactory<int64_t>, "TritonConverterInt64Factory");
3 changes: 3 additions & 0 deletions HeterogeneousCore/SonicTriton/test/tritonTest_cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@
modelVersion = cms.string(""),
verbose = cms.untracked.bool(options.verbose),
allowedTries = cms.untracked.uint32(0),
converterDefinition = cms.PSet(
converterName = cms.string("FloatStandardConverter"),
),
)
)
if options.producer=="TritonImageProducer":
Expand Down