Skip to content

Commit 9159927

Browse files
committed
Undoing changes to gen_oplist
1 parent 72db779 commit 9159927

File tree

1 file changed

+30
-46
lines changed

1 file changed

+30
-46
lines changed

codegen/tools/gen_oplist.py

Lines changed: 30 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import os
1010
import sys
1111
from enum import IntEnum
12-
from typing import Any, Dict, List, Optional
12+
from typing import Any, Dict, List, Optional, Set
1313

1414
import yaml
1515

@@ -85,17 +85,17 @@ class KernelType(IntEnum):
8585

8686

8787
def _get_operators(model_file: str) -> List[str]:
88+
from executorch.codegen.tools.selective_build import ( # type: ignore[import-not-found]
89+
_get_program_from_buffer,
90+
_get_program_operators,
91+
)
92+
8893
print("Processing model file: ", model_file)
8994
with open(model_file, "rb") as f:
9095
buf = f.read()
9196

92-
from executorch.exir._serialize import _deserialize_pte_binary
93-
94-
model = _deserialize_pte_binary(buf)
95-
operators = []
96-
for execution_plan in model.execution_plan:
97-
for op in execution_plan.operators:
98-
operators.append(op.name)
97+
program = _get_program_from_buffer(buf)
98+
operators = _get_program_operators(program)
9999
print(f"Model file loaded, operators are: {operators}")
100100
return operators
101101

@@ -109,47 +109,31 @@ def _get_kernel_metadata_for_model(model_file: str) -> Dict[str, List[str]]:
109109

110110
with open(model_file, "rb") as f:
111111
buf = f.read()
112+
113+
program = _get_program_from_buffer(buf)
114+
operators_with_io_metadata = _get_io_metadata_for_program_operators(program)
115+
112116
op_kernel_key_list: Dict[str, List[str]] = {}
113117

114-
from executorch.exir._serialize import _deserialize_pte_binary
115-
from executorch.exir.schema import (
116-
EValue,
117-
KernelCall,
118-
OptionalTensorList,
119-
Tensor,
120-
TensorList,
121-
)
118+
specialized_kernels: Set[List[_IOMetaData]]
119+
for op_name, specialized_kernels in operators_with_io_metadata.items():
120+
print(op_name)
121+
if op_name not in op_kernel_key_list:
122+
op_kernel_key_list[op_name] = []
123+
124+
for specialized_kernel in specialized_kernels:
125+
version = "v1"
126+
kernel_key = version + "/"
127+
for io_metadata in specialized_kernel:
128+
if io_metadata.kernel_type in [
129+
KernelType.TENSOR,
130+
KernelType.TENSOR_LIST,
131+
KernelType.OPTIONAL_TENSOR_LIST,
132+
]:
133+
dim_order = ",".join(map(str, io_metadata.dim_order))
134+
kernel_key += f"{io_metadata.dtype};{dim_order}|"
135+
op_kernel_key_list[op_name].append(kernel_key[:-1])
122136

123-
def _get_dtypes_from_non_list(evalue: EValue):
124-
kernel_key = ""
125-
if isinstance(evalue, Tensor):
126-
dim_order = ",".join(map(str, evalue.dim_order))
127-
kernel_key += f"{evalue.scalar_type};{dim_order}|"
128-
return kernel_key
129-
130-
model = _deserialize_pte_binary(buf)
131-
for execution_plan in model.execution_plan:
132-
for chain in execution_plan.chains:
133-
for instr in chain.instructions:
134-
if not isinstance(instr.instr_args, KernelCall):
135-
continue
136-
op_name = execution_plan.operators[instr.instr_args.op_index].name
137-
if op_name not in op_kernel_key_list:
138-
op_kernel_key_list[op_name] = []
139-
version = "v1"
140-
kernel_key = version + "/"
141-
for tensor_arg in instr.instr_args.args:
142-
val = execution_plan.values[tensor_arg].val
143-
144-
if isinstance(val, TensorList) or isinstance(
145-
val, OptionalTensorList
146-
):
147-
for tensor in val.items:
148-
tval = execution_plan.values[tensor].val
149-
kernel_key += _get_dtypes_from_non_list(tval) # type: ignore[arg-type]
150-
151-
kernel_key += _get_dtypes_from_non_list(val) # type: ignore[arg-type]
152-
op_kernel_key_list[op_name].append(kernel_key[:-1])
153137
return op_kernel_key_list
154138

155139

0 commit comments

Comments
 (0)