Skip to content

Commit 493c19f

Browse files
committed
refactor: Realign py api
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 715a0ce commit 493c19f

File tree

8 files changed

+101
-88
lines changed

8 files changed

+101
-88
lines changed

py/torch_tensorrt/_version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "1.0.0a0+483ef591"
1+
__version__ = "1.0.0a0+715a0cea"
Lines changed: 54 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "tensorrt_classes.h"
22

3-
namespace trtorch {
3+
namespace torch_tensorrt {
4+
namespace torchscript {
45
namespace backend {
56
namespace {
67

@@ -9,58 +10,65 @@ namespace {
910
(registry).def("_get_" #field_name, &class_name::get_##field_name);
1011

1112
void RegisterTRTCompileSpec() {
12-
static auto TRTORCH_UNUSED TRTInputRangeTSRegistration = torch::class_<trtorch::pyapi::Input>("tensorrt", "_Input")
13-
.def(torch::init<>())
14-
.def("__str__", &trtorch::pyapi::Input::to_str);
13+
static auto TORCHTRT_UNUSED TRTInputRangeTSRegistration =
14+
torch::class_<torch_tensorrt::pyapi::Input>("tensorrt", "_Input")
15+
.def(torch::init<>())
16+
.def("__str__", &torch_tensorrt::pyapi::Input::to_str);
1517

16-
ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, trtorch::pyapi::Input, min);
17-
ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, trtorch::pyapi::Input, opt);
18-
ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, trtorch::pyapi::Input, max);
19-
ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, trtorch::pyapi::Input, dtype);
20-
ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, trtorch::pyapi::Input, format);
21-
ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, trtorch::pyapi::Input, input_is_dynamic);
22-
ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, trtorch::pyapi::Input, explicit_set_dtype);
18+
ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, torch_tensorrt::pyapi::Input, min);
19+
ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, torch_tensorrt::pyapi::Input, opt);
20+
ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, torch_tensorrt::pyapi::Input, max);
21+
ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, torch_tensorrt::pyapi::Input, dtype);
22+
ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, torch_tensorrt::pyapi::Input, format);
23+
ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, torch_tensorrt::pyapi::Input, input_is_dynamic);
24+
ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, torch_tensorrt::pyapi::Input, explicit_set_dtype);
2325

24-
static auto TRTORCH_UNUSED TRTDeviceTSRegistration = torch::class_<trtorch::pyapi::Device>("tensorrt", "_Device")
25-
.def(torch::init<>())
26-
.def("__str__", &trtorch::pyapi::Device::to_str);
26+
static auto TORCHTRT_UNUSED TRTDeviceTSRegistration =
27+
torch::class_<torch_tensorrt::pyapi::Device>("tensorrt", "_Device")
28+
.def(torch::init<>())
29+
.def("__str__", &torch_tensorrt::pyapi::Device::to_str);
2730

28-
ADD_FIELD_GET_SET_REGISTRATION(TRTDeviceTSRegistration, trtorch::pyapi::Device, device_type);
29-
ADD_FIELD_GET_SET_REGISTRATION(TRTDeviceTSRegistration, trtorch::pyapi::Device, gpu_id);
30-
ADD_FIELD_GET_SET_REGISTRATION(TRTDeviceTSRegistration, trtorch::pyapi::Device, dla_core);
31-
ADD_FIELD_GET_SET_REGISTRATION(TRTDeviceTSRegistration, trtorch::pyapi::Device, allow_gpu_fallback);
31+
ADD_FIELD_GET_SET_REGISTRATION(TRTDeviceTSRegistration, torch_tensorrt::pyapi::Device, device_type);
32+
ADD_FIELD_GET_SET_REGISTRATION(TRTDeviceTSRegistration, torch_tensorrt::pyapi::Device, gpu_id);
33+
ADD_FIELD_GET_SET_REGISTRATION(TRTDeviceTSRegistration, torch_tensorrt::pyapi::Device, dla_core);
34+
ADD_FIELD_GET_SET_REGISTRATION(TRTDeviceTSRegistration, torch_tensorrt::pyapi::Device, allow_gpu_fallback);
3235

33-
static auto TRTORCH_UNUSED TRTFallbackTSRegistration =
34-
torch::class_<trtorch::pyapi::TorchFallback>("tensorrt", "_TorchFallback")
36+
static auto TORCHTRT_UNUSED TRTFallbackTSRegistration =
37+
torch::class_<torch_tensorrt::pyapi::TorchFallback>("tensorrt", "_TorchFallback")
3538
.def(torch::init<>())
36-
.def("__str__", &trtorch::pyapi::TorchFallback::to_str);
39+
.def("__str__", &torch_tensorrt::pyapi::TorchFallback::to_str);
3740

