1
1
#include " tensorrt_classes.h"
2
2
3
- namespace trtorch {
3
+ namespace torch_tensorrt {
4
+ namespace torchscript {
4
5
namespace backend {
5
6
namespace {
6
7
@@ -9,58 +10,65 @@ namespace {
9
10
(registry).def(" _get_" #field_name, &class_name::get_##field_name);
10
11
11
12
void RegisterTRTCompileSpec () {
12
- static auto TRTORCH_UNUSED TRTInputRangeTSRegistration = torch::class_<trtorch::pyapi::Input>(" tensorrt" , " _Input" )
13
- .def (torch::init<>())
14
- .def (" __str__" , &trtorch::pyapi::Input::to_str);
13
+ static auto TORCHTRT_UNUSED TRTInputRangeTSRegistration =
14
+ torch::class_<torch_tensorrt::pyapi::Input>(" tensorrt" , " _Input" )
15
+ .def (torch::init<>())
16
+ .def (" __str__" , &torch_tensorrt::pyapi::Input::to_str);
15
17
16
- ADD_FIELD_GET_SET_REGISTRATION (TRTInputRangeTSRegistration, trtorch ::pyapi::Input, min);
17
- ADD_FIELD_GET_SET_REGISTRATION (TRTInputRangeTSRegistration, trtorch ::pyapi::Input, opt);
18
- ADD_FIELD_GET_SET_REGISTRATION (TRTInputRangeTSRegistration, trtorch ::pyapi::Input, max);
19
- ADD_FIELD_GET_SET_REGISTRATION (TRTInputRangeTSRegistration, trtorch ::pyapi::Input, dtype);
20
- ADD_FIELD_GET_SET_REGISTRATION (TRTInputRangeTSRegistration, trtorch ::pyapi::Input, format);
21
- ADD_FIELD_GET_SET_REGISTRATION (TRTInputRangeTSRegistration, trtorch ::pyapi::Input, input_is_dynamic);
22
- ADD_FIELD_GET_SET_REGISTRATION (TRTInputRangeTSRegistration, trtorch ::pyapi::Input, explicit_set_dtype);
18
+ ADD_FIELD_GET_SET_REGISTRATION (TRTInputRangeTSRegistration, torch_tensorrt ::pyapi::Input, min);
19
+ ADD_FIELD_GET_SET_REGISTRATION (TRTInputRangeTSRegistration, torch_tensorrt ::pyapi::Input, opt);
20
+ ADD_FIELD_GET_SET_REGISTRATION (TRTInputRangeTSRegistration, torch_tensorrt ::pyapi::Input, max);
21
+ ADD_FIELD_GET_SET_REGISTRATION (TRTInputRangeTSRegistration, torch_tensorrt ::pyapi::Input, dtype);
22
+ ADD_FIELD_GET_SET_REGISTRATION (TRTInputRangeTSRegistration, torch_tensorrt ::pyapi::Input, format);
23
+ ADD_FIELD_GET_SET_REGISTRATION (TRTInputRangeTSRegistration, torch_tensorrt ::pyapi::Input, input_is_dynamic);
24
+ ADD_FIELD_GET_SET_REGISTRATION (TRTInputRangeTSRegistration, torch_tensorrt ::pyapi::Input, explicit_set_dtype);
23
25
24
- static auto TRTORCH_UNUSED TRTDeviceTSRegistration = torch::class_<trtorch::pyapi::Device>(" tensorrt" , " _Device" )
25
- .def (torch::init<>())
26
- .def (" __str__" , &trtorch::pyapi::Device::to_str);
26
+ static auto TORCHTRT_UNUSED TRTDeviceTSRegistration =
27
+ torch::class_<torch_tensorrt::pyapi::Device>(" tensorrt" , " _Device" )
28
+ .def (torch::init<>())
29
+ .def (" __str__" , &torch_tensorrt::pyapi::Device::to_str);
27
30
28
- ADD_FIELD_GET_SET_REGISTRATION (TRTDeviceTSRegistration, trtorch ::pyapi::Device, device_type);
29
- ADD_FIELD_GET_SET_REGISTRATION (TRTDeviceTSRegistration, trtorch ::pyapi::Device, gpu_id);
30
- ADD_FIELD_GET_SET_REGISTRATION (TRTDeviceTSRegistration, trtorch ::pyapi::Device, dla_core);
31
- ADD_FIELD_GET_SET_REGISTRATION (TRTDeviceTSRegistration, trtorch ::pyapi::Device, allow_gpu_fallback);
31
+ ADD_FIELD_GET_SET_REGISTRATION (TRTDeviceTSRegistration, torch_tensorrt ::pyapi::Device, device_type);
32
+ ADD_FIELD_GET_SET_REGISTRATION (TRTDeviceTSRegistration, torch_tensorrt ::pyapi::Device, gpu_id);
33
+ ADD_FIELD_GET_SET_REGISTRATION (TRTDeviceTSRegistration, torch_tensorrt ::pyapi::Device, dla_core);
34
+ ADD_FIELD_GET_SET_REGISTRATION (TRTDeviceTSRegistration, torch_tensorrt ::pyapi::Device, allow_gpu_fallback);
32
35
33
- static auto TRTORCH_UNUSED TRTFallbackTSRegistration =
34
- torch::class_<trtorch ::pyapi::TorchFallback>(" tensorrt" , " _TorchFallback" )
36
+ static auto TORCHTRT_UNUSED TRTFallbackTSRegistration =
37
+ torch::class_<torch_tensorrt ::pyapi::TorchFallback>(" tensorrt" , " _TorchFallback" )
35
38
.def (torch::init<>())
36
- .def (" __str__" , &trtorch ::pyapi::TorchFallback::to_str);
39
+ .def (" __str__" , &torch_tensorrt ::pyapi::TorchFallback::to_str);
37
40
38
- ADD_FIELD_GET_SET_REGISTRATION (TRTFallbackTSRegistration, trtorch::pyapi::TorchFallback, enabled);
39
- ADD_FIELD_GET_SET_REGISTRATION (TRTFallbackTSRegistration, trtorch::pyapi::TorchFallback, min_block_size);
40
- ADD_FIELD_GET_SET_REGISTRATION (TRTFallbackTSRegistration, trtorch::pyapi::TorchFallback, forced_fallback_operators);
41
- ADD_FIELD_GET_SET_REGISTRATION (TRTFallbackTSRegistration, trtorch::pyapi::TorchFallback, forced_fallback_modules);
41
+ ADD_FIELD_GET_SET_REGISTRATION (TRTFallbackTSRegistration, torch_tensorrt::pyapi::TorchFallback, enabled);
42
+ ADD_FIELD_GET_SET_REGISTRATION (TRTFallbackTSRegistration, torch_tensorrt::pyapi::TorchFallback, min_block_size);
43
+ ADD_FIELD_GET_SET_REGISTRATION (
44
+ TRTFallbackTSRegistration, torch_tensorrt::pyapi::TorchFallback, forced_fallback_operators);
45
+ ADD_FIELD_GET_SET_REGISTRATION (
46
+ TRTFallbackTSRegistration, torch_tensorrt::pyapi::TorchFallback, forced_fallback_modules);
42
47
43
- static auto TRTORCH_UNUSED TRTCompileSpecTSRegistration =
44
- torch::class_<trtorch ::pyapi::CompileSpec>(" tensorrt" , " CompileSpec" )
48
+ static auto TORCHTRT_UNUSED TRTCompileSpecTSRegistration =
49
+ torch::class_<torch_tensorrt ::pyapi::CompileSpec>(" tensorrt" , " CompileSpec" )
45
50
.def (torch::init<>())
46
- .def (" _append_input" , &trtorch ::pyapi::CompileSpec::appendInput)
47
- .def (" _set_precisions" , &trtorch ::pyapi::CompileSpec::setPrecisions)
48
- .def (" _set_device" , &trtorch ::pyapi::CompileSpec::setDeviceIntrusive)
49
- .def (" _set_torch_fallback" , &trtorch ::pyapi::CompileSpec::setTorchFallbackIntrusive)
50
- .def (" _set_ptq_calibrator" , &trtorch ::pyapi::CompileSpec::setPTQCalibratorViaHandle)
51
- .def (" __str__" , &trtorch ::pyapi::CompileSpec::stringify);
51
+ .def (" _append_input" , &torch_tensorrt ::pyapi::CompileSpec::appendInput)
52
+ .def (" _set_precisions" , &torch_tensorrt ::pyapi::CompileSpec::setPrecisions)
53
+ .def (" _set_device" , &torch_tensorrt ::pyapi::CompileSpec::setDeviceIntrusive)
54
+ .def (" _set_torch_fallback" , &torch_tensorrt ::pyapi::CompileSpec::setTorchFallbackIntrusive)
55
+ .def (" _set_ptq_calibrator" , &torch_tensorrt ::pyapi::CompileSpec::setPTQCalibratorViaHandle)
56
+ .def (" __str__" , &torch_tensorrt ::pyapi::CompileSpec::stringify);
52
57
53
- ADD_FIELD_GET_SET_REGISTRATION (TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, sparse_weights);
54
- ADD_FIELD_GET_SET_REGISTRATION (TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, disable_tf32);
55
- ADD_FIELD_GET_SET_REGISTRATION (TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, refit);
56
- ADD_FIELD_GET_SET_REGISTRATION (TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, debug);
57
- ADD_FIELD_GET_SET_REGISTRATION (TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, strict_types);
58
- ADD_FIELD_GET_SET_REGISTRATION (TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, capability);
59
- ADD_FIELD_GET_SET_REGISTRATION (TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, num_min_timing_iters);
60
- ADD_FIELD_GET_SET_REGISTRATION (TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, num_avg_timing_iters);
61
- ADD_FIELD_GET_SET_REGISTRATION (TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, workspace_size);
62
- ADD_FIELD_GET_SET_REGISTRATION (TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, max_batch_size);
63
- ADD_FIELD_GET_SET_REGISTRATION (TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, truncate_long_and_double);
58
+ ADD_FIELD_GET_SET_REGISTRATION (TRTCompileSpecTSRegistration, torch_tensorrt::pyapi::CompileSpec, sparse_weights);
59
+ ADD_FIELD_GET_SET_REGISTRATION (TRTCompileSpecTSRegistration, torch_tensorrt::pyapi::CompileSpec, disable_tf32);
60
+ ADD_FIELD_GET_SET_REGISTRATION (TRTCompileSpecTSRegistration, torch_tensorrt::pyapi::CompileSpec, refit);
61
+ ADD_FIELD_GET_SET_REGISTRATION (TRTCompileSpecTSRegistration, torch_tensorrt::pyapi::CompileSpec, debug);
62
+ ADD_FIELD_GET_SET_REGISTRATION (TRTCompileSpecTSRegistration, torch_tensorrt::pyapi::CompileSpec, strict_types);
63
+ ADD_FIELD_GET_SET_REGISTRATION (TRTCompileSpecTSRegistration, torch_tensorrt::pyapi::CompileSpec, capability);
64
+ ADD_FIELD_GET_SET_REGISTRATION (
65
+ TRTCompileSpecTSRegistration, torch_tensorrt::pyapi::CompileSpec, num_min_timing_iters);
66
+ ADD_FIELD_GET_SET_REGISTRATION (
67
+ TRTCompileSpecTSRegistration, torch_tensorrt::pyapi::CompileSpec, num_avg_timing_iters);
68
+ ADD_FIELD_GET_SET_REGISTRATION (TRTCompileSpecTSRegistration, torch_tensorrt::pyapi::CompileSpec, workspace_size);
69
+ ADD_FIELD_GET_SET_REGISTRATION (TRTCompileSpecTSRegistration, torch_tensorrt::pyapi::CompileSpec, max_batch_size);
70
+ ADD_FIELD_GET_SET_REGISTRATION (
71
+ TRTCompileSpecTSRegistration, torch_tensorrt::pyapi::CompileSpec, truncate_long_and_double);
64
72
}
65
73
66
74
struct TRTTSRegistrations {
@@ -72,4 +80,5 @@ struct TRTTSRegistrations {
72
80
static TRTTSRegistrations register_trt_classes = TRTTSRegistrations();
73
81
} // namespace
74
82
} // namespace backend
75
- } // namespace trtorch
83
+ } // namespace torchscript
84
+ } // namespace torch_tensorrt
0 commit comments