26
26
generate_partitions_from_list_of_nodes ,
27
27
)
28
28
29
- from executorch .exir .backend .partitioner import DelegationSpec , Partitioner
29
+ from executorch .exir .backend .partitioner import (
30
+ DelegationSpec ,
31
+ Partitioner ,
32
+ PartitionResult ,
33
+ )
30
34
from executorch .exir .dialects ._ops import ops as exir_ops
31
35
from torch .fx .passes .infra .partitioner import Partition
32
36
from torch .fx .passes .operator_support import OperatorSupportBase
@@ -313,7 +317,6 @@ def __init__(
313
317
self .supported_ops = set (supported_ops or [])
314
318
315
319
self .delegation_spec = DelegationSpec (XnnpackBackend .__name__ , [])
316
- self .partition_tags : Dict [str , DelegationSpec ] = {}
317
320
318
321
@staticmethod
319
322
def check_partitions (partitions : Union [dict , list ]) -> bool :
@@ -386,27 +389,30 @@ def generate_partitions(self, graph_module: torch.fx.GraphModule) -> List[Any]:
386
389
),
387
390
)
388
391
389
- def tag_nodes (self , partitions : List [Partition ]) -> None :
392
+ def tag_nodes (self , partitions : List [Partition ]) -> Dict [ str , DelegationSpec ] :
390
393
"""
391
394
Tag each partition in the list with its delegation tag.
392
395
"""
396
+ partition_tags : Dict [str , DelegationSpec ] = {}
393
397
for partition in partitions :
394
398
# Add delegation tags
395
399
for node in partition .nodes :
396
400
delegation_tag = f"tag{ partition .id } "
397
401
node .meta ["delegation_tag" ] = delegation_tag
398
- self .partition_tags [delegation_tag ] = self .delegation_spec
402
+ partition_tags [delegation_tag ] = self .delegation_spec
403
+ return partition_tags
399
404
400
405
# override
401
- def partition (self , graph_module : torch .fx .GraphModule ) -> torch . fx . GraphModule :
406
+ def partition (self , graph_module : torch .fx .GraphModule ) -> PartitionResult :
402
407
"""
403
408
Run the partitioner on the given graph module, then tag each partition
404
409
with its delegation tag (and partition id)
405
410
"""
406
411
partitions = self .generate_partitions (graph_module )
412
+ partition_tags : Dict [str , DelegationSpec ] = {}
407
413
if self .check_partitions (partitions ):
408
- self .tag_nodes (partitions )
409
- return graph_module
414
+ partition_tags = self .tag_nodes (partitions )
415
+ return PartitionResult ( tagged_graph = graph_module , partition_tags = partition_tags )
410
416
411
417
412
418
# TODO: Merge XnnpackQuantizedPartitioner and XnnpackFloatingPointPartitioner
@@ -761,10 +767,11 @@ def generate_partitions(
761
767
XnnpackOperatorSupport (supported_ops = list (self .get_supported_ops (quant ))),
762
768
)
763
769
764
- def tag_nodes (self , partitions : List [Partition ]) -> None :
770
+ def tag_nodes (self , partitions : List [Partition ]) -> Dict [ str , DelegationSpec ] :
765
771
"""
766
772
Tag each partition in the list with its delegation tag.
767
773
"""
774
+ partition_tags : Dict [str , DelegationSpec ] = {}
768
775
for partition in partitions :
769
776
# Add delegation tags
770
777
skip = False
@@ -776,23 +783,25 @@ def tag_nodes(self, partitions: List[Partition]) -> None:
776
783
for node in partition .nodes :
777
784
delegation_tag = f"tag{ partition .id } "
778
785
node .meta ["delegation_tag" ] = delegation_tag
779
- self .partition_tags [delegation_tag ] = self .delegation_spec
786
+ partition_tags [delegation_tag ] = self .delegation_spec
787
+ return partition_tags
780
788
781
789
# override
782
790
def _partition (
783
791
self , graph_module : torch .fx .GraphModule , quant : Optional [bool ]
784
- ) -> torch . fx . GraphModule :
792
+ ) -> PartitionResult :
785
793
"""
786
794
Run the partitioner on the given graph module, then tag each partition
787
795
with its delegation tag (and partition id)
788
796
"""
789
797
partitions = self .generate_partitions (graph_module , quant )
798
+ partition_tags : Dict [str , DelegationSpec ] = {}
790
799
if self .check_partitions (partitions ):
791
- self .tag_nodes (partitions )
792
- return graph_module
800
+ partition_tags = self .tag_nodes (partitions )
801
+ return PartitionResult ( tagged_graph = graph_module , partition_tags = partition_tags )
793
802
794
- def partition (self , graph_module : torch .fx .GraphModule ) -> torch . fx . GraphModule :
795
- ret = self ._partition (graph_module , self .quant )
803
+ def partition (self , graph_module : torch .fx .GraphModule ) -> PartitionResult :
804
+ ret : PartitionResult = self ._partition (graph_module , self .quant )
796
805
return ret
797
806
798
807
@@ -805,7 +814,7 @@ def __init__(
805
814
super ().__init__ (supported_modules , supported_ops )
806
815
807
816
# override
808
- def partition (self , graph_module : torch .fx .GraphModule ) -> torch . fx . GraphModule :
817
+ def partition (self , graph_module : torch .fx .GraphModule ) -> PartitionResult :
809
818
"""
810
819
Run the partitioner on the given graph module, then tag each partition with its delegegation tag (and partition id)
811
820
@@ -819,7 +828,8 @@ def partition(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
819
828
)
820
829
for match in self .get_module_partitions (graph_module )
821
830
]
831
+ partition_tags : Dict [str , DelegationSpec ] = {}
822
832
823
833
if self .check_partitions (partitions ):
824
- self .tag_nodes (partitions )
825
- return graph_module
834
+ partition_tags = self .tag_nodes (partitions )
835
+ return PartitionResult ( tagged_graph = graph_module , partition_tags = partition_tags )
0 commit comments