99import os
1010import sys
1111from enum import IntEnum
12- from typing import Any , Dict , List , Optional
12+ from typing import Any , Dict , List , Optional , Set
1313
1414import yaml
1515
@@ -85,17 +85,17 @@ class KernelType(IntEnum):
8585
8686
8787def _get_operators (model_file : str ) -> List [str ]:
88+ from executorch .codegen .tools .selective_build import ( # type: ignore[import-not-found]
89+ _get_program_from_buffer ,
90+ _get_program_operators ,
91+ )
92+
8893 print ("Processing model file: " , model_file )
8994 with open (model_file , "rb" ) as f :
9095 buf = f .read ()
9196
92- from executorch .exir ._serialize import _deserialize_pte_binary
93-
94- model = _deserialize_pte_binary (buf )
95- operators = []
96- for execution_plan in model .execution_plan :
97- for op in execution_plan .operators :
98- operators .append (op .name )
97+ program = _get_program_from_buffer (buf )
98+ operators = _get_program_operators (program )
9999 print (f"Model file loaded, operators are: { operators } " )
100100 return operators
101101
@@ -109,47 +109,31 @@ def _get_kernel_metadata_for_model(model_file: str) -> Dict[str, List[str]]:
109109
110110 with open (model_file , "rb" ) as f :
111111 buf = f .read ()
112+
113+ program = _get_program_from_buffer (buf )
114+ operators_with_io_metadata = _get_io_metadata_for_program_operators (program )
115+
112116 op_kernel_key_list : Dict [str , List [str ]] = {}
113117
114- from executorch .exir ._serialize import _deserialize_pte_binary
115- from executorch .exir .schema import (
116- EValue ,
117- KernelCall ,
118- OptionalTensorList ,
119- Tensor ,
120- TensorList ,
121- )
118+ specialized_kernels : Set [List [_IOMetaData ]]
119+ for op_name , specialized_kernels in operators_with_io_metadata .items ():
120+ print (op_name )
121+ if op_name not in op_kernel_key_list :
122+ op_kernel_key_list [op_name ] = []
123+
124+ for specialized_kernel in specialized_kernels :
125+ version = "v1"
126+ kernel_key = version + "/"
127+ for io_metadata in specialized_kernel :
128+ if io_metadata .kernel_type in [
129+ KernelType .TENSOR ,
130+ KernelType .TENSOR_LIST ,
131+ KernelType .OPTIONAL_TENSOR_LIST ,
132+ ]:
133+ dim_order = "," .join (map (str , io_metadata .dim_order ))
134+ kernel_key += f"{ io_metadata .dtype } ;{ dim_order } |"
135+ op_kernel_key_list [op_name ].append (kernel_key [:- 1 ])
122136
123- def _get_dtypes_from_non_list (evalue : EValue ):
124- kernel_key = ""
125- if isinstance (evalue , Tensor ):
126- dim_order = "," .join (map (str , evalue .dim_order ))
127- kernel_key += f"{ evalue .scalar_type } ;{ dim_order } |"
128- return kernel_key
129-
130- model = _deserialize_pte_binary (buf )
131- for execution_plan in model .execution_plan :
132- for chain in execution_plan .chains :
133- for instr in chain .instructions :
134- if not isinstance (instr .instr_args , KernelCall ):
135- continue
136- op_name = execution_plan .operators [instr .instr_args .op_index ].name
137- if op_name not in op_kernel_key_list :
138- op_kernel_key_list [op_name ] = []
139- version = "v1"
140- kernel_key = version + "/"
141- for tensor_arg in instr .instr_args .args :
142- val = execution_plan .values [tensor_arg ].val
143-
144- if isinstance (val , TensorList ) or isinstance (
145- val , OptionalTensorList
146- ):
147- for tensor in val .items :
148- tval = execution_plan .values [tensor ].val
149- kernel_key += _get_dtypes_from_non_list (tval ) # type: ignore[arg-type]
150-
151- kernel_key += _get_dtypes_from_non_list (val ) # type: ignore[arg-type]
152- op_kernel_key_list [op_name ].append (kernel_key [:- 1 ])
153137 return op_kernel_key_list
154138
155139
0 commit comments