Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions HeterogeneousCore/SonicTriton/interface/TritonData.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@ class TritonData {
converterName_ = conf.getParameter<std::string>("converterName");
}
template <typename DT>
std::unique_ptr<TritonConverterBase<DT>> createConverter() const { return TritonConverterFactory<DT>::get()->create(converterName_); }
void createConverter() const {
if (!converter_.has_value()) converter_ = std::shared_ptr<TritonConverterBase<DT>>(TritonConverterFactory<DT>::get()->create(converterName_));
}

//io accessors
template <typename DT>
Expand Down Expand Up @@ -102,7 +104,7 @@ class TritonData {
int64_t byteSize_;
std::any holder_;
std::shared_ptr<Result> result_;
std::any converter_;
mutable std::any converter_;
std::string converterName_;
};

Expand Down
21 changes: 8 additions & 13 deletions HeterogeneousCore/SonicTriton/src/TritonData.cc
Original file line number Diff line number Diff line change
Expand Up @@ -117,16 +117,16 @@ void TritonInputData::toServer(std::shared_ptr<TritonInput<DT>> ptr) {
//shape must be specified for variable dims or if batch size changes
data_->SetShape(fullShape_);

std::unique_ptr<TritonConverterBase<DT>> converter = createConverter<DT>();
createConverter<DT>();

if (byteSize_ != converter->byteSize())
throw cms::Exception("TritonDataError") << name_ << " input(): inconsistent byte size " << converter->byteSize()
if (byteSize_ != std::any_cast<std::shared_ptr<TritonConverterBase<DT>>>(converter_)->byteSize())
throw cms::Exception("TritonDataError") << name_ << " input(): inconsistent byte size " << std::any_cast<std::shared_ptr<TritonConverterBase<DT>>>(converter_)->byteSize()
<< " (should be " << byteSize_ << " for " << dname_ << ")";

int64_t nInput = sizeShape();
for (unsigned i0 = 0; i0 < batchSize_; ++i0) {
const DT* arr = data_in[i0].data();
triton_utils::throwIfError(data_->AppendRaw(converter->convertIn(arr), nInput * byteSize_),
triton_utils::throwIfError(data_->AppendRaw(std::any_cast<std::shared_ptr<TritonConverterBase<DT>>>(converter_)->convertIn(arr), nInput * byteSize_),
name_ + " input(): unable to set data for batch entry " + std::to_string(i0));
}

Expand All @@ -141,7 +141,8 @@ TritonOutput<DT> TritonOutputData::fromServer() const {
throw cms::Exception("TritonDataError") << name_ << " output(): missing result";
}

std::unique_ptr<TritonConverterBase<DT>> converter = createConverter<DT>();
createConverter<DT>();
//std::unique_ptr<TritonConverterBase<DT>> converter = std::any_cast<converter>;

if (byteSize_ != sizeof(DT)) {
throw cms::Exception("TritonDataError") << name_ << " output(): inconsistent byte size " << sizeof(DT)
Expand All @@ -152,14 +153,14 @@ TritonOutput<DT> TritonOutputData::fromServer() const {
TritonOutput<DT> dataOut;
const uint8_t* r0;
size_t contentByteSize;
size_t expectedContentByteSize = nOutput * converter->byteSize() * batchSize_;
size_t expectedContentByteSize = nOutput * std::any_cast<std::shared_ptr<TritonConverterBase<DT>>>(converter_)->byteSize() * batchSize_;
triton_utils::throwIfError(result_->RawData(name_, &r0, &contentByteSize), "output(): unable to get raw");
if (contentByteSize != expectedContentByteSize) {
throw cms::Exception("TritonDataError") << name_ << " output(): unexpected content byte size " << contentByteSize
<< " (expected " << expectedContentByteSize << ")";
}

const DT* r1 = converter->convertOut(r0);
const DT* r1 = std::any_cast<std::shared_ptr<TritonConverterBase<DT>>>(converter_)->convertOut(r0);
dataOut.reserve(batchSize_);
for (unsigned i0 = 0; i0 < batchSize_; ++i0) {
auto offset = i0 * nOutput;
Expand Down Expand Up @@ -188,9 +189,3 @@ template void TritonInputData::toServer(std::shared_ptr<TritonInput<float>> data
template void TritonInputData::toServer(std::shared_ptr<TritonInput<int64_t>> data_in);

template TritonOutput<float> TritonOutputData::fromServer() const;

template std::unique_ptr<TritonConverterBase<float>> TritonInputData::createConverter() const;
template std::unique_ptr<TritonConverterBase<int64_t>> TritonInputData::createConverter() const;

template std::unique_ptr<TritonConverterBase<float>> TritonOutputData::createConverter() const;
template std::unique_ptr<TritonConverterBase<int64_t>> TritonOutputData::createConverter() const;