Skip to content

Commit 28eb55e

Browse files
committed
Fixes for latest 1.9 changes; cleanup
Signed-off-by: Boris Fomitchev <[email protected]>
1 parent dd0f6d5 commit 28eb55e

File tree

6 files changed

+42
-25
lines changed

6 files changed

+42
-25
lines changed

core/conversion/conversionctx/ConversionCtx.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ struct BuilderSettings {
3232
bool strict_types = false;
3333
bool truncate_long_and_double = false;
3434
Device device;
35-
nvinfer1::EngineCapability capability = nvinfer1::EngineCapability::kDEFAULT;
35+
nvinfer1::EngineCapability capability = TRT_ENGINE_CAPABILITY_STANDARD;
3636
nvinfer1::IInt8Calibrator* calibrator = nullptr;
3737
uint64_t num_min_timing_iters = 2;
3838
uint64_t num_avg_timing_iters = 1;

core/util/trt_util.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ std::vector<int64_t> toVec(nvinfer1::Dims d) {
201201
for (int i = 0; i < d.nbDims; i++) {
202202
dims.push_back(d.d[i]);
203203
}
204-
return std::move(dims);
204+
return dims;
205205
}
206206

207207
std::string toStr(nvinfer1::Dims d) {

core/util/trt_util.h

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,27 @@
99
namespace nvinfer1 {
1010

1111
#if NV_TENSORRT_MAJOR < 8
12+
13+
#define TRT_ENGINE_CAPABILITY_STANDARD nvinfer1::EngineCapability::kDEFAULT
14+
#define TRT_ENGINE_CAPABILITY_SAFETY nvinfer1::EngineCapability::kSAFE_GPU
15+
#define TRT_ENGINE_CAPABILITY_DLA_STANDALONE nvinfer1::EngineCapability::kSAFE_DLA
16+
1217
template <class T>
1318
std::shared_ptr<T> make_trt(T* p) {
1419
return std::shared_ptr<T>(p, [](T* p){p->destroy();});
1520
}
21+
1622
#else
23+
24+
#define TRT_ENGINE_CAPABILITY_STANDARD nvinfer1::EngineCapability::kSTANDARD
25+
#define TRT_ENGINE_CAPABILITY_SAFETY nvinfer1::EngineCapability::kSAFETY
26+
#define TRT_ENGINE_CAPABILITY_DLA_STANDALONE nvinfer1::EngineCapability::kDLA_STANDALONE
27+
1728
template <class T>
1829
std::shared_ptr<T> make_trt(T* p) {
1930
return std::shared_ptr<T>(p);
2031
}
32+
2133
#endif
2234

2335
inline std::ostream& operator<<(std::ostream& os, const nvinfer1::TensorFormat& format) {
@@ -99,11 +111,11 @@ inline std::ostream& operator<<(std::ostream& stream, const nvinfer1::DeviceType
99111

100112
inline std::ostream& operator<<(std::ostream& stream, const nvinfer1::EngineCapability& cap) {
101113
switch (cap) {
102-
case nvinfer1::EngineCapability::kDEFAULT:
114+
case TRT_ENGINE_CAPABILITY_STANDARD:
103115
return stream << "standard";
104-
case nvinfer1::EngineCapability::kSAFE_GPU:
116+
case TRT_ENGINE_CAPABILITY_SAFETY:
105117
return stream << "safety";
106-
case nvinfer1::EngineCapability::kSAFE_DLA:
118+
case TRT_ENGINE_CAPABILITY_DLA_STANDALONE:
107119
return stream << "DLA standalone";
108120
default:
109121
return stream << "Unknown Engine Capability Setting";

cpp/api/src/compile_spec.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -387,14 +387,14 @@ core::CompileSpec to_internal_compile_spec(CompileSpec external) {
387387

388388
switch (external.capability) {
389389
case CompileSpec::EngineCapability::kSAFETY:
390-
internal.convert_info.engine_settings.capability = nvinfer1::EngineCapability::kSAFE_GPU;
390+
internal.convert_info.engine_settings.capability = TRT_ENGINE_CAPABILITY_SAFETY;
391391
break;
392392
case CompileSpec::EngineCapability::kDLA_STANDALONE:
393-
internal.convert_info.engine_settings.capability = nvinfer1::EngineCapability::kSAFE_DLA;
393+
internal.convert_info.engine_settings.capability = TRT_ENGINE_CAPABILITY_DLA_STANDALONE;
394394
break;
395395
case CompileSpec::EngineCapability::kSTANDARD:
396396
default:
397-
internal.convert_info.engine_settings.capability = nvinfer1::EngineCapability::kDEFAULT;
397+
internal.convert_info.engine_settings.capability = TRT_ENGINE_CAPABILITY_STANDARD;
398398
}
399399

400400
internal.convert_info.engine_settings.device.gpu_id = external.device.gpu_id;

py/trtorch/csrc/tensorrt_backend.cpp

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,6 @@
1010
namespace trtorch {
1111
namespace backend {
1212

13-
namespace {
14-
c10::IValue preprocess(const torch::jit::Module& mod, const c10::Dict<c10::IValue, c10::IValue>& method_compile_spec) {
15-
for (auto it = method_compile_spec.begin(), end = method_compile_spec.end(); it != end; ++it) {
16-
TRTORCH_CHECK(
17-
core::CheckMethodOperatorSupport(mod, it->key().toStringRef()),
18-
"Method " << it->key().toStringRef() << "cannot be compiled by TRTorch");
19-
}
20-
21-
return mod._ivalue();
22-
}
23-
} // namespace
24-
2513
c10::impl::GenericDict TensorRTBackend::compile(c10::IValue mod_val, c10::impl::GenericDict method_compile_spec) {
2614
auto mod = mod_val.toModule();
2715
mod = core::lowering::LowerModule(mod);
@@ -77,8 +65,25 @@ c10::impl::GenericList TensorRTBackend::execute(c10::IValue handle, c10::impl::G
7765
}
7866

7967
namespace {
80-
static auto reg = torch::jit::backend<TensorRTBackend>("tensorrt");
81-
static auto preproc_reg = torch::jit::backend_preprocess_register("tensorrt", &preprocess);
68+
c10::IValue preprocess(const torch::jit::Module& mod,
69+
const c10::Dict<c10::IValue,
70+
#ifdef EARLY_PYTORCH_19X_VERSION
71+
c10::IValue>& method_compile_spec
72+
# else
73+
c10::IValue>& method_compile_spec, const torch::jit::BackendDebugHandleGenerator& generate_debug_handles
74+
#endif
75+
) {
76+
for (auto it = method_compile_spec.begin(), end = method_compile_spec.end(); it != end; ++it) {
77+
TRTORCH_CHECK(
78+
core::CheckMethodOperatorSupport(mod, it->key().toStringRef()),
79+
"Method " << it->key().toStringRef() << "cannot be compiled by TRTorch");
80+
}
81+
return mod._ivalue();
82+
};
83+
84+
static const std::string trt("tensorrt");
85+
static auto reg = torch::jit::backend<TensorRTBackend>(trt);
86+
static auto preproc_reg = torch::jit::backend_preprocess_register(trt, torch::jit::detail::BackendPreprocessFunction(preprocess));
8287
} // namespace
8388

8489
} // namespace backend

py/trtorch/csrc/tensorrt_classes.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -146,12 +146,12 @@ std::string to_str(EngineCapability value) {
146146
nvinfer1::EngineCapability toTRTEngineCapability(EngineCapability value) {
147147
switch (value) {
148148
case EngineCapability::kSAFE_DLA:
149-
return nvinfer1::EngineCapability::kSAFE_DLA;
149+
return TRT_ENGINE_CAPABILITY_DLA_STANDALONE;
150150
case EngineCapability::kSAFE_GPU:
151-
return nvinfer1::EngineCapability::kSAFE_GPU;
151+
return TRT_ENGINE_CAPABILITY_SAFETY;
152152
case EngineCapability::kDEFAULT:
153153
default:
154-
return nvinfer1::EngineCapability::kDEFAULT;
154+
return TRT_ENGINE_CAPABILITY_STANDARD;
155155
}
156156
}
157157

0 commit comments

Comments
 (0)