@@ -5,33 +5,32 @@ namespace backend {
5
5
namespace {
6
6
7
7
#define ADD_FIELD_GET_SET_REGISTRATION (registry, class_name, field_name ) \
8
- (registry).def(" _set_" #field_name, &class_name::set_##field_name); \
8
+ (registry).def(" _set_" #field_name, &class_name::set_##field_name); \
9
9
(registry).def(" _get_" #field_name, &class_name::get_##field_name);
10
10
11
11
void RegisterTRTCompileSpec () {
12
12
static auto TRTORCH_UNUSED TRTInputRangeTSRegistration =
13
- torch::class_<trtorch::pyapi::InputRange>(" tensorrt" , " _InputRange" )
14
- .def (torch::init<>())
15
- .def (" __str__" , &trtorch::pyapi::InputRange::to_str);
13
+ torch::class_<trtorch::pyapi::InputRange>(" tensorrt" , " _InputRange" )
14
+ .def (torch::init<>())
15
+ .def (" __str__" , &trtorch::pyapi::InputRange::to_str);
16
16
17
17
ADD_FIELD_GET_SET_REGISTRATION (TRTInputRangeTSRegistration, trtorch::pyapi::InputRange, min);
18
18
ADD_FIELD_GET_SET_REGISTRATION (TRTInputRangeTSRegistration, trtorch::pyapi::InputRange, opt);
19
19
ADD_FIELD_GET_SET_REGISTRATION (TRTInputRangeTSRegistration, trtorch::pyapi::InputRange, max);
20
20
21
- static auto TRTORCH_UNUSED TRTDeviceTSRegistration =
22
- torch::class_<trtorch::pyapi::Device>(" tensorrt" , " _Device" )
23
- .def (torch::init<>())
24
- .def (" __str__" , &trtorch::pyapi::Device::to_str);
21
+ static auto TRTORCH_UNUSED TRTDeviceTSRegistration = torch::class_<trtorch::pyapi::Device>(" tensorrt" , " _Device" )
22
+ .def (torch::init<>())
23
+ .def (" __str__" , &trtorch::pyapi::Device::to_str);
25
24
26
25
ADD_FIELD_GET_SET_REGISTRATION (TRTDeviceTSRegistration, trtorch::pyapi::Device, device_type);
27
26
ADD_FIELD_GET_SET_REGISTRATION (TRTDeviceTSRegistration, trtorch::pyapi::Device, gpu_id);
28
27
ADD_FIELD_GET_SET_REGISTRATION (TRTDeviceTSRegistration, trtorch::pyapi::Device, dla_core);
29
28
ADD_FIELD_GET_SET_REGISTRATION (TRTDeviceTSRegistration, trtorch::pyapi::Device, allow_gpu_fallback);
30
29
31
30
static auto TRTORCH_UNUSED TRTFallbackTSRegistration =
32
- torch::class_<trtorch::pyapi::TorchFallback>(" tensorrt" , " _TorchFallback" )
33
- .def (torch::init<>())
34
- .def (" __str__" , &trtorch::pyapi::TorchFallback::to_str);
31
+ torch::class_<trtorch::pyapi::TorchFallback>(" tensorrt" , " _TorchFallback" )
32
+ .def (torch::init<>())
33
+ .def (" __str__" , &trtorch::pyapi::TorchFallback::to_str);
35
34
36
35
ADD_FIELD_GET_SET_REGISTRATION (TRTFallbackTSRegistration, trtorch::pyapi::TorchFallback, enabled);
37
36
ADD_FIELD_GET_SET_REGISTRATION (TRTFallbackTSRegistration, trtorch::pyapi::TorchFallback, min_block_size);
0 commit comments