Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 59 additions & 49 deletions shim_et/xplat/executorch/codegen/codegen.bzl
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "get_default_executorch_platforms", "is_xplat", "runtime", "struct_to_json")
load("@fbsource//xplat/executorch/build:selects.bzl", "selects")
load("@fbsource//xplat/executorch/kernels/portable:op_registration_util.bzl", "portable_source_list")
load("@fbsource//xplat/executorch/kernels/optimized:op_registration_util.bzl", "optimized_source_list")
load(
"@fbsource//xplat/executorch/kernels/optimized:lib_defs.bzl",
"get_vec_deps",
"get_vec_preprocessor_flags",
)
load("@fbsource//xplat/executorch/kernels/optimized:op_registration_util.bzl", "optimized_source_list")
load("@fbsource//xplat/executorch/kernels/portable:op_registration_util.bzl", "portable_source_list")
load("@fbsource//xplat/executorch/kernels/prim_ops:selective_build.bzl", "prim_ops_registry_selective")

# Headers that declare the function signatures of the C++ functions that
Expand Down Expand Up @@ -96,15 +96,17 @@ def _get_prim_ops_registry_target(name, deps, aten_suffix, platforms):
Returns:
String: Target name for the appropriate prim ops registry
"""

# If selective build targets are specified, create a selective prim ops registry
# Create a selective prim ops registry using the existing function
selective_prim_ops_registry_name = name + "_selected_prim_ops_registry"
combined_prim_ops_header_target_name = name + "_combined_prim_ops_header"
selected_prim_operators_genrule(combined_prim_ops_header_target_name, deps, platforms)

# Use the existing prim_ops_registry_selective function
prim_ops_registry_selective(
name = selective_prim_ops_registry_name,
selected_prim_ops_header_target = ":"+combined_prim_ops_header_target_name,
selected_prim_ops_header_target = ":" + combined_prim_ops_header_target_name,
aten_suffix = aten_suffix,
platforms = platforms,
)
Expand All @@ -123,11 +125,16 @@ def _extract_prim_ops_from_lists(ops, ops_dict):
Returns:
Tuple of (prim_ops, remaining_ops, remaining_ops_dict)
"""

def _is_aten_prim_op(op_name):
if not op_name.startswith("aten::"):
return False
for prim_suffix in [
"sym_size", "sym_numel", "sym_max", "sym_min", "sym_float"
"sym_size",
"sym_numel",
"sym_max",
"sym_min",
"sym_float",
]:
if prim_suffix in op_name:
return True
Expand Down Expand Up @@ -169,7 +176,6 @@ def et_operator_library(
ops_schema_yaml_target = None,
server_generated_yaml_target = None,
**kwargs):

