@@ -86,59 +86,88 @@ class KernelType(IntEnum):
8686
8787
8888def _get_operators (model_file : str ) -> List [str ]:
89- from executorch .codegen .tools .selective_build import ( # type: ignore[import-not-found]
90- _get_program_from_buffer ,
91- _get_program_operators ,
92- )
93-
9489 print ("Processing model file: " , model_file )
9590 with open (model_file , "rb" ) as f :
9691 buf = f .read ()
92+ try :
93+ from executorch .codegen .tools .selective_build import ( # type: ignore[import-not-found]
94+ _get_program_from_buffer ,
95+ _get_program_operators ,
96+ )
9797
98- program = _get_program_from_buffer (buf )
99- operators = _get_program_operators (program )
100- print (f"Model file loaded, operators are: { operators } " )
101- return operators
98+ program = _get_program_from_buffer (buf )
99+ operators = _get_program_operators (program )
100+ print (f"Model file loaded, operators are: { operators } " )
101+ return operators
102+ except ModuleNotFoundError :
103+ from executorch .exir ._serialize import _deserialize_pte_binary
104+ model = _deserialize_pte_binary (buf )
105+ operators = set ()
106+ for execPlan in model .execution_plan :
107+ for op in execPlan .operators :
108+ operators .add (op .name )
109+ print (f"Model file loaded, operators are: { operators } " )
110+ return operators
102111
103112
104113def _get_kernel_metadata_for_model (model_file : str ) -> Dict [str , List [str ]]:
105-
106- from executorch .codegen .tools .selective_build import ( # type: ignore[import-not-found]
107- _get_io_metadata_for_program_operators ,
108- _get_program_from_buffer ,
109- _IOMetaData ,
110- )
111-
112114 with open (model_file , "rb" ) as f :
113115 buf = f .read ()
114-
115- program = _get_program_from_buffer (buf )
116- operators_with_io_metadata = _get_io_metadata_for_program_operators (program )
117-
118116 op_kernel_key_list : Dict [str , List [str ]] = {}
119117
120- specialized_kernels : Set [List [_IOMetaData ]]
121- for op_name , specialized_kernels in operators_with_io_metadata .items ():
122- print (op_name )
123- if op_name not in op_kernel_key_list :
124- op_kernel_key_list [op_name ] = []
125-
126- for specialized_kernel in specialized_kernels :
127- version = "v1"
128- kernel_key = version + "/"
129- for io_metadata in specialized_kernel :
130- if io_metadata .kernel_type in [
131- KernelType .TENSOR ,
132- KernelType .TENSOR_LIST ,
133- KernelType .OPTIONAL_TENSOR_LIST ,
134- ]:
135- dim_order = "," .join (map (str , io_metadata .dim_order ))
136- kernel_key += f"{ io_metadata .dtype } ;{ dim_order } |"
137- op_kernel_key_list [op_name ].append (kernel_key [:- 1 ])
118+ try :
119+ from executorch .codegen .tools .selective_build import ( # type: ignore[import-not-found]
120+ _get_io_metadata_for_program_operators ,
121+ _get_program_from_buffer ,
122+ _IOMetaData ,
123+ )
138124
125+ program = _get_program_from_buffer (buf )
126+ operators_with_io_metadata = _get_io_metadata_for_program_operators (program )
127+
128+ specialized_kernels : Set [List [_IOMetaData ]]
129+ for op_name , specialized_kernels in operators_with_io_metadata .items ():
130+ print (op_name )
131+ if op_name not in op_kernel_key_list :
132+ op_kernel_key_list [op_name ] = []
133+
134+ for specialized_kernel in specialized_kernels :
135+ version = "v1"
136+ kernel_key = version + "/"
137+ for io_metadata in specialized_kernel :
138+ if io_metadata .kernel_type in [
139+ KernelType .TENSOR ,
140+ KernelType .TENSOR_LIST ,
141+ KernelType .OPTIONAL_TENSOR_LIST ,
142+ ]:
143+ dim_order = "," .join (map (str , io_metadata .dim_order ))
144+ kernel_key += f"{ io_metadata .dtype } ;{ dim_order } |"
145+ op_kernel_key_list [op_name ].append (kernel_key [:- 1 ])
146+
147+ except ModuleNotFoundError :
148+ from executorch .exir ._serialize import _deserialize_pte_binary
149+ model = _deserialize_pte_binary (buf )
150+ for execPlan in model .execution_plan :
151+ for chain in execPlan .chains :
152+ for instr in chain .instructions :
153+ op_name = execPlan .operators [instr .instr_args .op_index ].name
154+ if op_name not in op_kernel_key_list :
155+ op_kernel_key_list [op_name ] = []
156+ version = "v1"
157+ kernel_key = version + "/"
158+ # TODO what happens when tensors have different types withina single kernel/ is that even allowed?
159+ for tensor_arg in instr .instr_args .args :
160+ val = execPlan .values [tensor_arg ].val
161+
162+ # TODO is there a better way to do this?
163+ if type (val ).__name__ == "Tensor" :
164+ dim_order = "," .join (map (str ,val .dim_order ))
165+ kernel_key += f"{ val .scalar_type } ;{ dim_order } |"
166+ op_kernel_key_list [op_name ].append (kernel_key [:- 1 ])
139167 return op_kernel_key_list
140168
141169
170+
142171def _get_et_kernel_metadata_from_ops_yaml (ops_yaml_path : str ) -> Dict [str , List [str ]]:
143172 ops = []
144173 with open (ops_yaml_path , "r" ) as f :
0 commit comments