Skip to content

Commit ffea630

Browse files
kimishpatelfacebook-github-bot
authored andcommitted
Add selective build support for prim ops
Summary: This diff implements selective build functionality for primitive operations (prim ops) in ExecutorTorch, allowing users to include only specific prim ops in their builds to reduce binary size and compilation time. ## Key Changes: 1. **Conditional compilation in register_prim_ops.cpp**: Wrapped each of the prim op registrations with conditional compilation macros that check both selective build enablement (`EXECUTORCH_ENABLE_PRIM_OPS_SELECTIVE_BUILD`) and individual op selection (e.g., `INCLUDE_EXECUTORCH_PRIM_ET_VIEW_DEFAULT`). 2. **Code generation tool**: Added `gen_selected_prim_ops.py` that takes comma-separated prim op names and generates a header file (`selected_prim_ops.h`) containing appropriate `#define` statements for selected ops. The tool normalizes op names to macro-safe format (e.g., `executorch_prim::et_view.default` → `INCLUDE_EXECUTORCH_PRIM_ET_VIEW_DEFAULT`). 3. **Build system integration**: In order to make et_operator_library also handle prim of selective build we make a few changes. 1. Extract prim ops in et_operator_library 2. Similar to gen_op_list, we invoke script that geneates selected_prim_ops.h file per et_operator_library target. Thus et_operator_library now generates selected_operators.yaml and selected_prim_ops.h. Note that in order to make these work we have to allow et_operator_libray to handle the following cases. 1. All ops are aten ops 2. All ops are prim ops 3. Mix To do this we must make sure that the genrule continues to produce the file it says it will produce. In the case of 1 we have to produce empty selected_prim_opsh. and in case 2 we have to produce emtpy selected_operators.yaml 3. In gen_all_oplist we allow for empty selected_operators.yaml and skip the file. 4. Similar to gen_all_oplist we introduce another binary that combines all selected_prim_ops.h. 5. Then in executorch_generated_lib we query targets from 4 that have selected_prim_ops and use those to compile register_prim_ops.cpp. In executorch_generate_lib we introduce include_all_prim_ops which by default is True. Hence if one wants to enable selective build for prim ops one must turn off that flag ## Usage: Users can now specify prim ops like: ``` et_operator_library(name="my_aten_prim_ops", ops=["aten::mul.out", "executorch_prim::et_view.default", "aten::sym_size.int"]) executorch_generated_lib(name="my_lib", deps=[":my_aten_prim_ops"] + other_deps, include_all_prim_ops=False) ``` Reviewed By: ivayloen, larryliu0820 Differential Revision: D81648030
1 parent d1e1bf8 commit ffea630

File tree

13 files changed

+903
-14
lines changed

13 files changed