# Check if we should extract prim ops from the operator lists
# Note that selective build for prim ops doesnt support model or ops_schema_yaml_target or server_generated_yaml_target
# TODO: Add support for selective build for prim ops with model or ops_schema_yaml_target or server_generated_yaml_target
Expand All @@ -178,6 +184,7 @@ def et_operator_library(
if should_extract_prim_ops:
# Extract prim ops from ops and ops_dict
prim_ops, remaining_ops, remaining_ops_dict = _extract_prim_ops_from_lists(ops, ops_dict)

# Use the remaining ops (with prim ops removed) for the main et_operator_library
final_ops = remaining_ops
final_ops_dict = remaining_ops_dict
Expand All @@ -189,6 +196,7 @@ def et_operator_library(

selected_operator_yaml_filename = "selected_operators.yaml"
selected_prim_ops_filename = "selected_prim_ops.h"

# Generate the main operator library with the final ops
# do a dummy copy if server_generated_yaml_target is set
if server_generated_yaml_target:
Expand Down Expand Up @@ -231,6 +239,7 @@ def et_operator_library(
"--prim_op_names=" + ",".join(prim_ops),
"--output_dir=${OUT}",
]

# Here we generate the selected_prim_ops.h and the selected_operators.yaml file
# both with single genrule
genrule_cmd = genrule_cmd + [" && "] + prim_ops_genrule_cmd
Expand Down Expand Up @@ -307,7 +316,6 @@ def _prepare_genrule_and_lib(
if support_exceptions:
genrule_cmd.append("--add-exception-boundary")


# Sources for generated kernel registration lib
sources = MANUAL_REGISTRATION_SOURCES if manual_registration else GENERATED_SOURCES

Expand Down Expand Up @@ -371,7 +379,8 @@ def _prepare_custom_ops_genrule_and_lib(
custom_ops_yaml_path = None,
support_exceptions = True,
deps = [],
kernels = []):
kernels = [],
platforms = get_default_executorch_platforms()):
"""Similar to _prepare_genrule_and_lib but for custom ops."""
genrules = {}
libs = {}
Expand All @@ -390,6 +399,7 @@ def _prepare_custom_ops_genrule_and_lib(
"--output_dir $OUT ").format(deps = " ".join(["\"{}\"".format(d) for d in deps])),
outs = {"selected_operators.yaml": ["selected_operators.yaml"]},
default_outs = ["."],
platforms = platforms,
)

# genrule for generating operator kernel bindings
Expand Down Expand Up @@ -460,6 +470,7 @@ def exir_custom_ops_aot_lib(
kernels = kernels,
support_exceptions = support_exceptions,
deps = deps,
platforms = platforms,
)
for genrule in genrules:
runtime.genrule(
Expand All @@ -468,6 +479,7 @@ def exir_custom_ops_aot_lib(
cmd = genrules[genrule]["cmd"],
outs = genrules[genrule]["outs"],
default_outs = ["."],
platforms = platforms,
)
for compiler_lib in libs:
runtime.cxx_library(
Expand Down Expand Up @@ -538,29 +550,31 @@ def get_optimized_lib_deps():
"//executorch/runtime/kernel:kernel_includes",
] + get_vec_deps()

def build_portable_header_lib(name, oplist_header_name, feature = None):
def build_portable_header_lib(name, oplist_header_name, feature = None, **kwargs):
"""Build the portable headers into a header-only library.
Ensures that includes work across portable and optimized libs.
"""
runtime.cxx_library(
name = name,
srcs = [],
exported_headers = {
"selected_op_variants.h":":{}[selected_op_variants]".format(oplist_header_name),
"selected_op_variants.h": ":{}[selected_op_variants]".format(oplist_header_name),
},
exported_preprocessor_flags = ["-DEXECUTORCH_SELECTIVE_BUILD_DTYPE"],
header_namespace = "",
feature = feature,
**kwargs
)

def build_portable_lib(
name,
et_operator_lib_deps = [],
oplist_header_name = None,
portable_header_lib = None,
feature = None,
expose_operator_symbols = False,
visibility = ["@EXECUTORCH_CLIENTS"]):
name,
et_operator_lib_deps = [],
oplist_header_name = None,
portable_header_lib = None,
feature = None,
expose_operator_symbols = False,
visibility = ["@EXECUTORCH_CLIENTS"],
platforms = get_default_executorch_platforms()):
"""
WARNING: Before using this, please consider using executorch_generated_lib instead. This
function is only for special cases where you need to build a portable kernel library with
Expand Down Expand Up @@ -639,9 +653,10 @@ def build_portable_lib(
# @lint-ignore BUCKLINT link_whole
link_whole = True,
feature = feature,
platforms = platforms,
)

def build_optimized_lib(name, oplist_header_name, portable_header_lib, feature = None, expose_operator_symbols = False):
def build_optimized_lib(name, oplist_header_name, portable_header_lib, feature = None, expose_operator_symbols = False, platforms = get_default_executorch_platforms()):
"""Build optimized lib from source. We build from source so that the generated header file,
selected_op_variants.h, can be used to selectively build the lib for different dtypes.
"""
Expand All @@ -661,7 +676,7 @@ def build_optimized_lib(name, oplist_header_name, portable_header_lib, feature =
# Currently fbcode links all dependent libraries through shared
# library, and it blocks users like unit tests to use kernel
# implementation directly. So we enable this for xplat only.
compiler_flags = ["-Wno-missing-prototypes", "-Wno-pass-failed","-Wno-global-constructors","-Wno-shadow",]
compiler_flags = ["-Wno-missing-prototypes", "-Wno-pass-failed", "-Wno-global-constructors", "-Wno-shadow"]
if not expose_operator_symbols and is_xplat():
# Removing '-fvisibility=hidden' exposes operator symbols.
# This allows operators to be called outside of the kernel registry.
Expand All @@ -674,6 +689,7 @@ def build_optimized_lib(name, oplist_header_name, portable_header_lib, feature =
exported_preprocessor_flags = ["-DEXECUTORCH_SELECTIVE_BUILD_DTYPE"],
deps = get_portable_lib_deps() + get_optimized_lib_deps() + [":" + portable_header_lib],
compiler_flags = compiler_flags,
platforms = platforms,
preprocessor_flags = get_vec_preprocessor_flags(),
# sleef needs to be added as a direct dependency of the operator target when building for Android,
# or a linker error may occur. Not sure why this happens; it seems that fbandroid_platform_deps of
Expand All @@ -699,10 +715,9 @@ def build_optimized_lib(name, oplist_header_name, portable_header_lib, feature =
)

def selected_operators_genrule(
name,
deps,
platforms = get_default_executorch_platforms(),
):
name,
deps,
platforms = get_default_executorch_platforms()):
"""Generates selected_operators.yaml from the list of deps. We look into the trasitive closure of all the deps,
and look for macros `et_operator_library`.

Expand All @@ -725,10 +740,9 @@ def selected_operators_genrule(
)

def selected_prim_operators_genrule(
name,
deps,
platforms = get_default_executorch_platforms(),
):
name,
deps,
platforms = get_default_executorch_platforms()):
"""Generates selected_prim_ops.h from the list of deps. We look into the transitive closure of all the deps,
and look for targets with label `et_operator_library`.

Expand All @@ -750,12 +764,11 @@ def selected_prim_operators_genrule(
)

def dtype_header_genrule(
name,
visibility,
deps = [],
selected_operators_genrule_name = None,
platforms = get_default_executorch_platforms(),
):
name,
visibility,
deps = [],
selected_operators_genrule_name = None,
platforms = get_default_executorch_platforms()):
"""Generate selected_op_variants.h from selected_operators.yaml.

