Skip to content

Commit 1f30051

Browse files
committed
fix: pybind issue and handling exception
Signed-off-by: Anurag Dixit <[email protected]>
1 parent c7675a5 commit 1f30051

File tree

1 file changed

+62
-9
lines changed

1 file changed

+62
-9
lines changed

py/torch_tensorrt/csrc/torch_tensorrt_py.cpp

Lines changed: 62 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,16 @@ class pyCalibratorTrampoline : public Derived {
2323
using Derived::Derived; // Inherit constructors
2424

2525
int getBatchSize() const noexcept override {
26-
PYBIND11_OVERLOAD_PURE_NAME(int, Derived, "get_batch_size", getBatchSize);
26+
try {
27+
PYBIND11_OVERLOAD_PURE_NAME(int, Derived, "get_batch_size", getBatchSize);
28+
}
29+
catch(std::exception const &e) {
30+
LOG_ERROR("Exception caught in get_batch_size" + std::string(e.what()));
31+
}
32+
catch(...) {
33+
LOG_ERROR("Exception caught in get_batch_size");
34+
}
35+
return -1;
2736
}
2837

2938
bool getBatch(void* bindings[], const char* names[], int nbBindings) noexcept override {
@@ -71,8 +80,17 @@ class pyIInt8Calibrator : public pyCalibratorTrampoline<nvinfer1::IInt8Calibrato
7180
using Derived::Derived;
7281

7382
nvinfer1::CalibrationAlgoType getAlgorithm() noexcept override {
74-
PYBIND11_OVERLOAD_PURE_NAME(
75-
nvinfer1::CalibrationAlgoType, nvinfer1::IInt8Calibrator, "get_algorithm", getAlgorithm);
83+
try {
84+
PYBIND11_OVERLOAD_PURE_NAME(
85+
nvinfer1::CalibrationAlgoType, nvinfer1::IInt8Calibrator, "get_algorithm", getAlgorithm);
86+
}
87+
catch(std::exception const &e) {
88+
LOG_ERROR("Exception caught in get_algorithm: " + std::string(e.what()));
89+
}
90+
catch(...) {
91+
LOG_ERROR("Exception caught in get_algorithm");
92+
}
93+
return {};
7694
}
7795
};
7896

@@ -82,21 +100,56 @@ class pyIInt8LegacyCalibrator : public pyCalibratorTrampoline<nvinfer1::IInt8Leg
82100
using Derived::Derived;
83101

84102
double getQuantile() const noexcept override {
85-
PYBIND11_OVERLOAD_PURE_NAME(double, nvinfer1::IInt8LegacyCalibrator, "get_quantile", getQuantile);
103+
try {
104+
PYBIND11_OVERLOAD_PURE_NAME(double, nvinfer1::IInt8LegacyCalibrator, "get_quantile", getQuantile);
105+
}
106+
catch(std::exception const &e) {
107+
LOG_ERROR("Exception caught in get_quantile: " + std::string(e.what()));
108+
}
109+
catch(...) {
110+
LOG_ERROR("Exception caught in get_quantile");
111+
}
112+
return -1.0;
86113
}
87114

88115
double getRegressionCutoff() const noexcept override {
89-
PYBIND11_OVERLOAD_PURE_NAME(double, nvinfer1::IInt8LegacyCalibrator, "get_regression_cutoff", getRegressionCutoff);
116+
try {
117+
PYBIND11_OVERLOAD_PURE_NAME(double, nvinfer1::IInt8LegacyCalibrator, "get_regression_cutoff", getRegressionCutoff);
118+
}
119+
catch(std::exception const &e) {
120+
LOG_ERROR("Exception caught in get_regression_cutoff: " + std::string(e.what()));
121+
}
122+
catch(...) {
123+
LOG_ERROR("Exception caught in get_regression_cutoff");
124+
}
125+
return -1.0;
90126
}
91127

92128
const void* readHistogramCache(std::size_t& length) noexcept override {
93-
PYBIND11_OVERLOAD_PURE_NAME(
94-
void const*, nvinfer1::IInt8LegacyCalibrator, "read_histogram_cache", readHistogramCache, length);
129+
try {
130+
PYBIND11_OVERLOAD_PURE_NAME(
131+
const char*, nvinfer1::IInt8LegacyCalibrator, "read_histogram_cache", readHistogramCache, length);
132+
}
133+
catch(std::exception const& e) {
134+
LOG_ERROR("Exception caught in read_histogram_cache" + std::string(e.what()));
135+
}
136+
catch(...) {
137+
LOG_ERROR("Exception caught in read_histogram_cache");
138+
}
139+
return {};
95140
}
96141

97142
void writeHistogramCache(const void* ptr, std::size_t length) noexcept override {
98-
PYBIND11_OVERLOAD_PURE_NAME(
99-
void, nvinfer1::IInt8LegacyCalibrator, "write_histogram_cache", writeHistogramCache, ptr, length);
143+
try {
144+
PYBIND11_OVERLOAD_PURE_NAME(
145+
void, nvinfer1::IInt8LegacyCalibrator, "write_histogram_cache", writeHistogramCache, ptr, length);
146+
}
147+
catch(std::exception const& e) {
148+
LOG_ERROR("Exception caught in write_histogram_cache" + std::string(e.what()));
149+
}
150+
catch(...) {
151+
LOG_ERROR("Exception caught in write_histogram_cache");
152+
}
100153
}
101154
};
102155

0 commit comments

Comments
 (0)