|
23 | 23 | from torch.fx.passes.operator_support import OperatorSupportBase |
24 | 24 |
|
25 | 25 | logger = logging.getLogger(__name__) |
26 | | -logger.setLevel(logging.WARNING) |
| 26 | +logger.setLevel(logging.INFO) |
27 | 27 |
|
28 | 28 |
|
29 | | -class OperatorsSupportedForCoreMLBackend(OperatorSupportBase): |
| 29 | +class _OperatorsSupportedForCoreMLBackend(OperatorSupportBase): |
30 | 30 | def __init__( |
31 | 31 | self, |
32 | 32 | skip_ops_for_coreml_delegation: Optional[List[str]] = None, |
33 | 33 | lower_full_graph: bool = False, |
| 34 | + log: bool = False, |
34 | 35 | ) -> None: |
35 | 36 | if skip_ops_for_coreml_delegation is None: |
36 | 37 | skip_ops_for_coreml_delegation = [] |
37 | 38 | super().__init__() |
38 | 39 | self.skip_ops_for_coreml_delegation = skip_ops_for_coreml_delegation |
39 | 40 | self.lower_full_graph = lower_full_graph |
40 | 41 | self._logged_msgs = set() |
| 42 | + self._log = log |
41 | 43 |
|
42 | 44 | def log_once(self, msg: str) -> None: |
43 | | - if msg not in self._logged_msgs: |
44 | | - logging.info(msg) |
| 45 | + if self._log and msg not in self._logged_msgs: |
| 46 | + logger.info(msg) |
45 | 47 | self._logged_msgs.add(msg) |
46 | 48 |
|
47 | 49 | def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: |
@@ -154,8 +156,10 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult: |
154 | 156 |
|
155 | 157 | capability_partitioner = CapabilityBasedPartitioner( |
156 | 158 | exported_program.graph_module, |
157 | | - OperatorsSupportedForCoreMLBackend( |
158 | | - self.skip_ops_for_coreml_delegation, self.lower_full_graph |
| 159 | + _OperatorsSupportedForCoreMLBackend( |
| 160 | + self.skip_ops_for_coreml_delegation, |
| 161 | + self.lower_full_graph, |
| 162 | + log=True, |
159 | 163 | ), |
160 | 164 | allows_single_node_partition=True, |
161 | 165 | ) |
@@ -191,8 +195,10 @@ def ops_to_not_decompose( |
191 | 195 | self, ep: ExportedProgram |
192 | 196 | ) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]: |
193 | 197 | do_not_decompose = [] |
194 | | - op_support = OperatorsSupportedForCoreMLBackend( |
195 | | - self.skip_ops_for_coreml_delegation, self.lower_full_graph |
| 198 | + op_support = _OperatorsSupportedForCoreMLBackend( |
| 199 | + self.skip_ops_for_coreml_delegation, |
| 200 | + self.lower_full_graph, |
| 201 | + log=False, |
196 | 202 | ) |
197 | 203 |
|
198 | 204 | # CoreML prevents certain ops (like triu) from lowering to CoreML when put in the ExecuTorch op namespace |
|
0 commit comments