Skip to content

Commit 7861ab3

Browse files
authored
Merge pull request #540 from NVIDIA/fallback_trtorchc
feat: Support fallback options in trtorchc
2 parents 7cdd9f5 + 01ffb5a commit 7861ab3

File tree

3 files changed

+54
-8
lines changed

3 files changed

+54
-8
lines changed

cpp/trtorchc/README.md

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,15 +36,22 @@ trtorchc [input_file_path] [output_file_path]
3636
--allow-gpu-fallback (Only used when targeting DLA
3737
(device-type)) Lets engine run layers on
3838
GPU if they are not supported on DLA
39+
40+
--allow-torch-fallback Enable layers to run in torch
41+
if they are not supported in TensorRT
42+
3943
--disable-tf32 Prevent Float32 layers from using the
4044
TF32 data format
4145
-p[precision...],
4246
--enabled-precison=[precision...] (Repeatable) Enabling an operating
4347
precision for kernels to use when
44-
building the engine (Int8 requires a
45-
calibration-cache argument) [ float |
48+
building the engine [ float |
4649
float32 | f32 | half | float16 | f16 |
4750
int8 | i8 ] (default: float)
51+
52+
--ffo,
53+
--forced-fallback-ops List of operators in the graph that
54+
should be forced to fallback to Pytorch for execution
4855
-d[type], --device-type=[type] The type of device the engine should be
4956
built for [ gpu | dla ] (default: gpu)
5057
--gpu-id=[gpu_id] GPU id if running on multi-GPU platform
@@ -96,6 +103,7 @@ trtorchc [input_file_path] [output_file_path]
96103
```
97104

98105
e.g.
106+
99107
```
100108
trtorchc tests/modules/ssd_traced.jit.pt ssd_trt.ts "[(1,3,300,300); (1,3,512,512); (1, 3, 1024, 1024)]@fp16%contiguous" -p f16
101-
```
109+
```

cpp/trtorchc/main.cpp

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -237,14 +237,20 @@ int main(int argc, char** argv) {
237237
"(Only used when targeting DLA (device-type)) Lets engine run layers on GPU if they are not supported on DLA",
238238
{"allow-gpu-fallback"});
239239

240+
args::Flag allow_torch_fallback(
241+
parser,
242+
"allow-torch-fallback",
243+
"Enable layers to run in torch if they are not supported in TensorRT",
244+
{"allow-torch-fallback"});
245+
240246
args::Flag disable_tf32(
241247
parser, "disable-tf32", "Prevent Float32 layers from using the TF32 data format", {"disable-tf32"});
242248

243249
args::ValueFlagList<std::string> enabled_precision(
244250
parser,
245251
"precision",
246252
"(Repeatable) Enabling an operating precision for kernels to use when building the engine (Int8 requires a calibration-cache argument) [ float | float32 | f32 | fp32 | half | float16 | f16 | fp16 | int8 | i8 | char ] (default: float)",
247-
{'p', "enabled-precison"});
253+
{'p', "enabled-precision"});
248254
args::ValueFlag<std::string> device_type(
249255
parser,
250256
"type",
@@ -267,6 +273,12 @@ int main(int argc, char** argv) {
267273
"Path to calibration cache file to use for post training quantization",
268274
{"calibration-cache-file"});
269275

276+
args::ValueFlagList<std::string> forced_fallback_ops(
277+
parser,
278+
"forced_fallback_ops",
279+
"(Repeatable) List of operators in the graph that should be forced to fallback to Pytorch for execution.",
280+
{"ffo", "forced-fallback-ops"});
281+
270282
args::Flag embed_engine(
271283
parser,
272284
"embed-engine",
@@ -442,6 +454,10 @@ int main(int argc, char** argv) {
442454
compile_settings.device.allow_gpu_fallback = true;
443455
}
444456

457+
if (allow_torch_fallback) {
458+
compile_settings.torch_fallback = trtorch::CompileSpec::TorchFallback(true);
459+
}
460+
445461
if (disable_tf32) {
446462
compile_settings.disable_tf32 = true;
447463
}
@@ -453,6 +469,18 @@ int main(int argc, char** argv) {
453469

454470
auto calibrator = trtorch::ptq::make_int8_cache_calibrator(calibration_cache_file_path);
455471

472+
if (forced_fallback_ops) {
473+
if (!allow_torch_fallback) {
474+
trtorch::logging::log(
475+
trtorch::logging::Level::kERROR,
476+
"Forced fallback ops provided but allow_torch_fallback is False. Please use --allow-torch-fallback to enable automatic fallback of operators.");
477+
}
478+
479+
for (const auto fallback_op : args::get(forced_fallback_ops)) {
480+
compile_settings.torch_fallback.forced_fallback_ops.push_back(fallback_op);
481+
}
482+
}
483+
456484
if (enabled_precision) {
457485
for (const auto precision : args::get(enabled_precision)) {
458486
auto dtype = parseDataType(precision);
@@ -563,9 +591,11 @@ int main(int argc, char** argv) {
563591
return 1;
564592
}
565593

566-
if (!trtorch::CheckMethodOperatorSupport(mod, "forward")) {
567-
trtorch::logging::log(trtorch::logging::Level::kERROR, "Module is not currently supported by TRTorch");
568-
return 1;
594+
if (!allow_torch_fallback) {
595+
if (!trtorch::CheckMethodOperatorSupport(mod, "forward")) {
596+
trtorch::logging::log(trtorch::logging::Level::kERROR, "Module is not currently supported by TRTorch");
597+
return 1;
598+
}
569599
}
570600

571601
if (save_engine) {

docsrc/tutorials/trtorchc.rst

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,14 @@ to standard TorchScript. Load with ``torch.jit.load()`` and run like you would r
3939
--allow-gpu-fallback (Only used when targeting DLA
4040
(device-type)) Lets engine run layers on
4141
GPU if they are not supported on DLA
42-
--disable-tf32 Prevent Float32 layers from using the
42+
43+
--allow-torch-fallback Enable layers to run in torch
44+
if they are not supported in TensorRT
45+
--ffo,
46+
--forced-fallback-ops List of operators in the graph that
47+
should be forced to fallback to Pytorch for execution
48+
49+
--disable-tf32 Prevent Float32 layers from using the
4350
TF32 data format
4451
-p[precision...],
4552
--enabled-precison=[precision...] (Repeatable) Enabling an operating
@@ -48,6 +55,7 @@ to standard TorchScript. Load with ``torch.jit.load()`` and run like you would r
4855
calibration-cache argument) [ float |
4956
float32 | f32 | half | float16 | f16 |
5057
int8 | i8 ] (default: float)
58+
5159
-d[type], --device-type=[type] The type of device the engine should be
5260
built for [ gpu | dla ] (default: gpu)
5361
--gpu-id=[gpu_id] GPU id if running on multi-GPU platform

0 commit comments

Comments
 (0)