Skip to content

Commit 07d1092

Browse files
authored
Add selective build support for prim ops
Differential Revision: D81648030 Pull Request resolved: #14332
1 parent 8b11418 commit 07d1092

File tree

11 files changed

+734
-22
lines changed

11 files changed

+734
-22
lines changed
Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
Script to combine multiple selected_prim_ops.h header files into a single header.
9+
This is used by selected_prim_operators_genrule to merge prim ops headers from dependencies.
10+
"""
11+
12+
import argparse
13+
import os
14+
import sys
15+
from pathlib import Path
16+
from typing import List, Set
17+
18+
19+
def read_header_file(file_path: Path) -> Set[str]:
20+
"""
21+
Read a selected_prim_ops.h file and extract the macros and comments.
22+
23+
Args:
24+
file_path: Path to the header file
25+
26+
Returns:
27+
macros_set where macros_set contains unique macro defines
28+
"""
29+
macros = set()
30+
31+
try:
32+
with open(file_path, "r") as f:
33+
for line in f:
34+
line = line.strip()
35+
36+
# Extract #define statements for prim ops
37+
if line.startswith("#define INCLUDE_") and not line.startswith(
38+
"#define EXECUTORCH_ENABLE"
39+
):
40+
macros.add(line)
41+
except FileNotFoundError:
42+
print(f"Warning: Header file not found: {file_path}", file=sys.stderr)
43+
except Exception as e:
44+
print(f"Error reading {file_path}: {e}", file=sys.stderr)
45+
46+
return macros
47+
48+
49+
def combine_prim_ops_headers(header_file_paths: List[str], output_path: str) -> None:
50+
"""
51+
Combine multiple selected_prim_ops.h files into a single header.
52+
53+
Args:
54+
header_files: List of paths to header files to combine
55+
output_path: Path to output the combined header
56+
"""
57+
all_macros = set()
58+
has_selective_build = False
59+
60+
# Read all header files and collect unique macros
61+
for header_file_path in header_file_paths:
62+
header_file = Path(header_file_path) / "selected_prim_ops.h"
63+
if os.path.exists(header_file):
64+
macros = read_header_file(header_file)
65+
all_macros.update(macros)
66+
if len(all_macros) > 0:
67+
has_selective_build = True
68+
else:
69+
print(
70+
f"Warning: Header file does not exist: {header_file}", file=sys.stderr
71+
)
72+
73+
# Generate combined header
74+
header_content = [
75+
"// Combined header for selective prim ops build",
76+
"// This file is auto-generated by combining multiple selected_prim_ops.h files",
77+
"// Do not edit manually.",
78+
"",
79+
"#pragma once",
80+
"",
81+
]
82+
83+
if all_macros and has_selective_build:
84+
header_content.extend(
85+
[
86+
"// Enable selective build for prim ops",
87+
"#define EXECUTORCH_ENABLE_PRIM_OPS_SELECTIVE_BUILD",
88+
"",
89+
"// Combined prim ops macros from all dependencies",
90+
]
91+
)
92+
93+
# Sort macros for deterministic output
94+
sorted_macros = sorted(all_macros)
95+
header_content.extend(sorted_macros)
96+
else:
97+
header_content.extend(
98+
[
99+
"// No prim ops found in dependencies - all prim ops will be included",
100+
"// Selective build is disabled",
101+
]
102+
)
103+
104+
header_content.append("")
105+
106+
# Write the combined header
107+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
108+
with open(output_path, "w") as f:
109+
f.write("\n".join(header_content))
110+
111+
112+
def _get_header_file_paths_from_query_output(query_output_file: str) -> List[str]:
113+
"""
114+
Parse the output of a Buck query command to extract header file paths.
115+
116+
Args:
117+
query_output_file: Path to the file containing the query output
118+
119+
Returns:
120+
List of header file paths
121+
"""
122+
header_file_paths = []
123+
assert (
124+
query_output_file[0] == "@"
125+
), "query_output_file is not a valid file path, or it doesn't start with '@'."
126+
query_output_file = query_output_file[1:]
127+
128+
with open(query_output_file, "r") as f:
129+
for line in f:
130+
# Extract the header file path from the query output
131+
header_file_paths += line.split()
132+
return header_file_paths
133+
134+
135+
def main():
136+
parser = argparse.ArgumentParser(
137+
description="Combine multiple selected_prim_ops.h header files"
138+
)
139+
parser.add_argument(
140+
"--header_files",
141+
required=True,
142+
help="Comma-separated list of header file paths",
143+
)
144+
parser.add_argument(
145+
"--output_dir", required=True, help="Output directory for combined header"
146+
)
147+
148+
args = parser.parse_args()
149+
import os
150+
151+
header_file_paths = _get_header_file_paths_from_query_output(args.header_files)
152+
153+
if not header_file_paths:
154+
print("Error: No header files provided", file=sys.stderr)
155+
sys.exit(1)
156+
157+
# Generate output path
158+
output_path = os.path.join(args.output_dir, "selected_prim_ops.h")
159+
160+
combine_prim_ops_headers(header_file_paths, output_path)
161+
162+
163+
if __name__ == "__main__":
164+
main()

codegen/tools/gen_all_oplist.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import sys
1111
from functools import reduce
1212
from pathlib import Path
13-
from typing import Any, List
13+
from typing import Any, Dict, List
1414

1515
import yaml
1616
from torchgen.selective_build.selector import (
@@ -72,6 +72,19 @@ def _raise_if_check_prim_ops_fail(options):
7272
raise Exception(error)
7373

7474

75+
def _selected_ops_model_dict_is_empty(model_dict: Dict[str, Any]) -> bool:
76+
return (
77+
not model_dict.get("build_features", [])
78+
and not model_dict.get("custom_classes", [])
79+
and not model_dict.get("et_kernel_metadata", None)
80+
and not model_dict.get("include_all_non_op_selectives", False)
81+
and not model_dict.get("include_all_operators", False)
82+
and not model_dict.get("kernel_metadata", {})
83+
and not model_dict.get("operators", {})
84+
)
85+
86+
87+
# flake8: noqa: C901
7588
def main(argv: List[Any]) -> None:
7689
"""This binary generates 3 files:
7790
@@ -171,6 +184,11 @@ def main(argv: List[Any]) -> None:
171184
), f"{model_file_name} is not a valid file path. This is likely a BUCK issue."
172185
with open(model_file_name, "rb") as model_file:
173186
model_dict = yaml.safe_load(model_file)
187+
# It is possible that we created an empty yaml file.
188+
# This is because et_operator_library may only contain prim ops.
189+
# In that case selected_operators.yaml will be empty.
190+
if _selected_ops_model_dict_is_empty(model_dict):
191+
continue
174192
resolved = resolve_model_file_path_to_buck_target(model_file_name)
175193
for op in model_dict["operators"]:
176194
model_dict["operators"][op]["debug_info"] = [resolved]

