Skip to content

Commit cd2bd8b

Browse files
kimishpatelfacebook-github-bot
authored andcommitted
Add selective build support for prim ops (#14332)
Summary: This diff implements selective build functionality for primitive operations (prim ops) in ExecutorTorch, allowing users to include only specific prim ops in their builds to reduce binary size and compilation time. ## Key Changes: 1. **Conditional compilation in register_prim_ops.cpp**: Wrapped each of the prim op registrations with conditional compilation macros that check both selective build enablement (`EXECUTORCH_ENABLE_PRIM_OPS_SELECTIVE_BUILD`) and individual op selection (e.g., `INCLUDE_EXECUTORCH_PRIM_ET_VIEW_DEFAULT`). 2. **Code generation tool**: Added `gen_selected_prim_ops.py` that takes comma-separated prim op names and generates a header file (`selected_prim_ops.h`) containing appropriate `#define` statements for selected ops. The tool normalizes op names to macro-safe format (e.g., `executorch_prim::et_view.default` → `INCLUDE_EXECUTORCH_PRIM_ET_VIEW_DEFAULT`). 3. **Build system integration**: In order to make et_operator_library also handle prim of selective build we make a few changes. 1. Extract prim ops in et_operator_library 2. Similar to gen_op_list, we invoke script that geneates selected_prim_ops.h file per et_operator_library target. Thus et_operator_library now generates selected_operators.yaml and selected_prim_ops.h. Note that in order to make these work we have to allow et_operator_libray to handle the following cases. 1. All ops are aten ops 2. All ops are prim ops 3. Mix To do this we must make sure that the genrule continues to produce the file it says it will produce. In the case of 1 we have to produce empty selected_prim_opsh. and in case 2 we have to produce emtpy selected_operators.yaml 3. In gen_all_oplist we allow for empty selected_operators.yaml and skip the file. 4. Similar to gen_all_oplist we introduce another binary that combines all selected_prim_ops.h. 5. Then in executorch_generated_lib we query targets from 4 that have selected_prim_ops and use those to compile register_prim_ops.cpp. In executorch_generate_lib we introduce include_all_prim_ops which by default is True. Hence if one wants to enable selective build for prim ops one must turn off that flag ## Usage: Users can now specify prim ops like: ``` et_operator_library(name="my_aten_prim_ops", ops=["aten::mul.out", "executorch_prim::et_view.default", "aten::sym_size.int"]) executorch_generated_lib(name="my_lib", deps=[":my_aten_prim_ops"] + other_deps, include_all_prim_ops=False) ``` Reviewed By: ivayloen, larryliu0820 Differential Revision: D81648030
1 parent 0f22062 commit cd2bd8b

File tree

12 files changed

+743
-22
lines changed

12 files changed

+743
-22
lines changed
Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
Script to combine multiple selected_prim_ops.h header files into a single header.
9+
This is used by selected_prim_operators_genrule to merge prim ops headers from dependencies.
10+
"""
11+
12+
import argparse
13+
import os
14+
import sys
15+
from pathlib import Path
16+
from typing import List, Set
17+
18+
19+
def read_header_file(file_path: Path) -> Set[str]:
20+
"""
21+
Read a selected_prim_ops.h file and extract the macros and comments.
22+
23+
Args:
24+
file_path: Path to the header file
25+
26+
Returns:
27+
macros_set where macros_set contains unique macro defines
28+
"""
29+
macros = set()
30+
31+
try:
32+
with open(file_path, "r") as f:
33+
for line in f:
34+
line = line.strip()
35+
36+
# Extract #define statements for prim ops
37+
if line.startswith("#define INCLUDE_") and not line.startswith(
38+
"#define EXECUTORCH_ENABLE"
39+
):
40+
macros.add(line)
41+
except FileNotFoundError:
42+
print(f"Warning: Header file not found: {file_path}", file=sys.stderr)
43+
except Exception as e:
44+
print(f"Error reading {file_path}: {e}", file=sys.stderr)
45+
46+
return macros
47+
48+
49+
def combine_prim_ops_headers(header_file_paths: List[str], output_path: str) -> None:
50+
"""
51+
Combine multiple selected_prim_ops.h files into a single header.
52+
53+
Args:
54+
header_files: List of paths to header files to combine
55+
output_path: Path to output the combined header
56+
"""
57+
all_macros = set()
58+
has_selective_build = False
59+
60+
# Read all header files and collect unique macros
61+
for header_file_path in header_file_paths:
62+
header_file = Path(header_file_path) / "selected_prim_ops.h"
63+
if os.path.exists(header_file):
64+
macros = read_header_file(header_file)
65+
all_macros.update(macros)
66+
if len(all_macros) > 0:
67+
has_selective_build = True
68+
else:
69+
print(
70+
f"Warning: Header file does not exist: {header_file}", file=sys.stderr
71+
)
72+
73+
# Generate combined header
74+
header_content = [
75+
"// Combined header for selective prim ops build",
76+
"// This file is auto-generated by combining multiple selected_prim_ops.h files",
77+
"// Do not edit manually.",
78+
"",
79+
"#pragma once",
80+
"",
81+
]
82+
83+
if all_macros and has_selective_build:
84+
header_content.extend(
85+
[
86+
"// Enable selective build for prim ops",
87+
"#define EXECUTORCH_ENABLE_PRIM_OPS_SELECTIVE_BUILD",
88+
"",
89+
"// Combined prim ops macros from all dependencies",
90+
]
91+
)
92+
93+
# Sort macros for deterministic output
94+
sorted_macros = sorted(all_macros)
95+
header_content.extend(sorted_macros)
96+
else:
97+
header_content.extend(
98+
[
99+
"// No prim ops found in dependencies - all prim ops will be included",
100+
"// Selective build is disabled",
101+
]
102+
)
103+
104+
header_content.append("")
105+
106+
# Write the combined header
107+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
108+
with open(output_path, "w") as f:
109+
f.write("\n".join(header_content))
110+
111+
112+
def _get_header_file_paths_from_query_output(query_output_file: str) -> List[str]:
113+
"""
114+
Parse the output of a Buck query command to extract header file paths.
115+
116+
Args:
117+
query_output_file: Path to the file containing the query output
118+
119+
Returns:
120+
List of header file paths
121+
"""
122+
header_file_paths = []
123+
assert (
124+
query_output_file[0] == "@"
125+
), "query_output_file is not a valid file path, or it doesn't start with '@'."
126+
query_output_file = query_output_file[1:]
127+
128+
with open(query_output_file, "r") as f:
129+
for line in f:
130+
# Extract the header file path from the query output
131+
header_file_paths += line.split()
132+
return header_file_paths
133+
134+
135+
def main():
136+
parser = argparse.ArgumentParser(
137+
description="Combine multiple selected_prim_ops.h header files"
138+
)
139+
parser.add_argument(
140+
"--header_files",
141+
required=True,
142+
help="Comma-separated list of header file paths",
143+
)
144+
parser.add_argument(
145+
"--output_dir", required=True, help="Output directory for combined header"
146+
)
147+
148+
args = parser.parse_args()
149+
import os
150+
151+
header_file_paths = _get_header_file_paths_from_query_output(args.header_files)
152+
153+
if not header_file_paths:
154+
print("Error: No header files provided", file=sys.stderr)
155+
sys.exit(1)
156+
157+
# Generate output path
158+
output_path = os.path.join(args.output_dir, "selected_prim_ops.h")
159+
160+
combine_prim_ops_headers(header_file_paths, output_path)
161+
162+
163+
if __name__ == "__main__":
164+
main()

codegen/tools/gen_all_oplist.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import sys
1111
from functools import reduce
1212
from pathlib import Path
13-
from typing import Any, List
13+
from typing import Any, Dict, List
1414

1515
import yaml
1616
from torchgen.selective_build.selector import (
@@ -72,6 +72,19 @@ def _raise_if_check_prim_ops_fail(options):
7272
raise Exception(error)
7373

7474

75+
def _selected_ops_model_dict_is_empty(model_dict: Dict[str, Any]) -> bool:
76+
return (
77+
not model_dict.get("build_features", [])
78+
and not model_dict.get("custom_classes", [])
79+
and not model_dict.get("et_kernel_metadata", None)
80+
and not model_dict.get("include_all_non_op_selectives", False)
81+
and not model_dict.get("include_all_operators", False)
82+
and not model_dict.get("kernel_metadata", {})
83+
and not model_dict.get("operators", {})
84+
)
85+
86+
87+
# flake8: noqa: C901
7588
def main(argv: List[Any]) -> None:
7689
"""This binary generates 3 files:
7790
@@ -171,6 +184,11 @@ def main(argv: List[Any]) -> None:
171184
), f"{model_file_name} is not a valid file path. This is likely a BUCK issue."
172185
with open(model_file_name, "rb") as model_file:
173186
model_dict = yaml.safe_load(model_file)
187+
# It is possible that we created an empty yaml file.
188+
# This is because et_operator_library may only contain prim ops.
189+
# In that case selected_operators.yaml will be empty.
190+
if _selected_ops_model_dict_is_empty(model_dict):
191+
continue
174192
resolved = resolve_model_file_path_to_buck_target(model_file_name)
175193
for op in model_dict["operators"]:
176194
model_dict["operators"][op]["debug_info"] = [resolved]

