@@ -111,6 +111,7 @@ def __init__(
111111 min_block_size : int = MIN_BLOCK_SIZE ,
112112 require_full_compilation : bool = REQUIRE_FULL_COMPILATION ,
113113 return_tuple : bool = False ,
114+ skip_fusion : bool = False ,
114115 ):
115116 """
116117 Preprocesses graph before splitting:
@@ -127,6 +128,7 @@ def __init__(
127128 self .settings = _SplitterSettingBase (
128129 min_acc_module_size = min_block_size ,
129130 allow_non_tensor = True ,
131+ skip_fusion = skip_fusion ,
130132 )
131133 self .operator_support = operator_support
132134
@@ -252,6 +254,7 @@ def partition(
252254 min_block_size : int = MIN_BLOCK_SIZE ,
253255 torch_executed_ops : Collection [Target ] = set (),
254256 require_full_compilation : bool = REQUIRE_FULL_COMPILATION ,
257+ skip_fusion : bool = False ,
255258) -> Tuple [torch .fx .GraphModule , OpSupportTester ]:
256259 """Partition an FX GraphModule with aten ops into TRT engines
257260 Partitioning is based on converter operator support
@@ -262,6 +265,7 @@ def partition(
262265 min_block_size: Minimum number of operators per TRT-Engine Block
263266 torch_executed_ops: Collection of operations to run in Torch, regardless of converter coverage
264267 require_full_compilation: Require that all computational operators be run in TRT
268+ skip_fusion: Skip fusions found by FxNetAccFusionsFinder
265269 Returns:
266270 torch.fx.GraphModule, OpSupportTester
267271 """
@@ -277,6 +281,7 @@ def partition(
277281 supported_ops ,
278282 min_block_size = min_block_size ,
279283 require_full_compilation = require_full_compilation ,
284+ skip_fusion = skip_fusion ,
280285 )
281286
282287 partitioned_graph = partitioner .partition_graph ()
0 commit comments