77# pyre-strict
88
99from dataclasses import dataclass
10- from typing import Callable , List , Optional , Set , Type , Union
10+ from typing import Callable , List , Optional , Set , Union
1111
1212import torch
1313from executorch .backends .cadence .aot .utils import get_edge_overload_packet
@@ -32,33 +32,33 @@ class CadencePassAttribute:
3232
3333
3434# A dictionary that maps an ExportPass to its attributes.
35- ALL_CADENCE_PASSES : dict [Type [ ExportPass ] , CadencePassAttribute ] = {}
35+ ALL_CADENCE_PASSES : dict [ExportPass , CadencePassAttribute ] = {}
3636
3737
38- def get_cadence_pass_attribute (p : Type [ ExportPass ] ) -> CadencePassAttribute :
38+ def get_cadence_pass_attribute (p : ExportPass ) -> CadencePassAttribute :
3939 return ALL_CADENCE_PASSES [p ]
4040
4141
4242# A decorator that registers a pass.
4343def register_cadence_pass (
4444 pass_attribute : CadencePassAttribute ,
45- ) -> Callable [[Type [ ExportPass ]], Type [ ExportPass ] ]:
46- def wrapper (cls : Type [ ExportPass ] ) -> Type [ ExportPass ] :
45+ ) -> Callable [[ExportPass ], ExportPass ]:
46+ def wrapper (cls : ExportPass ) -> ExportPass :
4747 ALL_CADENCE_PASSES [cls ] = pass_attribute
4848 return cls
4949
5050 return wrapper
5151
5252
53- def get_all_available_cadence_passes () -> Set [Type [ ExportPass ] ]:
53+ def get_all_available_cadence_passes () -> Set [ExportPass ]:
5454 return set (ALL_CADENCE_PASSES .keys ())
5555
5656
5757# Create a new filter to filter out relevant passes from all passes.
5858def create_cadence_pass_filter (
5959 opt_level : int , debug : bool = False
60- ) -> Callable [[Type [ ExportPass ] ], bool ]:
61- def _filter (p : Type [ ExportPass ] ) -> bool :
60+ ) -> Callable [[ExportPass ], bool ]:
61+ def _filter (p : ExportPass ) -> bool :
6262 pass_attribute = get_cadence_pass_attribute (p )
6363 return (
6464 pass_attribute .opt_level is not None
0 commit comments