+903
-14
lines changed
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
"""
9+
Script to combine multiple selected_prim_ops.h header files into a single header.
10+
This is used by selected_prim_operators_genrule to merge prim ops headers from dependencies.
11+
"""
12+
13+
import argparse
14+
import os
15+
import sys
16+
from pathlib import Path
17+
from typing import List, Set
18+
19+
20+
def read_header_file(file_path: Path) -> Set[str]:
21+
"""
22+
Read a selected_prim_ops.h file and extract the macros and comments.
23+
24+
Args:
25+
file_path: Path to the header file
26+
27+
Returns:
28+
macros_set where macros_set contains unique macro defines
29+
"""
30+
macros = set()
31+
32+
try:
33+
with open(file_path, 'r') as f:
34+
for line in f:
35+
line = line.strip()
36+
37+
# Extract #define statements for prim ops
38+
if line.startswith('#define INCLUDE_') and not line.startswith('#define EXECUTORCH_ENABLE'):
39+
macros.add(line)
40+
except FileNotFoundError:
41+
print(f"Warning: Header file not found: {file_path}", file=sys.stderr)
42+
except Exception as e:
43+
print(f"Error reading {file_path}: {e}", file=sys.stderr)
44+
45+
return macros
46+
47+
48+
def combine_prim_ops_headers(header_file_paths: List[str], output_path: str) -> None:
49+
"""
50+
Combine multiple selected_prim_ops.h files into a single header.
51+
52+
Args:
53+
header_files: List of paths to header files to combine
54+
output_path: Path to output the combined header
55+
"""
56+
all_macros = set()
57+
has_selective_build = False
58+
59+
# Read all header files and collect unique macros
60+
for header_file_path in header_file_paths:
61+
header_file = Path(header_file_path) / 'selected_prim_ops.h'
62+
if os.path.exists(header_file):
63+
macros = read_header_file(header_file)
64+
all_macros.update(macros)
65+
if len(all_macros) > 0:
66+
has_selective_build = True
67+
else:
68+
print(f"Warning: Header file does not exist: {header_file}", file=sys.stderr)
69+
70+
# Generate combined header
71+
header_content = [
72+
"// Combined header for selective prim ops build",
73+
"// This file is auto-generated by combining multiple selected_prim_ops.h files",
74+
"// Do not edit manually.",
75+
"",
76+
"#pragma once",
77+
"",
78+
]
79+
80+
if all_macros and has_selective_build:
81+
header_content.extend([
82+
"// Enable selective build for prim ops",
83+
"#define EXECUTORCH_ENABLE_PRIM_OPS_SELECTIVE_BUILD",
84+
"",
85+
"// Combined prim ops macros from all dependencies",
86+
])
87+
88+
# Sort macros for deterministic output
89+
sorted_macros = sorted(all_macros)
90+
header_content.extend(sorted_macros)
91+
else:
92+
header_content.extend([
93+
"// No prim ops found in dependencies - all prim ops will be included",
94+
"// Selective build is disabled",
95+
])
96+
97+
header_content.append("")
98+
99+
# Write the combined header
100+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
101+
with open(output_path, 'w') as f:
102+
f.write('\n'.join(header_content))
103+
104+
def _get_header_file_paths_from_query_output(query_output_file: str) -> List[str]:
105+
"""
106+
Parse the output of a Buck query command to extract header file paths.
107+
108+
Args:
109+
query_output_file: Path to the file containing the query output
110+
111+
Returns:
112+
List of header file paths
113+
"""
114+
header_file_paths = []
115+
assert query_output_file[0] == '@', "query_output_file is not a valid file path, or it doesn't start with '@'."
116+
query_output_file = query_output_file[1:]
117+
118+
with open(query_output_file, 'r') as f:
119+
for line in f:
120+
# Extract the header file path from the query output
121+
header_file_paths += line.split()
122+
return header_file_paths
123+
124+
125+
def main():
126+
parser = argparse.ArgumentParser(description='Combine multiple selected_prim_ops.h header files')
127+
parser.add_argument('--header_files', required=True, help='Comma-separated list of header file paths')
128+
parser.add_argument('--output_dir', required=True, help='Output directory for combined header')
129+
130+
args = parser.parse_args()
131+
import os
132+
header_file_paths = _get_header_file_paths_from_query_output(args.header_files)
133+
134+
if not header_file_paths:
135+
print("Error: No header files provided", file=sys.stderr)
136+
sys.exit(1)
137+
138+
# Generate output path
139+
output_path = os.path.join(args.output_dir, 'selected_prim_ops.h')
140+
141+
combine_prim_ops_headers(header_file_paths, output_path)
142+
143+
144+
if __name__ == '__main__':
145+
main()

codegen/tools/gen_all_oplist.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import sys
1111
from functools import reduce
1212
from pathlib import Path
13-
from typing import Any, List
13+
from typing import Any, Dict, List
1414

1515
import yaml
1616
from torchgen.selective_build.selector import (
@@ -71,6 +71,16 @@ def _raise_if_check_prim_ops_fail(options):
7171
)
7272
raise Exception(error)
7373

