1- #!/usr/bin/env fbpython
21# Copyright (c) Meta Platforms, Inc. and affiliates.
32# All rights reserved.
43#
76
87import argparse
98import os
9+ import re
1010import sys
11+ from functools import reduce
12+ from pathlib import Path
1113from typing import Any , List
1214
13- from tools_copy .code_analyzer import gen_oplist_copy_from_core
15+ import yaml
16+ from torchgen .selective_build .selector import (
17+ combine_selective_builders ,
18+ SelectiveBuilder ,
19+ )
20+
21+
22+ def throw_if_any_op_includes_overloads (selective_builder : SelectiveBuilder ) -> None :
23+ ops = []
24+ for op_name , op in selective_builder .operators .items ():
25+ if op .include_all_overloads :
26+ ops .append (op_name )
27+ if ops :
28+ raise Exception ( # noqa: TRY002
29+ (
30+ "Operators that include all overloads are "
31+ + "not allowed since --allow-include-all-overloads "
32+ + "was not specified: {}"
33+ ).format (", " .join (ops ))
34+ )
35+
36+
37+ def resolve_model_file_path_to_buck_target (model_file_path : str ) -> str :
38+ real_path = str (Path (model_file_path ).resolve (strict = True ))
39+ # try my best to convert to buck target
40+ prog = re .compile (
41+ r"/.*/buck-out/.*/(fbsource|fbcode)/[0-9a-f]*/(.*)/__(.*)_et_oplist__/out/selected_operators.yaml"
42+ )
43+ match = prog .match (real_path )
44+ if match :
45+ return f"{ match .group (1 )} //{ match .group (2 )} :{ match .group (3 )} "
46+ else :
47+ return real_path
1448
1549
1650def main (argv : List [Any ]) -> None :
17- """This binary is a wrapper for //executorch/codegen/tools/gen_oplist_copy_from_core.py.
18- This is needed because we intend to error out for the case where `model_file_list_path`
19- is empty or invalid, so that the ExecuTorch build will fail when no selective build target
20- is provided as a dependency to ExecuTorch build.
51+ """This binary generates 3 files:
52+
53+ 1. selected_mobile_ops.h: Primary operators used by templated selective build and Kernel Function
54+ dtypes captured by tracing
55+ 2. selected_operators.yaml: Selected root and non-root operators (either via tracing or static analysis)
2156 """
2257 parser = argparse .ArgumentParser (description = "Generate operator lists" )
2358 parser .add_argument (
59+ "--output-dir" ,
2460 "--output_dir" ,
2561 help = ("The directory to store the output yaml file (selected_operators.yaml)" ),
2662 required = True ,
2763 )
2864 parser .add_argument (
65+ "--model-file-list-path" ,
2966 "--model_file_list_path" ,
3067 help = (
3168 "Path to a file that contains the locations of individual "
@@ -36,6 +73,7 @@ def main(argv: List[Any]) -> None:
3673 required = True ,
3774 )
3875 parser .add_argument (
76+ "--allow-include-all-overloads" ,
3977 "--allow_include_all_overloads" ,
4078 help = (
4179 "Flag to allow operators that include all overloads. "
@@ -46,26 +84,109 @@ def main(argv: List[Any]) -> None:
4684 default = False ,
4785 required = False ,
4886 )
87+ parser .add_argument (
88+ "--check-ops-not-overlapping" ,
89+ "--check_ops_not_overlapping" ,
90+ help = (
91+ "Flag to check if the operators in the model file list are overlapping. "
92+ + "If not set, the script will not error out for overlapping operators."
93+ ),
94+ action = "store_true" ,
95+ default = False ,
96+ required = False ,
97+ )
98+ options = parser .parse_args (argv )
4999
50- # check if the build has any dependency on any selective build target. If we have a target, BUCK shold give us either:
100+ # Check if the build has any dependency on any selective build target. If we have a target, BUCK shold give us either:
51101 # 1. a yaml file containing selected ops (could be empty), or
52- # 2. a non-empty list of yaml files in the `model_file_list_path`.
53- # If none of the two things happened, the build target has no dependency on any selective build and we should error out .
54- options = parser . parse_args ( argv )
102+ # 2. a non-empty list of yaml files in the `model_file_list_path` or
103+ # 3. a non-empty list of directories in the `model_file_list_path`, with each directory containing a `selected_operators.yaml` file .
104+ # If none of the 3 things happened, the build target has no dependency on any selective build and we should error out.
55105 if os .path .isfile (options .model_file_list_path ):
56- pass
106+ print ("Processing model file: " , options .model_file_list_path )
107+ model_dicts = []
108+ model_dict = yaml .safe_load (open (options .model_file_list_path ))
109+ model_dicts .append (model_dict )
57110 else :
111+ print ("Processing model file list or model directory list: " , options .model_file_list_path )
58112 assert (
59113 options .model_file_list_path [0 ] == "@"
60114 ), "model_file_list_path is not a valid file path, or it doesn't start with '@'. This is likely a BUCK issue."
115+
61116 model_file_list_path = options .model_file_list_path [1 :]
117+
118+ model_dicts = []
62119 with open (model_file_list_path ) as model_list_file :
63120 model_file_names = model_list_file .read ().split ()
64121 assert (
65122 len (model_file_names ) > 0
66123 ), "BUCK was not able to find any `et_operator_library` in the dependency graph of the current ExecuTorch "
67124 "build. Please refer to Selective Build wiki page to add at least one."
68- gen_oplist_copy_from_core .main (argv )
125+ for model_file_name in model_file_names :
126+ if not os .path .isfile (model_file_name ):
127+ model_file_name = os .path .join (
128+ model_file_name , "selected_operators.yaml"
129+ )
130+ print ("Processing model file: " , model_file_name )
131+ assert os .path .isfile (
132+ model_file_name
133+ ), f"{ model_file_name } is not a valid file path. This is likely a BUCK issue."
134+ with open (model_file_name , "rb" ) as model_file :
135+ model_dict = yaml .safe_load (model_file )
136+ resolved = resolve_model_file_path_to_buck_target (model_file_name )
137+ for op in model_dict ["operators" ]:
138+ model_dict ["operators" ][op ]["debug_info" ] = [resolved ]
139+ model_dicts .append (model_dict )
140+
141+ selective_builders = [SelectiveBuilder .from_yaml_dict (m ) for m in model_dicts ]
142+
143+ # Optionally check if the operators in the model file list are overlapping.
144+ if options .check_ops_not_overlapping :
145+ ops = {}
146+ for model_dict in model_dicts :
147+ for op_name in model_dict ["operators" ]:
148+ if op_name in ops :
149+ debug_info_1 = "," .join (ops [op_name ]["debug_info" ])
150+ debug_info_2 = "," .join (
151+ model_dict ["operators" ][op_name ]["debug_info" ]
152+ )
153+ error = f"Operator { op_name } is used in 2 models: { debug_info_1 } and { debug_info_2 } "
154+ if "//" not in debug_info_1 and "//" not in debug_info_2 :
155+ error += "\n We can't determine what BUCK targets these model files belong to."
156+ tail = "."
157+ else :
158+ error += "\n Please run the following commands to find out where is the BUCK target being added as a dependency to your target:\n "
159+ error += f'\n buck2 cquery <mode> "allpaths(<target>, { debug_info_1 } )"'
160+ error += f'\n buck2 cquery <mode> "allpaths(<target>, { debug_info_2 } )"'
161+ tail = "as well as results from BUCK commands listed above."
162+
163+ error += (
164+ "\n \n If issue is not resolved, please post in PyTorch Edge Q&A with this error message"
165+ + tail
166+ )
167+ raise Exception (error ) # noqa: TRY002
168+ ops [op_name ] = model_dict ["operators" ][op_name ]
169+ # We may have 0 selective builders since there may not be any viable
170+ # pt_operator_library rule marked as a dep for the pt_operator_registry rule.
171+ # This is potentially an error, and we should probably raise an assertion
172+ # failure here. However, this needs to be investigated further.
173+ selective_builder = SelectiveBuilder .from_yaml_dict ({})
174+ if len (selective_builders ) > 0 :
175+ selective_builder = reduce (
176+ combine_selective_builders ,
177+ selective_builders ,
178+ )
179+
180+ if not options .allow_include_all_overloads :
181+ throw_if_any_op_includes_overloads (selective_builder )
182+ with open (
183+ os .path .join (options .output_dir , "selected_operators.yaml" ), "wb"
184+ ) as out_file :
185+ out_file .write (
186+ yaml .safe_dump (
187+ selective_builder .to_dict (), default_flow_style = False
188+ ).encode ("utf-8" ),
189+ )
69190
70191
71192if __name__ == "__main__" :
0 commit comments