38-
ADD_FIELD_GET_SET_REGISTRATION(TRTFallbackTSRegistration, trtorch::pyapi::TorchFallback, enabled);
39-
ADD_FIELD_GET_SET_REGISTRATION(TRTFallbackTSRegistration, trtorch::pyapi::TorchFallback, min_block_size);
40-
ADD_FIELD_GET_SET_REGISTRATION(TRTFallbackTSRegistration, trtorch::pyapi::TorchFallback, forced_fallback_operators);
41-
ADD_FIELD_GET_SET_REGISTRATION(TRTFallbackTSRegistration, trtorch::pyapi::TorchFallback, forced_fallback_modules);
41+
ADD_FIELD_GET_SET_REGISTRATION(TRTFallbackTSRegistration, torch_tensorrt::pyapi::TorchFallback, enabled);
42+
ADD_FIELD_GET_SET_REGISTRATION(TRTFallbackTSRegistration, torch_tensorrt::pyapi::TorchFallback, min_block_size);
43+
ADD_FIELD_GET_SET_REGISTRATION(
44+
TRTFallbackTSRegistration, torch_tensorrt::pyapi::TorchFallback, forced_fallback_operators);
45+
ADD_FIELD_GET_SET_REGISTRATION(
46+
TRTFallbackTSRegistration, torch_tensorrt::pyapi::TorchFallback, forced_fallback_modules);
4247

43-
static auto TRTORCH_UNUSED TRTCompileSpecTSRegistration =
44-
torch::class_<trtorch::pyapi::CompileSpec>("tensorrt", "CompileSpec")
48+
static auto TORCHTRT_UNUSED TRTCompileSpecTSRegistration =
49+
torch::class_<torch_tensorrt::pyapi::CompileSpec>("tensorrt", "CompileSpec")
4550
.def(torch::init<>())
46-
.def("_append_input", &trtorch::pyapi::CompileSpec::appendInput)
47-
.def("_set_precisions", &trtorch::pyapi::CompileSpec::setPrecisions)
48-
.def("_set_device", &trtorch::pyapi::CompileSpec::setDeviceIntrusive)
49-
.def("_set_torch_fallback", &trtorch::pyapi::CompileSpec::setTorchFallbackIntrusive)
50-
.def("_set_ptq_calibrator", &trtorch::pyapi::CompileSpec::setPTQCalibratorViaHandle)
51-
.def("__str__", &trtorch::pyapi::CompileSpec::stringify);
51+
.def("_append_input", &torch_tensorrt::pyapi::CompileSpec::appendInput)
52+
.def("_set_precisions", &torch_tensorrt::pyapi::CompileSpec::setPrecisions)
53+
.def("_set_device", &torch_tensorrt::pyapi::CompileSpec::setDeviceIntrusive)
54+
.def("_set_torch_fallback", &torch_tensorrt::pyapi::CompileSpec::setTorchFallbackIntrusive)
55+
.def("_set_ptq_calibrator", &torch_tensorrt::pyapi::CompileSpec::setPTQCalibratorViaHandle)
56+
.def("__str__", &torch_tensorrt::pyapi::CompileSpec::stringify);
5257

53-
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, sparse_weights);
54-
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, disable_tf32);
55-
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, refit);
56-
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, debug);
57-
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, strict_types);
58-
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, capability);
59-
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, num_min_timing_iters);
60-
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, num_avg_timing_iters);
61-
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, workspace_size);
62-
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, max_batch_size);
63-
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, truncate_long_and_double);
58+
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, torch_tensorrt::pyapi::CompileSpec, sparse_weights);
59+
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, torch_tensorrt::pyapi::CompileSpec, disable_tf32);
60+
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, torch_tensorrt::pyapi::CompileSpec, refit);
61+
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, torch_tensorrt::pyapi::CompileSpec, debug);
62+
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, torch_tensorrt::pyapi::CompileSpec, strict_types);
63+
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, torch_tensorrt::pyapi::CompileSpec, capability);
64+
ADD_FIELD_GET_SET_REGISTRATION(
65+
TRTCompileSpecTSRegistration, torch_tensorrt::pyapi::CompileSpec, num_min_timing_iters);
66+
ADD_FIELD_GET_SET_REGISTRATION(
67+
TRTCompileSpecTSRegistration, torch_tensorrt::pyapi::CompileSpec, num_avg_timing_iters);
68+
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, torch_tensorrt::pyapi::CompileSpec, workspace_size);
69+
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, torch_tensorrt::pyapi::CompileSpec, max_batch_size);
70+
ADD_FIELD_GET_SET_REGISTRATION(
71+
TRTCompileSpecTSRegistration, torch_tensorrt::pyapi::CompileSpec, truncate_long_and_double);
6472
}
6573