74+
def _selected_ops_model_dict_is_empty(model_dict: Dict[str, Any]) -> bool:
75+
return (
76+
not model_dict.get("build_features", []) and
77+
not model_dict.get("custom_classes", []) and
78+
not model_dict.get("et_kernel_metadata", None) and
79+
not model_dict.get("include_all_non_op_selectives", False) and
80+
not model_dict.get("include_all_operators", False) and
81+
not model_dict.get("kernel_metadata", {}) and
82+
not model_dict.get("operators", {})
83+
)
7484

7585
def main(argv: List[Any]) -> None:
7686
"""This binary generates 3 files:
@@ -171,6 +181,11 @@ def main(argv: List[Any]) -> None:
171181
), f"{model_file_name} is not a valid file path. This is likely a BUCK issue."
172182
with open(model_file_name, "rb") as model_file:
173183
model_dict = yaml.safe_load(model_file)
184+
# It is possible that we created an empty yaml file.
185+
# This is because et_operator_library may only contain prim ops.
186+
# In that case selected_operators.yaml will be empty.
187+
if _selected_ops_model_dict_is_empty(model_dict):
188+
continue
174189
resolved = resolve_model_file_path_to_buck_target(model_file_name)
175190
for op in model_dict["operators"]:
176191
model_dict["operators"][op]["debug_info"] = [resolved]

codegen/tools/gen_oplist.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import sys
1111
from enum import IntEnum
1212
from typing import Any, Dict, List, Optional, Set
13+
from pathlib import Path
1314

1415
import yaml
1516

@@ -219,13 +220,16 @@ def gen_oplist(
219220
ops_dict: Optional[str] = None,
220221
include_all_operators: bool = False,
221222
):
222-
assert (
223+
if not (
223224
model_file_path
224225
or ops_schema_yaml_path
225226
or root_ops
226227
or ops_dict
227228
or include_all_operators
228-
), "Need to provide either model_file_path or ops_schema_yaml_path or root_ops or ops_dict or include_all_operators."
229+
):
230+
# dump empty yaml file
231+
_dump_yaml([], output_path)
232+
return
229233

230234
assert output_path, "Need to provide output_path for dumped yaml file."
231235
op_set = set()
@@ -326,9 +330,10 @@ def main(args: List[Any]) -> None:
326330
)
327331
options = parser.parse_args(args)
328332

333+
output_path = Path(options.output_path) / "selected_operators.yaml"
329334
try:
330335
gen_oplist(
331-
output_path=options.output_path,
336+
output_path=output_path,
332337
model_file_path=options.model_file_path,
333338
ops_schema_yaml_path=options.ops_schema_yaml_path,
334339
root_ops=options.root_ops,
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
#!/usr/bin/env fbpython
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# pyre-unsafe
9+
10+
import argparse
11+
import os
12+
import sys
13+
from typing import Any, List
14+
15+
from torchgen.code_template import CodeTemplate
16+
17+
18+
selected_prim_ops_h_template_str = """#pragma once
19+
/**
20+
* Generated by executorch/codegen/tools/gen_selected_prim_ops.py
21+
*/
22+
23+
$defines
24+
"""
25+
selected_prim_ops_h_template = CodeTemplate(selected_prim_ops_h_template_str)
26+
27+
28+
def normalize_op_name(op_name: str) -> str:
29+
"""
30+
Normalize an operator name to a macro-safe format.
31+
Convert op names like "executorch_prim::et_view.default" to "EXECUTORCH_PRIM_ET_VIEW_DEFAULT"
32+
or "aten::sym_size.int" to "ATEN_SYM_SIZE_INT"
33+
"""
34+
# Remove namespace separator and replace with underscore
35+
normalized = op_name.replace("::", "_")
36+
# Replace dots with underscores
37+
normalized = normalized.replace(".", "_")
38+
# Convert to uppercase
39+
normalized = normalized.upper()
40+
# Add INCLUDE_ prefix
41+
normalized = f"INCLUDE_{normalized}"
42+
return normalized
43+
44+
45+
def write_selected_prim_ops(prim_op_names: List[str], output_dir: str) -> None:
46+
"""
47+
Generate selected_prim_ops.h from a list of prim op names.
48+
49+
Args:
50+
prim_op_names: List of prim op names like ["executorch_prim::et_view.default", "aten::sym_size.int"]
51+
output_dir: Directory where to write selected_prim_ops.h
52+
"""
53+
# Generate #define statements for each op
54+
defines = []
55+
for op_name in prim_op_names:
56+
macro_name = normalize_op_name(op_name)
57+
defines.append(f"#define {macro_name}")
58+
59+
# Join all defines with newlines
60+
defines_str = "\n".join(defines)
61+
62+
# Generate header content
63+
header_contents = selected_prim_ops_h_template.substitute(defines=defines_str)
64+
65+
# Write to file
66+
selected_prim_ops_path = os.path.join(output_dir, "selected_prim_ops.h")
67+
with open(selected_prim_ops_path, "wb") as out_file:
68+
out_file.write(header_contents.encode("utf-8"))
69+
70+
71+
def main(argv: List[Any]) -> None:
72+
parser = argparse.ArgumentParser(description="Generate selected prim ops header")
73+
parser.add_argument(
74+
"--prim-op-names",
75+
"--prim_op_names",
76+
help="Comma-separated list of prim op names to include",
77+
required=True,
78+
)
79+
parser.add_argument(
80+
"--output-dir",
81+
"--output_dir",
82+
help="The directory to store the output header file (selected_prim_ops.h)",
83+
required=True,
84+
)
85+
86+
options = parser.parse_args(argv)
87+
88+
# Parse comma-separated prim op names
89+
prim_op_names = [name.strip() for name in options.prim_op_names.split(",") if name.strip()]
90+
91+
write_selected_prim_ops(prim_op_names, options.output_dir)
92+
93+
94+
if __name__ == "__main__":
95+
main(sys.argv[1:])

codegen/tools/targets.bzl

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,26 @@ def define_common_targets(is_fbcode = False):
103103
_is_external_target = True,
104104
)
105105

106+
runtime.python_library(
107+
name = "combine_prim_ops_headers_lib",
108+
srcs = ["combine_prim_ops_headers.py"],
109+
base_module = "executorch.codegen.tools",
110+
visibility = ["//executorch/..."],
111+
)
112+
113+
runtime.python_binary(
114+
name = "combine_prim_ops_headers",
115+
main_module = "executorch.codegen.tools.combine_prim_ops_headers",
116+
package_style = "inplace",
117+
visibility = [
118+
"PUBLIC",
119+
],
120+
deps = [
121+
":combine_prim_ops_headers_lib",
122+
],
123+
_is_external_target = True,
124+
)
125+
106126
runtime.python_test(
107127
name = "test_gen_all_oplist",
108128
srcs = [
@@ -155,6 +175,27 @@ def define_common_targets(is_fbcode = False):
155175
_is_external_target = True,
156176
)
157177

178+
runtime.python_library(
179+
name = "gen_selected_prim_ops_lib",
180+
srcs = ["gen_selected_prim_ops.py"],
181+
base_module = "executorch.codegen.tools",
182+
visibility = ["//executorch/..."],
183+
external_deps = ["torchgen"],
184+
)
185+
186+
runtime.python_binary(
187+
name = "gen_selected_prim_ops",
188+
main_module = "executorch.codegen.tools.gen_selected_prim_ops",
189+
package_style = "inplace",
190+
visibility = [
191+
"PUBLIC",
192+
],
193+
deps = [
194+
":gen_selected_prim_ops_lib",
195+
],
196+
_is_external_target = True,
197+
)
198+
158199
if not runtime.is_oss:
159200
runtime.cxx_python_extension(
160201
name = "selective_build",

0 commit comments

Comments
 (0)