11#include " tensorrt_classes.h"
22
3- namespace trtorch {
3+ namespace torch_tensorrt {
4+ namespace torchscript {
45namespace backend {
56namespace {
67
@@ -9,58 +10,65 @@ namespace {
910 (registry).def(" _get_" #field_name, &class_name::get_##field_name);
1011
1112void RegisterTRTCompileSpec () {
12- static auto TRTORCH_UNUSED TRTInputRangeTSRegistration = torch::class_<trtorch::pyapi::Input>(" tensorrt" , " _Input" )
13- .def (torch::init<>())
14- .def (" __str__" , &trtorch::pyapi::Input::to_str);
13+ static auto TORCHTRT_UNUSED TRTInputRangeTSRegistration =
14+ torch::class_<torch_tensorrt::pyapi::Input>(" tensorrt" , " _Input" )
15+ .def (torch::init<>())
16+ .def (" __str__" , &torch_tensorrt::pyapi::Input::to_str);
1517
16- ADD_FIELD_GET_SET_REGISTRATION (TRTInputRangeTSRegistration, trtorch ::pyapi::Input, min);
17- ADD_FIELD_GET_SET_REGISTRATION (TRTInputRangeTSRegistration, trtorch ::pyapi::Input, opt);
18- ADD_FIELD_GET_SET_REGISTRATION (TRTInputRangeTSRegistration, trtorch ::pyapi::Input, max);
19- ADD_FIELD_GET_SET_REGISTRATION (TRTInputRangeTSRegistration, trtorch ::pyapi::Input, dtype);
20- ADD_FIELD_GET_SET_REGISTRATION (TRTInputRangeTSRegistration, trtorch ::pyapi::Input, format);
21- ADD_FIELD_GET_SET_REGISTRATION (TRTInputRangeTSRegistration, trtorch ::pyapi::Input, input_is_dynamic);
22- ADD_FIELD_GET_SET_REGISTRATION (TRTInputRangeTSRegistration, trtorch ::pyapi::Input, explicit_set_dtype);
18+ ADD_FIELD_GET_SET_REGISTRATION (TRTInputRangeTSRegistration, torch_tensorrt ::pyapi::Input, min);
19+ ADD_FIELD_GET_SET_REGISTRATION (TRTInputRangeTSRegistration, torch_tensorrt ::pyapi::Input, opt);
20+ ADD_FIELD_GET_SET_REGISTRATION (TRTInputRangeTSRegistration, torch_tensorrt ::pyapi::Input, max);
21+ ADD_FIELD_GET_SET_REGISTRATION (TRTInputRangeTSRegistration, torch_tensorrt ::pyapi::Input, dtype);
22+ ADD_FIELD_GET_SET_REGISTRATION (TRTInputRangeTSRegistration, torch_tensorrt ::pyapi::Input, format);
23+ ADD_FIELD_GET_SET_REGISTRATION (TRTInputRangeTSRegistration, torch_tensorrt ::pyapi::Input, input_is_dynamic);
24+ ADD_FIELD_GET_SET_REGISTRATION (TRTInputRangeTSRegistration, torch_tensorrt ::pyapi::Input, explicit_set_dtype);
2325
24- static auto TRTORCH_UNUSED TRTDeviceTSRegistration = torch::class_<trtorch::pyapi::Device>(" tensorrt" , " _Device" )
25- .def (torch::init<>())
26- .def (" __str__" , &trtorch::pyapi::Device::to_str);
26+ static auto TORCHTRT_UNUSED TRTDeviceTSRegistration =
27+ torch::class_<torch_tensorrt::pyapi::Device>(" tensorrt" , " _Device" )
28+ .def (torch::init<>())
29+ .def (" __str__" , &torch_tensorrt::pyapi::Device::to_str);
2730
28- ADD_FIELD_GET_SET_REGISTRATION (TRTDeviceTSRegistration, trtorch ::pyapi::Device, device_type);
29- ADD_FIELD_GET_SET_REGISTRATION (TRTDeviceTSRegistration, trtorch ::pyapi::Device, gpu_id);
30- ADD_FIELD_GET_SET_REGISTRATION (TRTDeviceTSRegistration, trtorch ::pyapi::Device, dla_core);
31- ADD_FIELD_GET_SET_REGISTRATION (TRTDeviceTSRegistration, trtorch ::pyapi::Device, allow_gpu_fallback);
31+ ADD_FIELD_GET_SET_REGISTRATION (TRTDeviceTSRegistration, torch_tensorrt ::pyapi::Device, device_type);
32+ ADD_FIELD_GET_SET_REGISTRATION (TRTDeviceTSRegistration, torch_tensorrt ::pyapi::Device, gpu_id);
33+ ADD_FIELD_GET_SET_REGISTRATION (TRTDeviceTSRegistration, torch_tensorrt ::pyapi::Device, dla_core);
34+ ADD_FIELD_GET_SET_REGISTRATION (TRTDeviceTSRegistration, torch_tensorrt ::pyapi::Device, allow_gpu_fallback);
3235
33- static auto TRTORCH_UNUSED TRTFallbackTSRegistration =
34- torch::class_<trtorch ::pyapi::TorchFallback>(" tensorrt" , " _TorchFallback" )
36+ static auto TORCHTRT_UNUSED TRTFallbackTSRegistration =
37+ torch::class_<torch_tensorrt ::pyapi::TorchFallback>(" tensorrt" , " _TorchFallback" )
3538 .def (torch::init<>())
36- .def (" __str__" , &trtorch ::pyapi::TorchFallback::to_str);
39+ .def (" __str__" , &torch_tensorrt ::pyapi::TorchFallback::to_str);
3740
38- ADD_FIELD_GET_SET_REGISTRATION (TRTFallbackTSRegistration, trtorch::pyapi::TorchFallback, enabled);
39- ADD_FIELD_GET_SET_REGISTRATION (TRTFallbackTSRegistration, trtorch::pyapi::TorchFallback, min_block_size);
40- ADD_FIELD_GET_SET_REGISTRATION (TRTFallbackTSRegistration, trtorch::pyapi::TorchFallback, forced_fallback_operators);
41- ADD_FIELD_GET_SET_REGISTRATION (TRTFallbackTSRegistration, trtorch::pyapi::TorchFallback, forced_fallback_modules);
41+ ADD_FIELD_GET_SET_REGISTRATION (TRTFallbackTSRegistration, torch_tensorrt::pyapi::TorchFallback, enabled);
42+ ADD_FIELD_GET_SET_REGISTRATION (TRTFallbackTSRegistration, torch_tensorrt::pyapi::TorchFallback, min_block_size);
43+ ADD_FIELD_GET_SET_REGISTRATION (
44+ TRTFallbackTSRegistration, torch_tensorrt::pyapi::TorchFallback, forced_fallback_operators);
45+ ADD_FIELD_GET_SET_REGISTRATION (
46+ TRTFallbackTSRegistration, torch_tensorrt::pyapi::TorchFallback, forced_fallback_modules);
4247
43- static auto TRTORCH_UNUSED TRTCompileSpecTSRegistration =
44- torch::class_<trtorch ::pyapi::CompileSpec>(" tensorrt" , " CompileSpec" )
48+ static auto TORCHTRT_UNUSED TRTCompileSpecTSRegistration =
49+ torch::class_<torch_tensorrt ::pyapi::CompileSpec>(" tensorrt" , " CompileSpec" )
4550 .def (torch::init<>())
46- .def (" _append_input" , &trtorch ::pyapi::CompileSpec::appendInput)
47- .def (" _set_precisions" , &trtorch ::pyapi::CompileSpec::setPrecisions)
48- .def (" _set_device" , &trtorch ::pyapi::CompileSpec::setDeviceIntrusive)
49- .def (" _set_torch_fallback" , &trtorch ::pyapi::CompileSpec::setTorchFallbackIntrusive)
50- .def (" _set_ptq_calibrator" , &trtorch ::pyapi::CompileSpec::setPTQCalibratorViaHandle)
51- .def (" __str__" , &trtorch ::pyapi::CompileSpec::stringify);
51+ .def (" _append_input" , &torch_tensorrt ::pyapi::CompileSpec::appendInput)
52+ .def (" _set_precisions" , &torch_tensorrt ::pyapi::CompileSpec::setPrecisions)
53+ .def (" _set_device" , &torch_tensorrt ::pyapi::CompileSpec::setDeviceIntrusive)
54+ .def (" _set_torch_fallback" , &torch_tensorrt ::pyapi::CompileSpec::setTorchFallbackIntrusive)
55+ .def (" _set_ptq_calibrator" , &torch_tensorrt ::pyapi::CompileSpec::setPTQCalibratorViaHandle)
56+ .def (" __str__" , &torch_tensorrt ::pyapi::CompileSpec::stringify);
5257
53- ADD_FIELD_GET_SET_REGISTRATION (TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, sparse_weights);
54- ADD_FIELD_GET_SET_REGISTRATION (TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, disable_tf32);
55- ADD_FIELD_GET_SET_REGISTRATION (TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, refit);
56- ADD_FIELD_GET_SET_REGISTRATION (TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, debug);
57- ADD_FIELD_GET_SET_REGISTRATION (TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, strict_types);
58- ADD_FIELD_GET_SET_REGISTRATION (TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, capability);
59- ADD_FIELD_GET_SET_REGISTRATION (TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, num_min_timing_iters);
60- ADD_FIELD_GET_SET_REGISTRATION (TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, num_avg_timing_iters);
61- ADD_FIELD_GET_SET_REGISTRATION (TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, workspace_size);
62- ADD_FIELD_GET_SET_REGISTRATION (TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, max_batch_size);
63- ADD_FIELD_GET_SET_REGISTRATION (TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, truncate_long_and_double);
58+ ADD_FIELD_GET_SET_REGISTRATION (TRTCompileSpecTSRegistration, torch_tensorrt::pyapi::CompileSpec, sparse_weights);
59+ ADD_FIELD_GET_SET_REGISTRATION (TRTCompileSpecTSRegistration, torch_tensorrt::pyapi::CompileSpec, disable_tf32);
60+ ADD_FIELD_GET_SET_REGISTRATION (TRTCompileSpecTSRegistration, torch_tensorrt::pyapi::CompileSpec, refit);
61+ ADD_FIELD_GET_SET_REGISTRATION (TRTCompileSpecTSRegistration, torch_tensorrt::pyapi::CompileSpec, debug);
62+ ADD_FIELD_GET_SET_REGISTRATION (TRTCompileSpecTSRegistration, torch_tensorrt::pyapi::CompileSpec, strict_types);
63+ ADD_FIELD_GET_SET_REGISTRATION (TRTCompileSpecTSRegistration, torch_tensorrt::pyapi::CompileSpec, capability);
64+ ADD_FIELD_GET_SET_REGISTRATION (
65+ TRTCompileSpecTSRegistration, torch_tensorrt::pyapi::CompileSpec, num_min_timing_iters);
66+ ADD_FIELD_GET_SET_REGISTRATION (
67+ TRTCompileSpecTSRegistration, torch_tensorrt::pyapi::CompileSpec, num_avg_timing_iters);
68+ ADD_FIELD_GET_SET_REGISTRATION (TRTCompileSpecTSRegistration, torch_tensorrt::pyapi::CompileSpec, workspace_size);
69+ ADD_FIELD_GET_SET_REGISTRATION (TRTCompileSpecTSRegistration, torch_tensorrt::pyapi::CompileSpec, max_batch_size);
70+ ADD_FIELD_GET_SET_REGISTRATION (
71+ TRTCompileSpecTSRegistration, torch_tensorrt::pyapi::CompileSpec, truncate_long_and_double);
6472}
6573
6674struct TRTTSRegistrations {
@@ -72,4 +80,5 @@ struct TRTTSRegistrations {
7280static TRTTSRegistrations register_trt_classes = TRTTSRegistrations();
7381} // namespace
7482} // namespace backend
75- } // namespace trtorch
83+ } // namespace torchscript
84+ } // namespace torch_tensorrt
0 commit comments