@@ -237,14 +237,20 @@ int main(int argc, char** argv) {
237
237
" (Only used when targeting DLA (device-type)) Lets engine run layers on GPU if they are not supported on DLA" ,
238
238
{" allow-gpu-fallback" });
239
239
240
+ args::Flag allow_torch_fallback (
241
+ parser,
242
+ " allow-torch-fallback" ,
243
+ " Enable layers to run in torch if they are not supported in TensorRT" ,
244
+ {" allow-torch-fallback" });
245
+
240
246
args::Flag disable_tf32 (
241
247
parser, " disable-tf32" , " Prevent Float32 layers from using the TF32 data format" , {" disable-tf32" });
242
248
243
249
args::ValueFlagList<std::string> enabled_precision (
244
250
parser,
245
251
" precision" ,
246
252
" (Repeatable) Enabling an operating precision for kernels to use when building the engine (Int8 requires a calibration-cache argument) [ float | float32 | f32 | fp32 | half | float16 | f16 | fp16 | int8 | i8 | char ] (default: float)" ,
247
- {' p' , " enabled-precison " });
253
+ {' p' , " enabled-precision " });
248
254
args::ValueFlag<std::string> device_type (
249
255
parser,
250
256
" type" ,
@@ -267,6 +273,12 @@ int main(int argc, char** argv) {
267
273
" Path to calibration cache file to use for post training quantization" ,
268
274
{" calibration-cache-file" });
269
275
276
+ args::ValueFlagList<std::string> forced_fallback_ops (
277
+ parser,
278
+ " forced_fallback_ops" ,
279
+ " (Repeatable) List of operators in the graph that should be forced to fallback to Pytorch for execution." ,
280
+ {" ffo" , " forced-fallback-ops" });
281
+
270
282
args::Flag embed_engine (
271
283
parser,
272
284
" embed-engine" ,
@@ -442,6 +454,10 @@ int main(int argc, char** argv) {
442
454
compile_settings.device .allow_gpu_fallback = true ;
443
455
}
444
456
457
+ if (allow_torch_fallback) {
458
+ compile_settings.torch_fallback = trtorch::CompileSpec::TorchFallback (true );
459
+ }
460
+
445
461
if (disable_tf32) {
446
462
compile_settings.disable_tf32 = true ;
447
463
}
@@ -453,6 +469,18 @@ int main(int argc, char** argv) {
453
469
454
470
auto calibrator = trtorch::ptq::make_int8_cache_calibrator (calibration_cache_file_path);
455
471
472
+ if (forced_fallback_ops) {
473
+ if (!allow_torch_fallback) {
474
+ trtorch::logging::log (
475
+ trtorch::logging::Level::kERROR ,
476
+ " Forced fallback ops provided but allow_torch_fallback is False. Please use --allow-torch-fallback to enable automatic fallback of operators." );
477
+ }
478
+
479
+ for (const auto fallback_op : args::get (forced_fallback_ops)) {
480
+ compile_settings.torch_fallback .forced_fallback_ops .push_back (fallback_op);
481
+ }
482
+ }
483
+
456
484
if (enabled_precision) {
457
485
for (const auto precision : args::get (enabled_precision)) {
458
486
auto dtype = parseDataType (precision);
@@ -563,9 +591,11 @@ int main(int argc, char** argv) {
563
591
return 1 ;
564
592
}
565
593
566
- if (!trtorch::CheckMethodOperatorSupport (mod, " forward" )) {
567
- trtorch::logging::log (trtorch::logging::Level::kERROR , " Module is not currently supported by TRTorch" );
568
- return 1 ;
594
+ if (!allow_torch_fallback) {
595
+ if (!trtorch::CheckMethodOperatorSupport (mod, " forward" )) {
596
+ trtorch::logging::log (trtorch::logging::Level::kERROR , " Module is not currently supported by TRTorch" );
597
+ return 1 ;
598
+ }
569
599
}
570
600
571
601
if (save_engine) {
0 commit comments