Skip to content

Commit 9bf2456

Browse files
committed
docs(//py): Adding documentation on the limitation of the backend API
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent bf1b2d8 commit 9bf2456

File tree

4 files changed

+20
-16
lines changed

4 files changed

+20
-16
lines changed

py/trtorch/_compile_spec.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,10 @@ def TensorRTCompileSpec(compile_spec: Dict[str, Any]) -> torch.classes.tensorrt.
214214
One key is required which is ``input_shapes``, describing the input sizes or ranges for inputs
215215
to the graph. All other keys are optional. Entries for each method to be compiled.
216216
217+
Note: Partial compilation of TorchScript modules is not supported through the PyTorch TensorRT backend
218+
If you need this feature, use trtorch.compile to compile your module. Usage of the resulting module is
219+
as if you were using the TensorRT integration.
220+
217221
.. code-block:: py
218222
219223
CompileSpec = {
@@ -272,7 +276,9 @@ def TensorRTCompileSpec(compile_spec: Dict[str, Any]) -> torch.classes.tensorrt.
272276
d._set_allow_gpu_fallback(parsed_spec.device.allow_gpu_fallback)
273277

274278
if parsed_spec.torch_fallback.enabled:
275-
raise RuntimeError("Partial module compilation is not currently supported via the PyTorch to_backend API integration. If you need partial compilation, use trtorch.compile")
279+
raise RuntimeError(
280+
"Partial module compilation is not currently supported via the PyTorch TensorRT backend. If you need partial compilation, use trtorch.compile"
281+
)
276282

277283
torch_fallback = torch.classes.tensorrt._TorchFallback()
278284
torch_fallback._set_enabled(parsed_spec.torch_fallback.enabled)

py/trtorch/_compiler.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,11 @@ def compile(module: torch.jit.ScriptModule, compile_spec: Any) -> torch.jit.Scri
5050
"workspace_size": 0, # Maximum size of workspace given to TensorRT
5151
"max_batch_size": 0, # Maximum batch size (must be >= 1 to be set, 0 means not set)
5252
"torch_fallback": {
53-
"enabled": True,
53+
"enabled": True, # Turn on or turn off falling back to PyTorch if operations are not supported in TensorRT
5454
"force_fallback_ops": [
55-
"aten::max_pool2d"
55+
"aten::max_pool2d" # List of specific ops to require running in PyTorch
5656
],
57-
"min_block_size": 1
57+
"min_block_size": 3 # Minimum number of ops an engine must incapsulate to be run in TensorRT
5858
}
5959
}
6060

py/trtorch/csrc/register_tensorrt_classes.cpp

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,33 +5,32 @@ namespace backend {
55
namespace {
66

77
#define ADD_FIELD_GET_SET_REGISTRATION(registry, class_name, field_name) \
8-
(registry).def("_set_" #field_name, &class_name::set_##field_name); \
8+
(registry).def("_set_" #field_name, &class_name::set_##field_name); \
99
(registry).def("_get_" #field_name, &class_name::get_##field_name);
1010

1111
void RegisterTRTCompileSpec() {
1212
static auto TRTORCH_UNUSED TRTInputRangeTSRegistration =
13-
torch::class_<trtorch::pyapi::InputRange>("tensorrt", "_InputRange")
14-
.def(torch::init<>())
15-
.def("__str__", &trtorch::pyapi::InputRange::to_str);
13+
torch::class_<trtorch::pyapi::InputRange>("tensorrt", "_InputRange")
14+
.def(torch::init<>())
15+
.def("__str__", &trtorch::pyapi::InputRange::to_str);
1616

1717
ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, trtorch::pyapi::InputRange, min);
1818
ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, trtorch::pyapi::InputRange, opt);
1919
ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, trtorch::pyapi::InputRange, max);
2020

21-
static auto TRTORCH_UNUSED TRTDeviceTSRegistration =
22-
torch::class_<trtorch::pyapi::Device>("tensorrt", "_Device")
23-
.def(torch::init<>())
24-
.def("__str__", &trtorch::pyapi::Device::to_str);
21+
static auto TRTORCH_UNUSED TRTDeviceTSRegistration = torch::class_<trtorch::pyapi::Device>("tensorrt", "_Device")
22+
.def(torch::init<>())
23+
.def("__str__", &trtorch::pyapi::Device::to_str);
2524

2625
ADD_FIELD_GET_SET_REGISTRATION(TRTDeviceTSRegistration, trtorch::pyapi::Device, device_type);
2726
ADD_FIELD_GET_SET_REGISTRATION(TRTDeviceTSRegistration, trtorch::pyapi::Device, gpu_id);
2827
ADD_FIELD_GET_SET_REGISTRATION(TRTDeviceTSRegistration, trtorch::pyapi::Device, dla_core);
2928
ADD_FIELD_GET_SET_REGISTRATION(TRTDeviceTSRegistration, trtorch::pyapi::Device, allow_gpu_fallback);
3029

3130
static auto TRTORCH_UNUSED TRTFallbackTSRegistration =
32-
torch::class_<trtorch::pyapi::TorchFallback>("tensorrt", "_TorchFallback")
33-
.def(torch::init<>())
34-
.def("__str__", &trtorch::pyapi::TorchFallback::to_str);
31+
torch::class_<trtorch::pyapi::TorchFallback>("tensorrt", "_TorchFallback")
32+
.def(torch::init<>())
33+
.def("__str__", &trtorch::pyapi::TorchFallback::to_str);
3534

3635
ADD_FIELD_GET_SET_REGISTRATION(TRTFallbackTSRegistration, trtorch::pyapi::TorchFallback, enabled);
3736
ADD_FIELD_GET_SET_REGISTRATION(TRTFallbackTSRegistration, trtorch::pyapi::TorchFallback, min_block_size);

py/trtorch/csrc/tensorrt_classes.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,6 @@ struct TorchFallback : torch::CustomClassHolder {
9393
std::string to_str();
9494
};
9595

96-
9796
enum class EngineCapability : int8_t {
9897
kDEFAULT,
9998
kSAFE_GPU,

0 commit comments

Comments
 (0)