codegen/tools/gen_oplist.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import os
1010
import sys
1111
from enum import IntEnum
12+
from pathlib import Path
1213
from typing import Any, Dict, List, Optional, Set
1314

1415
import yaml
@@ -158,7 +159,7 @@ def _get_et_kernel_metadata_from_ops_yaml(ops_yaml_path: str) -> Dict[str, List[
158159

159160
def _dump_yaml(
160161
op_list: List[str],
161-
output_path: str,
162+
output_path: Path,
162163
model_name: Optional[str] = None,
163164
et_kernel_metadata: Optional[Dict[str, List[str]]] = None,
164165
include_all_operators: bool = False,
@@ -212,20 +213,23 @@ def create_kernel_key(maybe_kernel_key: str) -> str:
212213

213214

214215
def gen_oplist(
215-
output_path: str,
216+
output_path: Path,
216217
model_file_path: Optional[str] = None,
217218
ops_schema_yaml_path: Optional[str] = None,
218219
root_ops: Optional[str] = None,
219220
ops_dict: Optional[str] = None,
220221
include_all_operators: bool = False,
221222
):
222-
assert (
223+
if not (
223224
model_file_path
224225
or ops_schema_yaml_path
225226
or root_ops
226227
or ops_dict
227228
or include_all_operators
228-
), "Need to provide either model_file_path or ops_schema_yaml_path or root_ops or ops_dict or include_all_operators."
229+
):
230+
# dump empty yaml file
231+
_dump_yaml([], output_path)
232+
return
229233

230234
assert output_path, "Need to provide output_path for dumped yaml file."
231235
op_set = set()
@@ -326,9 +330,15 @@ def main(args: List[Any]) -> None:
326330
)
327331
options = parser.parse_args(args)
328332

333+
# check if the output_path is a directory, then generate operators
334+
# under selected_operators.yaml
335+
if Path(options.output_path).is_dir():
336+
output_path = Path(options.output_path) / "selected_operators.yaml"
337+
else:
338+
output_path = Path(options.output_path)
329339
try:
330340
gen_oplist(
331-
output_path=options.output_path,
341+
output_path=output_path,
332342
model_file_path=options.model_file_path,
333343
ops_schema_yaml_path=options.ops_schema_yaml_path,
334344
root_ops=options.root_ops,
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
import argparse
10+
import os
11+
import sys
12+
from typing import Any, List
13+
14+
from torchgen.code_template import CodeTemplate # type: ignore[import-not-found]
15+
16+
17+
selected_prim_ops_h_template_str = """#pragma once
18+
/**
19+
* Generated by executorch/codegen/tools/gen_selected_prim_ops.py
20+
*/
21+
22+
$defines
23+
"""
24+
selected_prim_ops_h_template = CodeTemplate(selected_prim_ops_h_template_str)
25+
26+
27+
def normalize_op_name(op_name: str) -> str:
28+
"""
29+
Normalize an operator name to a macro-safe format.
30+
Convert op names like "executorch_prim::et_view.default" to "EXECUTORCH_PRIM_ET_VIEW_DEFAULT"
31+
or "aten::sym_size.int" to "ATEN_SYM_SIZE_INT"
32+
"""
33+
# Remove namespace separator and replace with underscore
34+
normalized = op_name.replace("::", "_")
35+
# Replace dots with underscores
36+
normalized = normalized.replace(".", "_")
37+
# Convert to uppercase
38+
normalized = normalized.upper()
39+
# Add INCLUDE_ prefix
40+
normalized = f"INCLUDE_{normalized}"
41+
return normalized
42+
43+
44+
def write_selected_prim_ops(prim_op_names: List[str], output_dir: str) -> None:
45+
"""
46+
Generate selected_prim_ops.h from a list of prim op names.
47+
48+
Args:
49+
prim_op_names: List of prim op names like ["executorch_prim::et_view.default", "aten::sym_size.int"]
50+
output_dir: Directory where to write selected_prim_ops.h
51+
"""
52+
# Generate #define statements for each op
53+
defines = []
54+
for op_name in prim_op_names:
55+
macro_name = normalize_op_name(op_name)
56+
defines.append(f"#define {macro_name}")
57+
58+
# Join all defines with newlines
59+
defines_str = "\n".join(defines)
60+
61+
# Generate header content
62+
header_contents = selected_prim_ops_h_template.substitute(defines=defines_str)
63+
64+
# Write to file
65+
selected_prim_ops_path = os.path.join(output_dir, "selected_prim_ops.h")
66+
with open(selected_prim_ops_path, "wb") as out_file:
67+
out_file.write(header_contents.encode("utf-8"))
68+
69+
70+
def main(argv: List[Any]) -> None:
71+
parser = argparse.ArgumentParser(description="Generate selected prim ops header")
72+
parser.add_argument(
73+
"--prim-op-names",
74+
"--prim_op_names",
75+
help="Comma-separated list of prim op names to include",
76+
required=True,
77+
)
78+
parser.add_argument(
79+
"--output-dir",
80+
"--output_dir",
81+
help="The directory to store the output header file (selected_prim_ops.h)",
82+
required=True,
83+
)
84+
85+
options = parser.parse_args(argv)
86+
87+
# Parse comma-separated prim op names
88+
prim_op_names = [
89+
name.strip() for name in options.prim_op_names.split(",") if name.strip()
90+
]
91+
92+
write_selected_prim_ops(prim_op_names, options.output_dir)
93+
94+
95+
if __name__ == "__main__":
96+
main(sys.argv[1:])

0 commit comments

Comments
 (0)