Skip to content

Commit a20fc5f

Browse files
committed
Allows yaml ET file to enable dtype selective build
1 parent 68c58e5 commit a20fc5f

File tree

1 file changed

+67
-38
lines changed

1 file changed

+67
-38
lines changed

codegen/tools/gen_oplist.py

Lines changed: 67 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -86,59 +86,88 @@ class KernelType(IntEnum):
8686

8787

8888
def _get_operators(model_file: str) -> List[str]:
89-
from executorch.codegen.tools.selective_build import ( # type: ignore[import-not-found]
90-
_get_program_from_buffer,
91-
_get_program_operators,
92-
)
93-
9489
print("Processing model file: ", model_file)
9590
with open(model_file, "rb") as f:
9691
buf = f.read()
92+
try:
93+
from executorch.codegen.tools.selective_build import ( # type: ignore[import-not-found]
94+
_get_program_from_buffer,
95+
_get_program_operators,
96+
)
9797

98-
program = _get_program_from_buffer(buf)
99-
operators = _get_program_operators(program)
100-
print(f"Model file loaded, operators are: {operators}")
101-
return operators
98+
program = _get_program_from_buffer(buf)
99+
operators = _get_program_operators(program)
100+
print(f"Model file loaded, operators are: {operators}")
101+
return operators
102+
except ModuleNotFoundError:
103+
from executorch.exir._serialize import _deserialize_pte_binary
104+
model = _deserialize_pte_binary(buf)
105+
operators = set()
106+
for execPlan in model.execution_plan:
107+
for op in execPlan.operators:
108+
operators.add(op.name)
109+
print(f"Model file loaded, operators are: {operators}")
110+
return operators
102111

103112

104113
def _get_kernel_metadata_for_model(model_file: str) -> Dict[str, List[str]]:
105-
106-
from executorch.codegen.tools.selective_build import ( # type: ignore[import-not-found]
107-
_get_io_metadata_for_program_operators,
108-
_get_program_from_buffer,
109-
_IOMetaData,
110-
)
111-
112114
with open(model_file, "rb") as f:
113115
buf = f.read()
114-
115-
program = _get_program_from_buffer(buf)
116-
operators_with_io_metadata = _get_io_metadata_for_program_operators(program)
117-
118116
op_kernel_key_list: Dict[str, List[str]] = {}
119117

120-
specialized_kernels: Set[List[_IOMetaData]]
121-
for op_name, specialized_kernels in operators_with_io_metadata.items():
122-
print(op_name)
123-
if op_name not in op_kernel_key_list:
124-
op_kernel_key_list[op_name] = []
125-
126-
for specialized_kernel in specialized_kernels:
127-
version = "v1"
128-
kernel_key = version + "/"
129-
for io_metadata in specialized_kernel:
130-
if io_metadata.kernel_type in [
131-
KernelType.TENSOR,
132-
KernelType.TENSOR_LIST,
133-
KernelType.OPTIONAL_TENSOR_LIST,
134-
]:
135-
dim_order = ",".join(map(str, io_metadata.dim_order))
136-
kernel_key += f"{io_metadata.dtype};{dim_order}|"
137-
op_kernel_key_list[op_name].append(kernel_key[:-1])
118+
try:
119+
from executorch.codegen.tools.selective_build import ( # type: ignore[import-not-found]
120+
_get_io_metadata_for_program_operators,
121+
_get_program_from_buffer,
122+
_IOMetaData,
123+
)
138124

125+
program = _get_program_from_buffer(buf)
126+
operators_with_io_metadata = _get_io_metadata_for_program_operators(program)
127+
128+
specialized_kernels: Set[List[_IOMetaData]]
129+
for op_name, specialized_kernels in operators_with_io_metadata.items():
130+
print(op_name)
131+
if op_name not in op_kernel_key_list:
132+
op_kernel_key_list[op_name] = []
133+
134+
for specialized_kernel in specialized_kernels:
135+
version = "v1"
136+
kernel_key = version + "/"
137+
for io_metadata in specialized_kernel:
138+
if io_metadata.kernel_type in [
139+
KernelType.TENSOR,
140+
KernelType.TENSOR_LIST,
141+
KernelType.OPTIONAL_TENSOR_LIST,
142+
]:
143+
dim_order = ",".join(map(str, io_metadata.dim_order))
144+
kernel_key += f"{io_metadata.dtype};{dim_order}|"
145+
op_kernel_key_list[op_name].append(kernel_key[:-1])
146+
147+
except ModuleNotFoundError:
148+
from executorch.exir._serialize import _deserialize_pte_binary
149+
model = _deserialize_pte_binary(buf)
150+
for execPlan in model.execution_plan:
151+
for chain in execPlan.chains:
152+
for instr in chain.instructions:
153+
op_name = execPlan.operators[instr.instr_args.op_index].name
154+
if op_name not in op_kernel_key_list:
155+
op_kernel_key_list[op_name] = []
156+
version = "v1"
157+
kernel_key = version + "/"
158+
# TODO what happens when tensors have different types withina single kernel/ is that even allowed?
159+
for tensor_arg in instr.instr_args.args:
160+
val = execPlan.values[tensor_arg].val
161+
162+
# TODO is there a better way to do this?
163+
if type(val).__name__ == "Tensor":
164+
dim_order = ",".join(map(str,val.dim_order))
165+
kernel_key += f"{val.scalar_type};{dim_order}|"
166+
op_kernel_key_list[op_name].append(kernel_key[:-1])
139167
return op_kernel_key_list
140168

141169

170+
142171
def _get_et_kernel_metadata_from_ops_yaml(ops_yaml_path: str) -> Dict[str, List[str]]:
143172
ops = []
144173
with open(ops_yaml_path, "r") as f:

0 commit comments

Comments
 (0)