diff --git a/shim_et/xplat/executorch/codegen/codegen.bzl b/shim_et/xplat/executorch/codegen/codegen.bzl index 3546b64cdb6..0002884b2a4 100644 --- a/shim_et/xplat/executorch/codegen/codegen.bzl +++ b/shim_et/xplat/executorch/codegen/codegen.bzl @@ -1,12 +1,12 @@ load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "get_default_executorch_platforms", "is_xplat", "runtime", "struct_to_json") load("@fbsource//xplat/executorch/build:selects.bzl", "selects") -load("@fbsource//xplat/executorch/kernels/portable:op_registration_util.bzl", "portable_source_list") -load("@fbsource//xplat/executorch/kernels/optimized:op_registration_util.bzl", "optimized_source_list") load( "@fbsource//xplat/executorch/kernels/optimized:lib_defs.bzl", "get_vec_deps", "get_vec_preprocessor_flags", ) +load("@fbsource//xplat/executorch/kernels/optimized:op_registration_util.bzl", "optimized_source_list") +load("@fbsource//xplat/executorch/kernels/portable:op_registration_util.bzl", "portable_source_list") load("@fbsource//xplat/executorch/kernels/prim_ops:selective_build.bzl", "prim_ops_registry_selective") # Headers that declare the function signatures of the C++ functions that @@ -96,15 +96,17 @@ def _get_prim_ops_registry_target(name, deps, aten_suffix, platforms): Returns: String: Target name for the appropriate prim ops registry """ + # If selective build targets are specified, create a selective prim ops registry # Create a selective prim ops registry using the existing function selective_prim_ops_registry_name = name + "_selected_prim_ops_registry" combined_prim_ops_header_target_name = name + "_combined_prim_ops_header" selected_prim_operators_genrule(combined_prim_ops_header_target_name, deps, platforms) + # Use the existing prim_ops_registry_selective function prim_ops_registry_selective( name = selective_prim_ops_registry_name, - selected_prim_ops_header_target = ":"+combined_prim_ops_header_target_name, + selected_prim_ops_header_target = ":" + combined_prim_ops_header_target_name, aten_suffix = aten_suffix, platforms = platforms, ) @@ -123,11 +125,16 @@ def _extract_prim_ops_from_lists(ops, ops_dict): Returns: Tuple of (prim_ops, remaining_ops, remaining_ops_dict) """ + def _is_aten_prim_op(op_name): if not op_name.startswith("aten::"): return False for prim_suffix in [ - "sym_size", "sym_numel", "sym_max", "sym_min", "sym_float" + "sym_size", + "sym_numel", + "sym_max", + "sym_min", + "sym_float", ]: if prim_suffix in op_name: return True @@ -169,7 +176,6 @@ def et_operator_library( ops_schema_yaml_target = None, server_generated_yaml_target = None, **kwargs): - # Check if we should extract prim ops from the operator lists # Note that selective build for prim ops doesnt support model or ops_schema_yaml_target or server_generated_yaml_target # TODO: Add support for selective build for prim ops with model or ops_schema_yaml_target or server_generated_yaml_target @@ -178,6 +184,7 @@ def et_operator_library( if should_extract_prim_ops: # Extract prim ops from ops and ops_dict prim_ops, remaining_ops, remaining_ops_dict = _extract_prim_ops_from_lists(ops, ops_dict) + # Use the remaining ops (with prim ops removed) for the main et_operator_library final_ops = remaining_ops final_ops_dict = remaining_ops_dict @@ -189,6 +196,7 @@ def et_operator_library( selected_operator_yaml_filename = "selected_operators.yaml" selected_prim_ops_filename = "selected_prim_ops.h" + # Generate the main operator library with the final ops # do a dummy copy if server_generated_yaml_target is set if server_generated_yaml_target: @@ -231,6 +239,7 @@ def et_operator_library( "--prim_op_names=" + ",".join(prim_ops), "--output_dir=${OUT}", ] + # Here we generate the selected_prim_ops.h and the selected_operators.yaml file # both with single genrule genrule_cmd = genrule_cmd + [" && "] + prim_ops_genrule_cmd @@ -307,7 +316,6 @@ def _prepare_genrule_and_lib( if support_exceptions: genrule_cmd.append("--add-exception-boundary") - # Sources for generated kernel registration lib sources = MANUAL_REGISTRATION_SOURCES if manual_registration else GENERATED_SOURCES @@ -371,7 +379,8 @@ def _prepare_custom_ops_genrule_and_lib( custom_ops_yaml_path = None, support_exceptions = True, deps = [], - kernels = []): + kernels = [], + platforms = get_default_executorch_platforms()): """Similar to _prepare_genrule_and_lib but for custom ops.""" genrules = {} libs = {} @@ -390,6 +399,7 @@ def _prepare_custom_ops_genrule_and_lib( "--output_dir $OUT ").format(deps = " ".join(["\"{}\"".format(d) for d in deps])), outs = {"selected_operators.yaml": ["selected_operators.yaml"]}, default_outs = ["."], + platforms = platforms, ) # genrule for generating operator kernel bindings @@ -460,6 +470,7 @@ def exir_custom_ops_aot_lib( kernels = kernels, support_exceptions = support_exceptions, deps = deps, + platforms = platforms, ) for genrule in genrules: runtime.genrule( @@ -468,6 +479,7 @@ def exir_custom_ops_aot_lib( cmd = genrules[genrule]["cmd"], outs = genrules[genrule]["outs"], default_outs = ["."], + platforms = platforms, ) for compiler_lib in libs: runtime.cxx_library( @@ -538,7 +550,7 @@ def get_optimized_lib_deps(): "//executorch/runtime/kernel:kernel_includes", ] + get_vec_deps() -def build_portable_header_lib(name, oplist_header_name, feature = None): +def build_portable_header_lib(name, oplist_header_name, feature = None, **kwargs): """Build the portable headers into a header-only library. Ensures that includes work across portable and optimized libs. """ @@ -546,21 +558,23 @@ def build_portable_header_lib(name, oplist_header_name, feature = None): name = name, srcs = [], exported_headers = { - "selected_op_variants.h":":{}[selected_op_variants]".format(oplist_header_name), + "selected_op_variants.h": ":{}[selected_op_variants]".format(oplist_header_name), }, exported_preprocessor_flags = ["-DEXECUTORCH_SELECTIVE_BUILD_DTYPE"], header_namespace = "", feature = feature, + **kwargs ) def build_portable_lib( - name, - et_operator_lib_deps = [], - oplist_header_name = None, - portable_header_lib = None, - feature = None, - expose_operator_symbols = False, - visibility = ["@EXECUTORCH_CLIENTS"]): + name, + et_operator_lib_deps = [], + oplist_header_name = None, + portable_header_lib = None, + feature = None, + expose_operator_symbols = False, + visibility = ["@EXECUTORCH_CLIENTS"], + platforms = get_default_executorch_platforms()): """ WARNING: Before using this, please consider using executorch_generated_lib instead. This function is only for special cases where you need to build a portable kernel library with @@ -639,9 +653,10 @@ def build_portable_lib( # @lint-ignore BUCKLINT link_whole link_whole = True, feature = feature, + platforms = platforms, ) -def build_optimized_lib(name, oplist_header_name, portable_header_lib, feature = None, expose_operator_symbols = False): +def build_optimized_lib(name, oplist_header_name, portable_header_lib, feature = None, expose_operator_symbols = False, platforms = get_default_executorch_platforms()): """Build optimized lib from source. We build from source so that the generated header file, selected_op_variants.h, can be used to selectively build the lib for different dtypes. """ @@ -661,7 +676,7 @@ def build_optimized_lib(name, oplist_header_name, portable_header_lib, feature = # Currently fbcode links all dependent libraries through shared # library, and it blocks users like unit tests to use kernel # implementation directly. So we enable this for xplat only. - compiler_flags = ["-Wno-missing-prototypes", "-Wno-pass-failed","-Wno-global-constructors","-Wno-shadow",] + compiler_flags = ["-Wno-missing-prototypes", "-Wno-pass-failed", "-Wno-global-constructors", "-Wno-shadow"] if not expose_operator_symbols and is_xplat(): # Removing '-fvisibility=hidden' exposes operator symbols. # This allows operators to be called outside of the kernel registry. @@ -674,6 +689,7 @@ def build_optimized_lib(name, oplist_header_name, portable_header_lib, feature = exported_preprocessor_flags = ["-DEXECUTORCH_SELECTIVE_BUILD_DTYPE"], deps = get_portable_lib_deps() + get_optimized_lib_deps() + [":" + portable_header_lib], compiler_flags = compiler_flags, + platforms = platforms, preprocessor_flags = get_vec_preprocessor_flags(), # sleef needs to be added as a direct dependency of the operator target when building for Android, # or a linker error may occur. Not sure why this happens; it seems that fbandroid_platform_deps of @@ -699,10 +715,9 @@ def build_optimized_lib(name, oplist_header_name, portable_header_lib, feature = ) def selected_operators_genrule( - name, - deps, - platforms = get_default_executorch_platforms(), -): + name, + deps, + platforms = get_default_executorch_platforms()): """Generates selected_operators.yaml from the list of deps. We look into the trasitive closure of all the deps, and look for macros `et_operator_library`. @@ -725,10 +740,9 @@ def selected_operators_genrule( ) def selected_prim_operators_genrule( - name, - deps, - platforms = get_default_executorch_platforms(), -): + name, + deps, + platforms = get_default_executorch_platforms()): """Generates selected_prim_ops.h from the list of deps. We look into the transitive closure of all the deps, and look for targets with label `et_operator_library`. @@ -750,12 +764,11 @@ def selected_prim_operators_genrule( ) def dtype_header_genrule( - name, - visibility, - deps = [], - selected_operators_genrule_name = None, - platforms = get_default_executorch_platforms(), -): + name, + visibility, + deps = [], + selected_operators_genrule_name = None, + platforms = get_default_executorch_platforms()): """Generate selected_op_variants.h from selected_operators.yaml. Given a `selected_operators.yaml` (passed in as selected_operators_genrule_name), we should be able to determine @@ -921,15 +934,14 @@ def executorch_generated_lib( index = index + 1 portable = name + "_check_portable_" + dep.split(":")[1] + str(index) message = "Dtype selective build requires that the portable library is not passed into `deps`. This will cause duplicate symbol errors in the build. Please remove it from `deps` and place it into `kernel_deps`" - check_recursive_dependencies(portable, dep, "//executorch/kernels/portable:operators", message) + check_recursive_dependencies(portable, dep, "//executorch/kernels/portable:operators", message, platforms = platforms) if ("//executorch/kernels/optimized:optimized_operators" in kernel_deps): index = 0 for dep in deps: index = index + 1 optimized = name + "_check_optimized_" + dep.split(":")[1] + str(index) message = "Dtype selective build requires that the optimized library is not passed into `deps`. This will cause duplicate symbol errors in the build. Please remove it from `deps` and place it into `kernel_deps`" - check_recursive_dependencies(optimized, dep, "//executorch/kernels/optimized:optimized_operators", message) - + check_recursive_dependencies(optimized, dep, "//executorch/kernels/optimized:optimized_operators", message, platforms = platforms) aten_suffix = "_aten" if aten_mode else "" @@ -995,7 +1007,7 @@ def executorch_generated_lib( if dtype_selective_build: # Build portable headers lib. Used for portable and optimized kernel libraries. portable_header_lib = name + "_portable_header_lib" - build_portable_header_lib(portable_header_lib, oplist_header_name, feature) + build_portable_header_lib(portable_header_lib, oplist_header_name, feature, platforms = platforms) if "//executorch/kernels/portable:operators" in kernel_deps: # Remove portable from kernel_deps as we're building it from source. @@ -1003,7 +1015,7 @@ def executorch_generated_lib( # Build portable lib. portable_lib_name = name + "_portable_lib" - build_portable_lib(name = portable_lib_name, portable_header_lib = portable_header_lib, feature = feature, expose_operator_symbols = expose_operator_symbols) + build_portable_lib(name = portable_lib_name, portable_header_lib = portable_header_lib, feature = feature, expose_operator_symbols = expose_operator_symbols, platforms = platforms) kernel_deps.append(":{}".format(portable_lib_name)) if "//executorch/kernels/optimized:optimized_operators" in kernel_deps: @@ -1012,7 +1024,7 @@ def executorch_generated_lib( # Build optimized lib. optimized_lib_name = name + "_optimized_lib" - build_optimized_lib(optimized_lib_name, oplist_header_name, portable_header_lib, feature, expose_operator_symbols) + build_optimized_lib(optimized_lib_name, oplist_header_name, portable_header_lib, feature, expose_operator_symbols, platforms = platforms) kernel_deps.append(":{}".format(optimized_lib_name)) # Exports headers that declare the function signatures of the C++ functions @@ -1111,10 +1123,9 @@ def executorch_generated_lib( # # If build successfully, all of the `selected_operators.yaml` will be merged into 1 `selected_operators.yaml` for debugging purpose. def executorch_ops_check( - name, - deps, - **kwargs, -): + name, + deps, + **kwargs): runtime.genrule( name = name, macros_only = False, @@ -1128,16 +1139,15 @@ def executorch_ops_check( platforms = kwargs.pop("platforms", get_default_executorch_platforms()), outs = {"selected_operators.yaml": ["selected_operators.yaml"]}, default_outs = ["."], - **kwargs, + **kwargs ) def check_recursive_dependencies( - name, - parent, - child, - message = "", - **kwargs, -): + name, + parent, + child, + message = "", + **kwargs): """ Checks if child is a transitive dependency of parent and fails if it is. The query runs the equivalent of `buck2 uquery "allpaths(parent, child)". diff --git a/shim_et/xplat/executorch/kernels/prim_ops/selective_build.bzl b/shim_et/xplat/executorch/kernels/prim_ops/selective_build.bzl index a5c89147801..73421f031ec 100644 --- a/shim_et/xplat/executorch/kernels/prim_ops/selective_build.bzl +++ b/shim_et/xplat/executorch/kernels/prim_ops/selective_build.bzl @@ -28,6 +28,7 @@ def prim_ops_registry_selective(name, selected_prim_ops_header_target, aten_suff header_name: [header_name], "selected_prim_ops.h": ["selected_prim_ops.h"] }, + platforms = kwargs.get("platforms", "CXX"), default_outs = ["."], ) runtime.cxx_library(