Given a `selected_operators.yaml` (passed in as selected_operators_genrule_name), we should be able to determine
Expand Down Expand Up @@ -921,15 +934,14 @@ def executorch_generated_lib(
index = index + 1
portable = name + "_check_portable_" + dep.split(":")[1] + str(index)
message = "Dtype selective build requires that the portable library is not passed into `deps`. This will cause duplicate symbol errors in the build. Please remove it from `deps` and place it into `kernel_deps`"
check_recursive_dependencies(portable, dep, "//executorch/kernels/portable:operators", message)
check_recursive_dependencies(portable, dep, "//executorch/kernels/portable:operators", message, platforms = platforms)
if ("//executorch/kernels/optimized:optimized_operators" in kernel_deps):
index = 0
for dep in deps:
index = index + 1
optimized = name + "_check_optimized_" + dep.split(":")[1] + str(index)
message = "Dtype selective build requires that the optimized library is not passed into `deps`. This will cause duplicate symbol errors in the build. Please remove it from `deps` and place it into `kernel_deps`"
check_recursive_dependencies(optimized, dep, "//executorch/kernels/optimized:optimized_operators", message)

check_recursive_dependencies(optimized, dep, "//executorch/kernels/optimized:optimized_operators", message, platforms = platforms)

aten_suffix = "_aten" if aten_mode else ""

Expand Down Expand Up @@ -995,15 +1007,15 @@ def executorch_generated_lib(
if dtype_selective_build:
# Build portable headers lib. Used for portable and optimized kernel libraries.
portable_header_lib = name + "_portable_header_lib"
build_portable_header_lib(portable_header_lib, oplist_header_name, feature)
build_portable_header_lib(portable_header_lib, oplist_header_name, feature, platforms = platforms)

if "//executorch/kernels/portable:operators" in kernel_deps:
# Remove portable from kernel_deps as we're building it from source.
kernel_deps.remove("//executorch/kernels/portable:operators")

# Build portable lib.
portable_lib_name = name + "_portable_lib"
build_portable_lib(name = portable_lib_name, portable_header_lib = portable_header_lib, feature = feature, expose_operator_symbols = expose_operator_symbols)
build_portable_lib(name = portable_lib_name, portable_header_lib = portable_header_lib, feature = feature, expose_operator_symbols = expose_operator_symbols, platforms = platforms)
kernel_deps.append(":{}".format(portable_lib_name))

if "//executorch/kernels/optimized:optimized_operators" in kernel_deps:
Expand All @@ -1012,7 +1024,7 @@ def executorch_generated_lib(

# Build optimized lib.
optimized_lib_name = name + "_optimized_lib"
build_optimized_lib(optimized_lib_name, oplist_header_name, portable_header_lib, feature, expose_operator_symbols)
build_optimized_lib(optimized_lib_name, oplist_header_name, portable_header_lib, feature, expose_operator_symbols, platforms = platforms)
kernel_deps.append(":{}".format(optimized_lib_name))

# Exports headers that declare the function signatures of the C++ functions
Expand Down Expand Up @@ -1111,10 +1123,9 @@ def executorch_generated_lib(
#
# If build successfully, all of the `selected_operators.yaml` will be merged into 1 `selected_operators.yaml` for debugging purpose.
def executorch_ops_check(
name,
deps,
**kwargs,
):
name,
deps,
**kwargs):
runtime.genrule(
name = name,
macros_only = False,
Expand All @@ -1128,16 +1139,15 @@ def executorch_ops_check(
platforms = kwargs.pop("platforms", get_default_executorch_platforms()),
outs = {"selected_operators.yaml": ["selected_operators.yaml"]},
default_outs = ["."],
**kwargs,
**kwargs
)

def check_recursive_dependencies(
name,
parent,
child,
message = "",
**kwargs,
):
name,
parent,
child,
message = "",
**kwargs):
"""
Checks if child is a transitive dependency of parent and fails if it is.
The query runs the equivalent of `buck2 uquery "allpaths(parent, child)".
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def prim_ops_registry_selective(name, selected_prim_ops_header_target, aten_suff
header_name: [header_name],
"selected_prim_ops.h": ["selected_prim_ops.h"]
},
platforms = kwargs.get("platforms", "CXX"),
default_outs = ["."],
)
runtime.cxx_library(
Expand Down
Loading