Skip to content

Commit 881915d

Browse files
authored
Add platforms for all operator library sub-targets.
Differential Revision: D83680406 Pull Request resolved: #14728
1 parent 3f0896a commit 881915d

File tree

2 files changed

+60
-49
lines changed

2 files changed

+60
-49
lines changed

shim_et/xplat/executorch/codegen/codegen.bzl

Lines changed: 59 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "get_default_executorch_platforms", "is_xplat", "runtime", "struct_to_json")
22
load("@fbsource//xplat/executorch/build:selects.bzl", "selects")
3-
load("@fbsource//xplat/executorch/kernels/portable:op_registration_util.bzl", "portable_source_list")
4-
load("@fbsource//xplat/executorch/kernels/optimized:op_registration_util.bzl", "optimized_source_list")
53
load(
64
"@fbsource//xplat/executorch/kernels/optimized:lib_defs.bzl",
75
"get_vec_deps",
86
"get_vec_preprocessor_flags",
97
)
8+
load("@fbsource//xplat/executorch/kernels/optimized:op_registration_util.bzl", "optimized_source_list")
9+
load("@fbsource//xplat/executorch/kernels/portable:op_registration_util.bzl", "portable_source_list")
1010
load("@fbsource//xplat/executorch/kernels/prim_ops:selective_build.bzl", "prim_ops_registry_selective")
1111

