File tree Expand file tree Collapse file tree 5 files changed +9
-5
lines changed Expand file tree Collapse file tree 5 files changed +9
-5
lines changed Original file line number Diff line number Diff line change @@ -19,7 +19,7 @@ struct CompileSpec {
1919 partitioning::PartitionInfo partition_info;
2020};
2121
22- bool CheckMethodOperatorSupport (const torch::jit::script::Module& mod, std::string method_name);
22+ bool CheckMethodOperatorSupport (const torch::jit::script::Module& mod, std::string method_name, CompileSpec cfg );
2323
2424std::string ConvertGraphToTRTEngine (const torch::jit::script::Module& mod, std::string method_name, CompileSpec cfg);
2525
Original file line number Diff line number Diff line change @@ -17,7 +17,7 @@ struct LowerInfo {
1717 bool disable_cse = false ;
1818 std::vector<std::string> forced_fallback_modules;
1919 friend std::ostream& operator <<(std::ostream& os, const LowerInfo& l);
20- }
20+ };
2121
2222void LowerBlock (torch::jit::Block* b);
2323void LowerGraph (std::shared_ptr<torch::jit::Graph>& g, LowerInfo lower_info);
Original file line number Diff line number Diff line change @@ -577,6 +577,9 @@ struct TRTORCH_API CompileSpec {
577577 // / A list of names of operations that will explicitly run in PyTorch
578578 std::vector<std::string> forced_fallback_ops;
579579
580+ // / A list of names of modules that will explicitly run in PyTorch
581+ std::vector<std::string> forced_fallback_modules;
582+
580583 /* *
581584 * @brief Construct a default Torch Fallback object, fallback will be off
582585 */
@@ -781,7 +784,7 @@ TRTORCH_API void dump_build_info();
781784 *
782785 * @returns bool: Method is supported by TRTorch
783786 */
784- TRTORCH_API bool CheckMethodOperatorSupport (const torch::jit::Module& module , std::string method_name);
787+ TRTORCH_API bool CheckMethodOperatorSupport (const torch::jit::Module& module , std::string method_name, CompileSpec info );
785788
786789/* *
787790 * @brief Compile a TorchScript module for NVIDIA GPUs using TensorRT
Original file line number Diff line number Diff line change @@ -375,6 +375,7 @@ core::CompileSpec to_internal_compile_spec(CompileSpec external) {
375375 internal.partition_info .enabled = external.torch_fallback .enabled ;
376376 internal.partition_info .min_block_size = external.torch_fallback .min_block_size ;
377377 internal.partition_info .forced_fallback_operators = external.torch_fallback .forced_fallback_ops ;
378+ internal.lower_info .forced_fallback_modules = external.torch_fallback .forced_fallback_modules ;
378379
379380 switch (external.device .device_type ) {
380381 case CompileSpec::Device::DeviceType::kDLA :
Original file line number Diff line number Diff line change @@ -11,8 +11,8 @@ namespace trtorch {
1111core::CompileSpec to_internal_compile_spec (CompileSpec external);
1212core::runtime::CudaDevice to_internal_cuda_device (CompileSpec::Device device);
1313
14- bool CheckMethodOperatorSupport (const torch::jit::script::Module& module , std::string method_name) {
15- return core::CheckMethodOperatorSupport (module , method_name);
14+ bool CheckMethodOperatorSupport (const torch::jit::script::Module& module , std::string method_name, CompileSpec info ) {
15+ return core::CheckMethodOperatorSupport (module , method_name, to_internal_compile_spec (info) );
1616}
1717
1818std::string ConvertGraphToTRTEngine (
You can’t perform that action at this time.
0 commit comments