Skip to content

Commit 77500cd

Browse files
zonglinpengfacebook-github-bot
authored andcommitted
retrieve cadence_passes in apply_jarvis_passes
Summary: Merge cadence passes in apply_jarvis_passes. This method partition OS passes an jarvis passes on the pass filter: - OSS passes remain the same order as original jarvis passes `remove -> fusion -> replacement ` - **TODO1**: Jarvis only has `ReplaceMatmulWithTransposedMatmulPass` and `ReplaceAvgPoolWithChannelLastAvgPoolPass` left. Matmul should also be used by vision backend should is OSS-able. TBC by mcremon-meta - Jarvis should keep `precompute_for_quantized_linear_pass` simplify because it's only used by v3. skrtskrtfb will make it as a ExportPass. - Jarvis passes are appended to OSS passes - OSS flow: OSS pass filter applies on opt-level and is-debug. Apply OSS pass filter on OSS passes, which guarentee correct results - Jarvis flow: Jarvis pass filter applies on opt-level and is-debug backend. The same filter cannot be reused from OSS because passes are registered in separate pass dictionaries. To get the correct set of passes after filtetring, 1. merge OSS & jarvis pass filter, 2. combine OSS & jarvis passes, 3. apply the merged filter on all passes. **Note**: The filter should return true if __OSS pass attribute **OR** Jarvis pass attribute matches with the filtering criteria__. - Test filter adds coverage in corner case when some passes are filtered by opt-level Differential Revision: D72999266
1 parent 2ff9abd commit 77500cd

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

backends/cadence/aot/pass_utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ class CadencePassAttribute:
3535
ALL_CADENCE_PASSES: dict[ExportPass, CadencePassAttribute] = {}
3636

3737

38-
def get_cadence_pass_attribute(p: ExportPass) -> CadencePassAttribute:
39-
return ALL_CADENCE_PASSES[p]
38+
def get_cadence_pass_attribute(p: ExportPass) -> Optional[CadencePassAttribute]:
39+
return ALL_CADENCE_PASSES.get(p, None)
4040

4141

4242
# A decorator that registers a pass.
@@ -61,7 +61,8 @@ def create_cadence_pass_filter(
6161
def _filter(p: ExportPass) -> bool:
6262
pass_attribute = get_cadence_pass_attribute(p)
6363
return (
64-
pass_attribute.opt_level is not None
64+
pass_attribute is not None
65+
and pass_attribute.opt_level is not None
6566
and pass_attribute.opt_level <= opt_level
6667
and (not pass_attribute.debug_pass or debug)
6768
)

0 commit comments

Comments
 (0)