1212
# Headers that declare the function signatures of the C++ functions that
@@ -96,15 +96,17 @@ def _get_prim_ops_registry_target(name, deps, aten_suffix, platforms):
9696
Returns:
9797
String: Target name for the appropriate prim ops registry
9898
"""
99+
99100
# If selective build targets are specified, create a selective prim ops registry
100101
# Create a selective prim ops registry using the existing function
101102
selective_prim_ops_registry_name = name + "_selected_prim_ops_registry"
102103
combined_prim_ops_header_target_name = name + "_combined_prim_ops_header"
103104
selected_prim_operators_genrule(combined_prim_ops_header_target_name, deps, platforms)
105+
104106
# Use the existing prim_ops_registry_selective function
105107
prim_ops_registry_selective(
106108
name = selective_prim_ops_registry_name,
107-
selected_prim_ops_header_target = ":"+combined_prim_ops_header_target_name,
109+
selected_prim_ops_header_target = ":" + combined_prim_ops_header_target_name,
108110
aten_suffix = aten_suffix,
109111
platforms = platforms,
110112
)
@@ -123,11 +125,16 @@ def _extract_prim_ops_from_lists(ops, ops_dict):
123125
Returns:
124126
Tuple of (prim_ops, remaining_ops, remaining_ops_dict)
125127
"""
128+
126129
def _is_aten_prim_op(op_name):
127130
if not op_name.startswith("aten::"):
128131
return False
129132
for prim_suffix in [
130-
"sym_size", "sym_numel", "sym_max", "sym_min", "sym_float"
133+
"sym_size",
134+
"sym_numel",
135+
"sym_max",
136+
"sym_min",
137+
"sym_float",
131138
]:
132139
if prim_suffix in op_name:
133140
return True
@@ -169,7 +176,6 @@ def et_operator_library(
169176
ops_schema_yaml_target = None,
170177
server_generated_yaml_target = None,
171178
**kwargs):
172-
173179
# Check if we should extract prim ops from the operator lists
174180
# Note that selective build for prim ops doesnt support model or ops_schema_yaml_target or server_generated_yaml_target
175181
# TODO: Add support for selective build for prim ops with model or ops_schema_yaml_target or server_generated_yaml_target
@@ -178,6 +184,7 @@ def et_operator_library(
178184
if should_extract_prim_ops:
179185
# Extract prim ops from ops and ops_dict
180186
prim_ops, remaining_ops, remaining_ops_dict = _extract_prim_ops_from_lists(ops, ops_dict)
187+
181188
# Use the remaining ops (with prim ops removed) for the main et_operator_library
182189
final_ops = remaining_ops
183190
final_ops_dict = remaining_ops_dict
@@ -189,6 +196,7 @@ def et_operator_library(
189196

190197
selected_operator_yaml_filename = "selected_operators.yaml"
191198
selected_prim_ops_filename = "selected_prim_ops.h"
199+
192200
# Generate the main operator library with the final ops
193201
# do a dummy copy if server_generated_yaml_target is set
194202
if server_generated_yaml_target:
@@ -231,6 +239,7 @@ def et_operator_library(
231239
"--prim_op_names=" + ",".join(prim_ops),
232240
"--output_dir=${OUT}",
233241
]
242+
234243
# Here we generate the selected_prim_ops.h and the selected_operators.yaml file
235244
# both with single genrule
236245
genrule_cmd = genrule_cmd + [" && "] + prim_ops_genrule_cmd
@@ -307,7 +316,6 @@ def _prepare_genrule_and_lib(
307316
if support_exceptions:
308317
genrule_cmd.append("--add-exception-boundary")
309318

310-
311319
# Sources for generated kernel registration lib
312320
sources = MANUAL_REGISTRATION_SOURCES if manual_registration else GENERATED_SOURCES
313321

@@ -371,7 +379,8 @@ def _prepare_custom_ops_genrule_and_lib(
371379
custom_ops_yaml_path = None,
372380
support_exceptions = True,
373381
deps = [],
374-
kernels = []):
382+
kernels = [],
383+
platforms = get_default_executorch_platforms()):
375384
"""Similar to _prepare_genrule_and_lib but for custom ops."""
376385
genrules = {}
377386
libs = {}
@@ -390,6 +399,7 @@ def _prepare_custom_ops_genrule_and_lib(
390399
"--output_dir $OUT ").format(deps = " ".join(["\"{}\"".format(d) for d in deps])),
391400
outs = {"selected_operators.yaml": ["selected_operators.yaml"]},
392401
default_outs = ["."],
402+
platforms = platforms,
393403
)
394404

395405
# genrule for generating operator kernel bindings
@@ -460,6 +470,7 @@ def exir_custom_ops_aot_lib(
460470
kernels = kernels,
461471
support_exceptions = support_exceptions,
462472
deps = deps,
473+
platforms = platforms,
463474
)
464475
for genrule in genrules:
465476
runtime.genrule(
@@ -468,6 +479,7 @@ def exir_custom_ops_aot_lib(
468479
cmd = genrules[genrule]["cmd"],
469480
outs = genrules[genrule]["outs"],
470481
default_outs = ["."],
482+
platforms = platforms,
471483
)
472484
for compiler_lib in libs:
473485
runtime.cxx_library(
@@ -538,29 +550,31 @@ def get_optimized_lib_deps():
538550
"//executorch/runtime/kernel:kernel_includes",
539551
] + get_vec_deps()
540552

541-
def build_portable_header_lib(name, oplist_header_name, feature = None):
553+
def build_portable_header_lib(name, oplist_header_name, feature = None, **kwargs):
542554
"""Build the portable headers into a header-only library.
543555
Ensures that includes work across portable and optimized libs.
544556
"""
545557
runtime.cxx_library(
546558
name = name,
547559
srcs = [],
548560
exported_headers = {
549-
"selected_op_variants.h":":{}[selected_op_variants]".format(oplist_header_name),
561+
"selected_op_variants.h": ":{}[selected_op_variants]".format(oplist_header_name),
550562
},
551563
exported_preprocessor_flags = ["-DEXECUTORCH_SELECTIVE_BUILD_DTYPE"],
552564
header_namespace = "",
553565
feature = feature,
566+
**kwargs
554567
)
555568

556569
def build_portable_lib(
557-
name,
558-
et_operator_lib_deps = [],
559-
oplist_header_name = None,
560-
portable_header_lib = None,
561-
feature = None,
562-
expose_operator_symbols = False,
563-
visibility = ["@EXECUTORCH_CLIENTS"]):
570+
name,
571+
et_operator_lib_deps = [],
572+
oplist_header_name = None,
573+
portable_header_lib = None,
574+
feature = None,
575+
expose_operator_symbols = False,
576+
visibility = ["@EXECUTORCH_CLIENTS"],
577+
platforms = get_default_executorch_platforms()):
564578
"""
565579
WARNING: Before using this, please consider using executorch_generated_lib instead. This
566580
function is only for special cases where you need to build a portable kernel library with
@@ -639,9 +653,10 @@ def build_portable_lib(
639653
# @lint-ignore BUCKLINT link_whole
640654
link_whole = True,
641655
feature = feature,
656+
platforms = platforms,
642657
)
643658

644-
def build_optimized_lib(name, oplist_header_name, portable_header_lib, feature = None, expose_operator_symbols = False):
659+
def build_optimized_lib(name, oplist_header_name, portable_header_lib, feature = None, expose_operator_symbols = False, platforms = get_default_executorch_platforms()):
645660
"""Build optimized lib from source. We build from source so that the generated header file,
646661
selected_op_variants.h, can be used to selectively build the lib for different dtypes.
647662
"""
@@ -661,7 +676,7 @@ def build_optimized_lib(name, oplist_header_name, portable_header_lib, feature =
661676
# Currently fbcode links all dependent libraries through shared
662677
# library, and it blocks users like unit tests to use kernel
663678
# implementation directly. So we enable this for xplat only.
664-
compiler_flags = ["-Wno-missing-prototypes", "-Wno-pass-failed","-Wno-global-constructors","-Wno-shadow",]
679+
compiler_flags = ["-Wno-missing-prototypes", "-Wno-pass-failed", "-Wno-global-constructors", "-Wno-shadow"]
665680
if not expose_operator_symbols and is_xplat():
666681
# Removing '-fvisibility=hidden' exposes operator symbols.
667682
# This allows operators to be called outside of the kernel registry.
@@ -674,6 +689,7 @@ def build_optimized_lib(name, oplist_header_name, portable_header_lib, feature =
674689
exported_preprocessor_flags = ["-DEXECUTORCH_SELECTIVE_BUILD_DTYPE"],
675690
deps = get_portable_lib_deps() + get_optimized_lib_deps() + [":" + portable_header_lib],
676691
compiler_flags = compiler_flags,
692+
platforms = platforms,
677693
preprocessor_flags = get_vec_preprocessor_flags(),
678694
# sleef needs to be added as a direct dependency of the operator target when building for Android,
679695
# or a linker error may occur. Not sure why this happens; it seems that fbandroid_platform_deps of
@@ -699,10 +715,9 @@ def build_optimized_lib(name, oplist_header_name, portable_header_lib, feature =
699715
)
700716

701717
def selected_operators_genrule(
702-
name,
703-
deps,
704-
platforms = get_default_executorch_platforms(),
705-
):
718+
name,
719+
deps,
720+
platforms = get_default_executorch_platforms()):
706721
"""Generates selected_operators.yaml from the list of deps. We look into the trasitive closure of all the deps,
707722
and look for macros `et_operator_library`.
708723
@@ -725,10 +740,9 @@ def selected_operators_genrule(
725740
)
726741

727742
def selected_prim_operators_genrule(
728-
name,
729-
deps,
730-
platforms = get_default_executorch_platforms(),
731-
):
743+
name,
744+
deps,
745+
platforms = get_default_executorch_platforms()):
732746
"""Generates selected_prim_ops.h from the list of deps. We look into the transitive closure of all the deps,
733747
and look for targets with label `et_operator_library`.
734748
@@ -750,12 +764,11 @@ def selected_prim_operators_genrule(
750764
)
751765

752766
def dtype_header_genrule(
753-
name,
754-
visibility,
755-
deps = [],
756-
selected_operators_genrule_name = None,
757-
platforms = get_default_executorch_platforms(),
758-
):
767+
name,
768+
visibility,
769+
deps = [],
770+
selected_operators_genrule_name = None,
771+
platforms = get_default_executorch_platforms()):
759772
"""Generate selected_op_variants.h from selected_operators.yaml.
760773
761774
Given a `selected_operators.yaml` (passed in as selected_operators_genrule_name), we should be able to determine
@@ -921,15 +934,14 @@ def executorch_generated_lib(
921934
index = index + 1
922935
portable = name + "_check_portable_" + dep.split(":")[1] + str(index)
923936
message = "Dtype selective build requires that the portable library is not passed into `deps`. This will cause duplicate symbol errors in the build. Please remove it from `deps` and place it into `kernel_deps`"
924-
check_recursive_dependencies(portable, dep, "//executorch/kernels/portable:operators", message)
937+
check_recursive_dependencies(portable, dep, "//executorch/kernels/portable:operators", message, platforms = platforms)
925938
if ("//executorch/kernels/optimized:optimized_operators" in kernel_deps):
926939
index = 0
927940
for dep in deps:
928941
index = index + 1
929942
optimized = name + "_check_optimized_" + dep.split(":")[1] + str(index)
930943
message = "Dtype selective build requires that the optimized library is not passed into `deps`. This will cause duplicate symbol errors in the build. Please remove it from `deps` and place it into `kernel_deps`"
931-
check_recursive_dependencies(optimized, dep, "//executorch/kernels/optimized:optimized_operators", message)
932-
944+
check_recursive_dependencies(optimized, dep, "//executorch/kernels/optimized:optimized_operators", message, platforms = platforms)
933945

934946
aten_suffix = "_aten" if aten_mode else ""
935947

@@ -995,15 +1007,15 @@ def executorch_generated_lib(
9951007
if dtype_selective_build:
9961008
# Build portable headers lib. Used for portable and optimized kernel libraries.
9971009
portable_header_lib = name + "_portable_header_lib"
998-
build_portable_header_lib(portable_header_lib, oplist_header_name, feature)
1010+
build_portable_header_lib(portable_header_lib, oplist_header_name, feature, platforms = platforms)
9991011

10001012
if "//executorch/kernels/portable:operators" in kernel_deps:
10011013
# Remove portable from kernel_deps as we're building it from source.
10021014
kernel_deps.remove("//executorch/kernels/portable:operators")
10031015

10041016
# Build portable lib.
10051017
portable_lib_name = name + "_portable_lib"
1006-
build_portable_lib(name = portable_lib_name, portable_header_lib = portable_header_lib, feature = feature, expose_operator_symbols = expose_operator_symbols)
1018+
build_portable_lib(name = portable_lib_name, portable_header_lib = portable_header_lib, feature = feature, expose_operator_symbols = expose_operator_symbols, platforms = platforms)
10071019
kernel_deps.append(":{}".format(portable_lib_name))
10081020

10091021
if "//executorch/kernels/optimized:optimized_operators" in kernel_deps:
@@ -1012,7 +1024,7 @@ def executorch_generated_lib(
10121024

10131025
# Build optimized lib.
10141026
optimized_lib_name = name + "_optimized_lib"
1015-
build_optimized_lib(optimized_lib_name, oplist_header_name, portable_header_lib, feature, expose_operator_symbols)
1027+
build_optimized_lib(optimized_lib_name, oplist_header_name, portable_header_lib, feature, expose_operator_symbols, platforms = platforms)
10161028
kernel_deps.append(":{}".format(optimized_lib_name))
10171029

10181030
# Exports headers that declare the function signatures of the C++ functions
@@ -1111,10 +1123,9 @@ def executorch_generated_lib(
11111123
#
11121124
# If build successfully, all of the `selected_operators.yaml` will be merged into 1 `selected_operators.yaml` for debugging purpose.
11131125
def executorch_ops_check(
1114-
name,
1115-
deps,
1116-
**kwargs,
1117-
):
1126+
name,
1127+
deps,
1128+
**kwargs):
11181129
runtime.genrule(
11191130
name = name,
11201131
macros_only = False,
@@ -1128,16 +1139,15 @@ def executorch_ops_check(
11281139
platforms = kwargs.pop("platforms", get_default_executorch_platforms()),
11291140
outs = {"selected_operators.yaml": ["selected_operators.yaml"]},
11301141
default_outs = ["."],
1131-
**kwargs,
1142+
**kwargs
11321143
)
11331144

11341145
def check_recursive_dependencies(
1135-
name,
1136-
parent,
1137-
child,
1138-
message = "",
1139-
**kwargs,
1140-
):
1146+
name,
1147+
parent,
1148+
child,
1149+
message = "",
1150+
**kwargs):
11411151
"""
11421152
Checks if child is a transitive dependency of parent and fails if it is.
11431153
The query runs the equivalent of `buck2 uquery "allpaths(parent, child)".

shim_et/xplat/executorch/kernels/prim_ops/selective_build.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def prim_ops_registry_selective(name, selected_prim_ops_header_target, aten_suff
2828
header_name: [header_name],
2929
"selected_prim_ops.h": ["selected_prim_ops.h"]
3030
},
31+
platforms = kwargs.get("platforms", "CXX"),
3132
default_outs = ["."],
3233
)
3334
runtime.cxx_library(

0 commit comments

Comments
 (0)