6674
struct TRTTSRegistrations {
@@ -72,4 +80,5 @@ struct TRTTSRegistrations {
7280
static TRTTSRegistrations register_trt_classes = TRTTSRegistrations();
7381
} // namespace
7482
} // namespace backend
75-
} // namespace trtorch
83+
} // namespace torchscript
84+
} // namespace torch_tensorrt

py/torch_tensorrt/csrc/tensorrt_backend.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
#include "core/lowering/lowering.h"
88
#include "core/runtime/runtime.h"
99

10-
namespace trtorch {
10+
namespace torch_tensorrt {
11+
namespace torchscript {
1112
namespace backend {
1213

1314
c10::impl::GenericDict TensorRTBackend::compile(c10::IValue mod_val, c10::impl::GenericDict method_compile_spec) {
@@ -21,7 +22,7 @@ c10::impl::GenericDict TensorRTBackend::compile(c10::IValue mod_val, c10::impl::
2122
for (auto it = spec.begin(), end = spec.end(); it != end; ++it) {
2223
auto mod_ = mod.clone();
2324
const auto& method_name = it->key();
24-
auto raw_spec = it->value().toCustomClass<trtorch::pyapi::CompileSpec>();
25+
auto raw_spec = it->value().toCustomClass<torch_tensorrt::pyapi::CompileSpec>();
2526
LOG_DEBUG(raw_spec->stringify());
2627

2728
auto cfg = raw_spec->toInternalCompileSpec();
@@ -37,12 +38,12 @@ c10::impl::GenericDict TensorRTBackend::compile(c10::IValue mod_val, c10::impl::
3738
}
3839

3940
c10::impl::GenericList TensorRTBackend::execute(c10::IValue handle, c10::impl::GenericList inputs) {
40-
TRTORCH_ASSERT(inputs.size() > 0, "Trying to execute on empty list of arguments");
41+
TORCHTRT_ASSERT(inputs.size() > 0, "Trying to execute on empty list of arguments");
4142
auto engine = handle.toCustomClass<core::runtime::TRTEngine>();
4243
std::vector<at::Tensor> in_vec;
4344
for (size_t i = 0, e = inputs.size(); i < e; ++i) {
4445
c10::IValue val = inputs[i];
45-
TRTORCH_CHECK(val.isTensor(), "TensorRT currently only accepts Tensors as inputs");
46+
TORCHTRT_CHECK(val.isTensor(), "TensorRT currently only accepts Tensors as inputs");
4647
in_vec.push_back(val.toTensor());
4748
}
4849
auto outputs = core::runtime::execute_engine(in_vec, engine);
@@ -63,7 +64,7 @@ c10::IValue preprocess(
6364
#endif
6465
) {
6566
for (auto it = method_compile_spec.begin(), end = method_compile_spec.end(); it != end; ++it) {
66-
TRTORCH_CHECK(
67+
TORCHTRT_CHECK(
6768
core::CheckMethodOperatorSupport(mod, it->key().toStringRef()),
6869
"Method " << it->key().toStringRef() << "cannot be compiled by TRTorch");
6970
}
@@ -77,4 +78,5 @@ static auto preproc_reg =
7778
} // namespace
7879

7980
} // namespace backend
80-
} // namespace trtorch
81+
} // namespace torchscript
82+
} // namespace torch_tensorrt

py/torch_tensorrt/csrc/tensorrt_backend.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
#include "torch/csrc/jit/backends/backend_debug_handler.h"
55
#include "torch/csrc/jit/backends/backend_preprocess.h"
66

7-
namespace trtorch {
7+
namespace torch_tensorrt {
8+
namespace torchscript {
89
namespace backend {
910

1011
class TensorRTBackend : public torch::jit::PyTorchBackendInterface {
@@ -21,4 +22,5 @@ class TensorRTBackend : public torch::jit::PyTorchBackendInterface {
2122
};
2223

2324
} // namespace backend
24-
} // namespace trtorch
25+
} // namespace torchscript
26+
} // namespace torch_tensorrt

py/torch_tensorrt/csrc/tensorrt_classes.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11

22
#include "tensorrt_classes.h"
33

4-
namespace trtorch {
4+
namespace torch_tensorrt {
55
namespace pyapi {
66

77
std::string to_str(DataType value) {
@@ -36,7 +36,7 @@ nvinfer1::DataType toTRTDataType(DataType value) {
3636
case DataType::kUnknown:
3737
return nvinfer1::DataType::kFLOAT;
3838
default:
39-
TRTORCH_THROW_ERROR("Unknown data type: " << to_str(value));
39+
TORCHTRT_THROW_ERROR("Unknown data type: " << to_str(value));
4040
}
4141
}
4242

@@ -221,13 +221,13 @@ core::CompileSpec CompileSpec::toInternalCompileSpec() {
221221
info.convert_info.engine_settings.truncate_long_and_double = truncate_long_and_double;
222222

223223
info.convert_info.engine_settings.capability = toTRTEngineCapability(capability);
224-
TRTORCH_CHECK(num_min_timing_iters >= 0, "num_min_timing_iters must be 0 or greater");
224+
TORCHTRT_CHECK(num_min_timing_iters >= 0, "num_min_timing_iters must be 0 or greater");
225225
info.convert_info.engine_settings.num_min_timing_iters = num_min_timing_iters;
226-
TRTORCH_CHECK(num_avg_timing_iters >= 0, "num_avg_timing_iters must be 0 or greater");
226+
TORCHTRT_CHECK(num_avg_timing_iters >= 0, "num_avg_timing_iters must be 0 or greater");
227227
info.convert_info.engine_settings.num_avg_timing_iters = num_avg_timing_iters;
228-
TRTORCH_CHECK(workspace_size >= 0, "workspace_size must be 0 or greater");
228+
TORCHTRT_CHECK(workspace_size >= 0, "workspace_size must be 0 or greater");
229229
info.convert_info.engine_settings.workspace_size = workspace_size;
230-
TRTORCH_CHECK(max_batch_size >= 0, "max_batch_size must be 0 or greater");
230+
TORCHTRT_CHECK(max_batch_size >= 0, "max_batch_size must be 0 or greater");
231231
info.convert_info.engine_settings.max_batch_size = max_batch_size;
232232
return info;
233233
}
@@ -263,4 +263,4 @@ std::string CompileSpec::stringify() {
263263
}
264264

265265
} // namespace pyapi
266-
} // namespace trtorch
266+
} // namespace torch_tensorrt

py/torch_tensorrt/csrc/tensorrt_classes.h

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
#include "torch/script.h"
77
#include "torch/torch.h"
88

9-
namespace trtorch {
9+
namespace torch_tensorrt {
1010
namespace pyapi {
1111

1212
#define ADD_FIELD_GET_SET(field_name, type) \
@@ -18,13 +18,13 @@ namespace pyapi {
1818
}
1919

2020
// TODO: Make this error message more informative
21-
#define ADD_ENUM_GET_SET(field_name, type, max_val) \
22-
void set_##field_name(int64_t val) { \
23-
TRTORCH_CHECK(val >= 0 && val <= max_val, "Invalid enum value for field"); \
24-
field_name = static_cast<type>(val); \
25-
} \
26-
int64_t get_##field_name() { \
27-
return static_cast<int64_t>(field_name); \
21+
#define ADD_ENUM_GET_SET(field_name, type, max_val) \
22+
void set_##field_name(int64_t val) { \
23+
TORCHTRT_CHECK(val >= 0 && val <= max_val, "Invalid enum value for field"); \
24+
field_name = static_cast<type>(val); \
25+
} \
26+
int64_t get_##field_name() { \
27+
return static_cast<int64_t>(field_name); \
2828
}
2929

3030
enum class DataType : int8_t { kFloat, kHalf, kChar, kInt32, kBool, kUnknown };
@@ -121,7 +121,7 @@ struct CompileSpec : torch::CustomClassHolder {
121121

122122
void setPrecisions(const std::vector<int64_t>& precisions_raw) {
123123
for (auto p : precisions_raw) {
124-
TRTORCH_CHECK(p >= 0 && p <= static_cast<int64_t>(DataType::kBool), "Invalid enum value for field");
124+
TORCHTRT_CHECK(p >= 0 && p <= static_cast<int64_t>(DataType::kBool), "Invalid enum value for field");
125125
enabled_precisions.insert(static_cast<DataType>(p));
126126
}
127127
}
@@ -176,4 +176,4 @@ struct CompileSpec : torch::CustomClassHolder {
176176
};
177177

178178
} // namespace pyapi
179-
} // namespace trtorch
179+
} // namespace torch_tensorrt

0 commit comments

Comments
 (0)