codegen/tools/gen_oplist.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import os
1010
import sys
1111
from enum import IntEnum
12+
from pathlib import Path
1213
from typing import Any, Dict, List, Optional, Set
1314

1415
import yaml
@@ -158,7 +159,7 @@ def _get_et_kernel_metadata_from_ops_yaml(ops_yaml_path: str) -> Dict[str, List[
158159

159160
def _dump_yaml(
160161
op_list: List[str],
161-
output_path: str,
162+
output_path: Path,
162163
model_name: Optional[str] = None,
163164
et_kernel_metadata: Optional[Dict[str, List[str]]] = None,
164165
include_all_operators: bool = False,
@@ -212,20 +213,23 @@ def create_kernel_key(maybe_kernel_key: str) -> str:
212213

213214

214215
def gen_oplist(
215-
output_path: str,
216+
output_path: Path,
216217
model_file_path: Optional[str] = None,
217218
ops_schema_yaml_path: Optional[str] = None,
218219
root_ops: Optional[str] = None,
219220
ops_dict: Optional[str] = None,
220221
include_all_operators: bool = False,
221222
):
222-
assert (
223+
if not (
223224
model_file_path
224225
or ops_schema_yaml_path
225226
or root_ops
226227
or ops_dict
227228
or include_all_operators
228-
), "Need to provide either model_file_path or ops_schema_yaml_path or root_ops or ops_dict or include_all_operators."
229+
):
230+
# dump empty yaml file
231+
_dump_yaml([], output_path)
232+
return
229233

230234
assert output_path, "Need to provide output_path for dumped yaml file."
231235
op_set = set()
@@ -326,9 +330,15 @@ def main(args: List[Any]) -> None:
326330
)
327331
options = parser.parse_args(args)
328332

333+
# check if the output_path is a directory, then generate operators
334+
# under selected_operators.yaml
335+
if Path(options.output_path).is_dir():
336+
output_path = Path(options.output_path) / "selected_operators.yaml"
337+
else:
338+
output_path = Path(options.output_path)
329339
try:
330340
gen_oplist(
331-
output_path=options.output_path,
341+
output_path=output_path,
332342
model_file_path=options.model_file_path,
333343
ops_schema_yaml_path=options.ops_schema_yaml_path,
334344
root_ops=options.root_ops,
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
import argparse
10+
import os
11+
import sys
12+
from typing import Any, List
13+
14+
from torchgen.code_template import CodeTemplate # type: ignore[import-not-found]
15+
16+
17+
selected_prim_ops_h_template_str = """#pragma once
18+
/**
19+
* Generated by executorch/codegen/tools/gen_selected_prim_ops.py
20+
*/
21+
22+
$defines
23+
"""
24+
selected_prim_ops_h_template = CodeTemplate(selected_prim_ops_h_template_str)
25+
26+
27+
def normalize_op_name(op_name: str) -> str:
28+
"""
29+
Normalize an operator name to a macro-safe format.
30+
Convert op names like "executorch_prim::et_view.default" to "EXECUTORCH_PRIM_ET_VIEW_DEFAULT"
31+
or "aten::sym_size.int" to "ATEN_SYM_SIZE_INT"
32+
"""
33+
# Remove namespace separator and replace with underscore
34+
normalized = op_name.replace("::", "_")
35+
# Replace dots with underscores
36+
normalized = normalized.replace(".", "_")
37+
# Convert to uppercase
38+
normalized = normalized.upper()
39+
# Add INCLUDE_ prefix
40+
normalized = f"INCLUDE_{normalized}"
41+
return normalized
42+
43+
44+
def write_selected_prim_ops(prim_op_names: List[str], output_dir: str) -> None:
45+
"""
46+
Generate selected_prim_ops.h from a list of prim op names.
47+
48+
Args:
49+
prim_op_names: List of prim op names like ["executorch_prim::et_view.default", "aten::sym_size.int"]
50+
output_dir: Directory where to write selected_prim_ops.h
51+
"""
52+
# Generate #define statements for each op
53+
defines = []
54+
for op_name in prim_op_names:
55+
macro_name = normalize_op_name(op_name)
56+
defines.append(f"#define {macro_name}")
57+
58+
# Join all defines with newlines
59+
defines_str = "\n".join(defines)
60+
61+
# Generate header content
62+
header_contents = selected_prim_ops_h_template.substitute(defines=defines_str)
63+
64+
# Write to file
65+
selected_prim_ops_path = os.path.join(output_dir, "selected_prim_ops.h")
66+
with open(selected_prim_ops_path, "wb") as out_file:
67+
out_file.write(header_contents.encode("utf-8"))
68+
69+
70+
def main(argv: List[Any]) -> None:
71+
parser = argparse.ArgumentParser(description="Generate selected prim ops header")
72+
parser.add_argument(
73+
"--prim-op-names",
74+
"--prim_op_names",
75+
help="Comma-separated list of prim op names to include",
76+
required=True,
77+
)
78+
parser.add_argument(
79+
"--output-dir",
80+
"--output_dir",
81+
help="The directory to store the output header file (selected_prim_ops.h)",
82+
required=True,
83+
)
84+
85+
options = parser.parse_args(argv)
86+
87+
# Parse comma-separated prim op names
88+
prim_op_names = [
89+
name.strip() for name in options.prim_op_names.split(",") if name.strip()
90+
]
91+
92+
write_selected_prim_ops(prim_op_names, options.output_dir)
93+
94+
95+
if __name__ == "__main__":
96+
main(sys.argv[1:])

0 commit comments

Comments
 (0)