|
9 | 9 | import os |
10 | 10 | import sys |
11 | 11 | from enum import IntEnum |
12 | | -from typing import Any, Dict, List, Optional, Set |
| 12 | +from typing import Any, Dict, List, Optional |
13 | 13 |
|
14 | 14 | import yaml |
15 | 15 |
|
|
21 | 21 | from ..parse import strip_et_fields |
22 | 22 |
|
23 | 23 |
|
| 24 | +from executorch.exir._serialize import _deserialize_pte_binary |
| 25 | +from executorch.exir.schema import ( |
| 26 | + EValue, |
| 27 | + KernelCall, |
| 28 | + OptionalTensorList, |
| 29 | + Tensor, |
| 30 | + TensorList, |
| 31 | +) |
24 | 32 | from torchgen.gen import LineLoader, parse_native_yaml_struct |
25 | 33 | from torchgen.selective_build.operator import SelectiveBuildOperator |
26 | 34 | from torchgen.selective_build.selector import merge_et_kernel_metadata |
@@ -89,83 +97,53 @@ def _get_operators(model_file: str) -> List[str]: |
89 | 97 | print("Processing model file: ", model_file) |
90 | 98 | with open(model_file, "rb") as f: |
91 | 99 | buf = f.read() |
92 | | - try: |
93 | | - from executorch.codegen.tools.selective_build import ( # type: ignore[import-not-found] |
94 | | - _get_program_from_buffer, |
95 | | - _get_program_operators, |
96 | | - ) |
97 | 100 |
|
98 | | - program = _get_program_from_buffer(buf) |
99 | | - operators = _get_program_operators(program) |
100 | | - print(f"Model file loaded, operators are: {operators}") |
101 | | - return operators |
102 | | - except ModuleNotFoundError: |
103 | | - from executorch.exir._serialize import _deserialize_pte_binary |
104 | | - model = _deserialize_pte_binary(buf) |
105 | | - operators = set() |
106 | | - for execPlan in model.execution_plan: |
107 | | - for op in execPlan.operators: |
108 | | - operators.add(op.name) |
109 | | - print(f"Model file loaded, operators are: {operators}") |
110 | | - return operators |
| 101 | + model = _deserialize_pte_binary(buf) |
| 102 | + operators = [] |
| 103 | + for execution_plan in model.execution_plan: |
| 104 | + for op in execution_plan.operators: |
| 105 | + operators.append(op.name) |
| 106 | + print(f"Model file loaded, operators are: {operators}") |
| 107 | + return operators |
| 108 | + |
| 109 | + |
| 110 | +def _get_dtypes_from_non_list(evalue: EValue): |
| 111 | + kernel_key = "" |
| 112 | + if isinstance(evalue, Tensor): |
| 113 | + dim_order = ",".join(map(str, evalue.dim_order)) |
| 114 | + kernel_key += f"{evalue.scalar_type};{dim_order}|" |
| 115 | + return kernel_key |
111 | 116 |
|
112 | 117 |
|
113 | 118 | def _get_kernel_metadata_for_model(model_file: str) -> Dict[str, List[str]]: |
114 | 119 | with open(model_file, "rb") as f: |
115 | 120 | buf = f.read() |
116 | 121 | op_kernel_key_list: Dict[str, List[str]] = {} |
117 | 122 |
|
118 | | - try: |
119 | | - from executorch.codegen.tools.selective_build import ( # type: ignore[import-not-found] |
120 | | - _get_io_metadata_for_program_operators, |
121 | | - _get_program_from_buffer, |
122 | | - _IOMetaData, |
123 | | - ) |
124 | | - |
125 | | - program = _get_program_from_buffer(buf) |
126 | | - operators_with_io_metadata = _get_io_metadata_for_program_operators(program) |
127 | | - |
128 | | - specialized_kernels: Set[List[_IOMetaData]] |
129 | | - for op_name, specialized_kernels in operators_with_io_metadata.items(): |
130 | | - print(op_name) |
131 | | - if op_name not in op_kernel_key_list: |
132 | | - op_kernel_key_list[op_name] = [] |
133 | | - |
134 | | - for specialized_kernel in specialized_kernels: |
| 123 | + model = _deserialize_pte_binary(buf) |
| 124 | + for execution_plan in model.execution_plan: |
| 125 | + for chain in execution_plan.chains: |
| 126 | + for instr in chain.instructions: |
| 127 | + if not isinstance(instr.instr_args, KernelCall): |
| 128 | + continue |
| 129 | + op_name = execution_plan.operators[instr.instr_args.op_index].name |
| 130 | + if op_name not in op_kernel_key_list: |
| 131 | + op_kernel_key_list[op_name] = [] |
135 | 132 | version = "v1" |
136 | 133 | kernel_key = version + "/" |
137 | | - for io_metadata in specialized_kernel: |
138 | | - if io_metadata.kernel_type in [ |
139 | | - KernelType.TENSOR, |
140 | | - KernelType.TENSOR_LIST, |
141 | | - KernelType.OPTIONAL_TENSOR_LIST, |
142 | | - ]: |
143 | | - dim_order = ",".join(map(str, io_metadata.dim_order)) |
144 | | - kernel_key += f"{io_metadata.dtype};{dim_order}|" |
145 | | - op_kernel_key_list[op_name].append(kernel_key[:-1]) |
| 134 | + for tensor_arg in instr.instr_args.args: |
| 135 | + val = execution_plan.values[tensor_arg].val |
146 | 136 |
|
147 | | - except ModuleNotFoundError: |
148 | | - from executorch.exir._serialize import _deserialize_pte_binary |
149 | | - model = _deserialize_pte_binary(buf) |
150 | | - for execPlan in model.execution_plan: |
151 | | - for chain in execPlan.chains: |
152 | | - for instr in chain.instructions: |
153 | | - op_name = execPlan.operators[instr.instr_args.op_index].name |
154 | | - if op_name not in op_kernel_key_list: |
155 | | - op_kernel_key_list[op_name] = [] |
156 | | - version = "v1" |
157 | | - kernel_key = version + "/" |
158 | | - # TODO what happens when tensors have different types withina single kernel/ is that even allowed? |
159 | | - for tensor_arg in instr.instr_args.args: |
160 | | - val = execPlan.values[tensor_arg].val |
161 | | - |
162 | | - # TODO is there a better way to do this? |
163 | | - if type(val).__name__ == "Tensor": |
164 | | - dim_order = ",".join(map(str,val.dim_order)) |
165 | | - kernel_key += f"{val.scalar_type};{dim_order}|" |
166 | | - op_kernel_key_list[op_name].append(kernel_key[:-1]) |
167 | | - return op_kernel_key_list |
| 137 | + if isinstance(val, TensorList) or isinstance( |
| 138 | + val, OptionalTensorList |
| 139 | + ): |
| 140 | + for tensor in val.items: |
| 141 | + tval = execution_plan.values[tensor].val |
| 142 | + kernel_key += _get_dtypes_from_non_list(tval) # type: ignore[arg-type] |
168 | 143 |
|
| 144 | + kernel_key += _get_dtypes_from_non_list(val) # type: ignore[arg-type] |
| 145 | + op_kernel_key_list[op_name].append(kernel_key[:-1]) |
| 146 | + return op_kernel_key_list |
169 | 147 |
|
170 | 148 |
|
171 | 149 | def _get_et_kernel_metadata_from_ops_yaml(ops_yaml_path: str) -> Dict[str, List[str]]: |
|
0 commit comments