diff --git a/CMakeLists.txt b/CMakeLists.txt index c7ed7b1fcb1..3a3b1f7bfe0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -554,6 +554,10 @@ if(EXECUTORCH_BUILD_PTHREADPOOL AND EXECUTORCH_BUILD_CPUINFO) endif() if(EXECUTORCH_BUILD_PYBIND) + + # Add codegen tools subdirectory for selective_build pybind module + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/codegen/tools) + if(NOT EXECUTORCH_BUILD_EXTENSION_DATA_LOADER) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/data_loader) endif() diff --git a/codegen/tools/CMakeLists.txt b/codegen/tools/CMakeLists.txt new file mode 100644 index 00000000000..6690418dd6f --- /dev/null +++ b/codegen/tools/CMakeLists.txt @@ -0,0 +1,45 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Check if pybind11 is available + +# Create the selective_build pybind11 module +pybind11_add_module(selective_build SHARED selective_build.cpp) + +# Set the output name to match the module name +set_target_properties(selective_build PROPERTIES OUTPUT_NAME "selective_build") + +# Set the module name for the pybind11 module +target_compile_definitions( + selective_build PUBLIC EXECUTORCH_PYTHON_MODULE_NAME=selective_build +) + +# Include directories +target_include_directories( + selective_build PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/../../.. +) + +# Compile options +target_compile_options( + selective_build PUBLIC + -Wno-deprecated-declarations + -fPIC + -frtti + -fexceptions +) + +# Link against required libraries +target_link_libraries( + selective_build PRIVATE + executorch_core + program_schema +) + +# Install the module +install(TARGETS selective_build + LIBRARY DESTINATION executorch/codegen/tools +) diff --git a/codegen/tools/gen_oplist.py b/codegen/tools/gen_oplist.py index 3d26797fb24..cca5bf1b1d2 100644 --- a/codegen/tools/gen_oplist.py +++ b/codegen/tools/gen_oplist.py @@ -20,7 +20,6 @@ # We can use relative import instead. from ..parse import strip_et_fields - from torchgen.gen import LineLoader, parse_native_yaml_struct from torchgen.selective_build.operator import SelectiveBuildOperator from torchgen.selective_build.selector import merge_et_kernel_metadata @@ -102,7 +101,6 @@ def _get_operators(model_file: str) -> List[str]: def _get_kernel_metadata_for_model(model_file: str) -> Dict[str, List[str]]: - from executorch.codegen.tools.selective_build import ( # type: ignore[import-not-found] _get_io_metadata_for_program_operators, _get_program_from_buffer, diff --git a/codegen/tools/selective_build.cpp b/codegen/tools/selective_build.cpp new file mode 100644 index 00000000000..d33ff12ec9f --- /dev/null +++ b/codegen/tools/selective_build.cpp @@ -0,0 +1,269 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include +#include + +namespace py = pybind11; + +namespace torch { +namespace executor { + +namespace { + +// Metadata for kernel call io variables. +// dtype and dim_order will exist only if corresponding variable is Tensor. +struct IOMetaData { + int kernel_type; + int dtype; + std::vector dim_order; + + // Create tensor metadata. It records tensor's dtype and dim order. + explicit IOMetaData(const executorch_flatbuffer::Tensor* t) + : kernel_type( + static_cast(executorch_flatbuffer::KernelTypes::Tensor)), + dtype(static_cast(t->scalar_type())) { + for (size_t i = 0; i < t->dim_order()->size(); i++) { + dim_order.push_back(static_cast(t->dim_order()->Get(i))); + } + } + + // Create metadata for non-tensor variable. + explicit IOMetaData(executorch_flatbuffer::KernelTypes type) + : kernel_type(static_cast(type)) { + ET_CHECK( + type != executorch_flatbuffer::KernelTypes::Tensor && + type != executorch_flatbuffer::KernelTypes::TensorList && + type != executorch_flatbuffer::KernelTypes::OptionalTensorList); + } +}; + +struct KernelIOMetaDataComparsion { + bool operator()( + const std::vector& lhs, + const std::vector& rhs) const { + if (lhs.size() != rhs.size()) { + return lhs.size() < rhs.size(); + } + for (size_t i = 0; i < lhs.size(); i++) { + if (lhs[i].kernel_type != rhs[i].kernel_type) { + return lhs[i].kernel_type < rhs[i].kernel_type; + } + if (lhs[i].kernel_type != + static_cast(executorch_flatbuffer::KernelTypes::Tensor)) { + continue; + } + if (lhs[i].dtype != rhs[i].dtype) { + return lhs[i].dtype < rhs[i].dtype; + } + if (lhs[i].dim_order != rhs[i].dim_order) { + return lhs[i].dim_order < rhs[i].dim_order; + } + } + return false; + } +}; + +using KernelIOMetadata = std::vector; + +using OpIOMetaData = std::set; + +std::vector get_operators_from_execution_plan( + const executorch_flatbuffer::ExecutionPlan& plan) { + std::vector op_names; + for (const executorch_flatbuffer::Operator* op : *plan.operators()) { + if (op->overload()->str().empty()) { + op_names.push_back(op->name()->str()); + } else { + op_names.push_back(op->name()->str() + "." + op->overload()->str()); + } + } + return op_names; +} + +std::map +get_kernel_tensor_metadatas_from_execution_plan( + const executorch_flatbuffer::ExecutionPlan* plan) { + std::map op_io_metadata; + for (const executorch_flatbuffer::Chain* chain : *plan->chains()) { + for (const executorch_flatbuffer::Instruction* inst : + *chain->instructions()) { + if (inst->instr_args_type() == + executorch_flatbuffer::InstructionArguments::KernelCall) { + const executorch_flatbuffer::KernelCall* kernel_call = + inst->instr_args_as_KernelCall(); + const executorch_flatbuffer::Operator* op = + plan->operators()->Get(kernel_call->op_index()); + std::string op_overload_name = op->name()->str(); + if (op->overload()->size()) { + op_overload_name += "." + op->overload()->str(); + } + + // create an empty entry if current kernel is not in the map. + if (op_io_metadata.count(op_overload_name) == 0) { + op_io_metadata.insert( + std::make_pair(op_overload_name, OpIOMetaData())); + } + + // go through IOs of this operator and collect tensor metadatas. + KernelIOMetadata kernel_io_metadata; + for (int arg_id : *kernel_call->args()) { + const executorch_flatbuffer::EValue* arg = + plan->values()->Get(arg_id); + if (arg->val_type() == executorch_flatbuffer::KernelTypes::Tensor) { + kernel_io_metadata.push_back(IOMetaData(arg->val_as_Tensor())); + } else if ( + arg->val_type() == + executorch_flatbuffer::KernelTypes::TensorList) { + if (arg->val_as_TensorList()->items()->size() == 0) { + // treat empty tensor list as null type since we can not get + // metadata from it. + kernel_io_metadata.push_back( + IOMetaData(executorch_flatbuffer::KernelTypes::Null)); + } else { + // all eles in TensorList are tensor and share same tensor + // metadata. use the metadata of first element as the metadata for + // whole list. + const executorch_flatbuffer::Tensor* tensor_arg = + plan->values() + ->Get(arg->val_as_TensorList()->items()->Get(0)) + ->val_as_Tensor(); + kernel_io_metadata.push_back(IOMetaData(tensor_arg)); + } + } else if ( + arg->val_type() == + executorch_flatbuffer::KernelTypes::OptionalTensorList) { + // all eles in OptionalTensorList are either tensor or null, and all + // tensors share same metadata. Use the metadata of first tensor + // element as the metadata for whole list. If no tensor exists (e.g. + // each element is None), treat the whole list as a single null + // element. + const executorch_flatbuffer::OptionalTensorList* opt_tensor_list = + arg->val_as_OptionalTensorList(); + + // Find one non-null tensor + bool found_tensor_element = false; + for (size_t i = 0; i < opt_tensor_list->items()->size(); i++) { + // We now adopt both index == -1 and actually serialize a null + // type EValue to represent a null data. + if (opt_tensor_list->items()->Get(i) != -1 && + plan->values() + ->Get(opt_tensor_list->items()->Get(i)) + ->val_type() == + executorch_flatbuffer::KernelTypes::Tensor) { + const executorch_flatbuffer::Tensor* tensor_arg = + plan->values() + ->Get(opt_tensor_list->items()->Get(i)) + ->val_as_Tensor(); + kernel_io_metadata.push_back(IOMetaData(tensor_arg)); + found_tensor_element = true; + break; + } + } + if (!found_tensor_element) { + kernel_io_metadata.push_back( + IOMetaData(executorch_flatbuffer::KernelTypes::Null)); + } + } else { + kernel_io_metadata.push_back(IOMetaData(arg->val_type())); + } + } + op_io_metadata[op_overload_name].insert(kernel_io_metadata); + } + } + } + return op_io_metadata; +} +} // namespace + +const executorch_flatbuffer::Program* _get_program_from_buffer( + const py::bytes& buffer) { + return executorch_flatbuffer::GetProgram( + buffer.cast().data()); +} + +py::list _get_program_operators(const executorch_flatbuffer::Program* program) { + const auto& plans = *program->execution_plan(); + std::vector op_names; + for (const auto& plan : plans) { + auto plan_ops = get_operators_from_execution_plan(*plan); + if (!plan_ops.empty()) { + op_names.insert(op_names.end(), plan_ops.begin(), plan_ops.end()); + } + } + return py::cast(op_names); +} + +// expose IO metadatas for all operators in given program +py::dict _get_io_metadata_for_program_operators( + const executorch_flatbuffer::Program* program) { + const auto& plans = *program->execution_plan(); + std::map program_op_io_metadata; + + // aggregrate op metadata from different execution plan. + for (const executorch_flatbuffer::ExecutionPlan* plan : plans) { + std::map plan_op_io_metadata = + get_kernel_tensor_metadatas_from_execution_plan(plan); + + for (const auto& op_io_metadata : plan_op_io_metadata) { + std::string op_name = op_io_metadata.first; + if (program_op_io_metadata.count(op_name) == 0) { + program_op_io_metadata.insert(std::make_pair(op_name, OpIOMetaData())); + } + program_op_io_metadata[op_name].insert( + plan_op_io_metadata[op_name].begin(), + plan_op_io_metadata[op_name].end()); + } + } + + // convert program_op_io_metadata to py data structure. + py::dict py_program_op_io_metadata; + for (const auto& op_io_meta : program_op_io_metadata) { + py::set py_op_io_meta; + for (const auto& io_metas : op_io_meta.second) { + py::list py_io_metadatas; + for (const auto& io_metadata : io_metas) { + py_io_metadatas.append(io_metadata); + } + py_op_io_meta.add(py::tuple(py_io_metadatas)); + } + py_program_op_io_metadata[op_io_meta.first.data()] = py_op_io_meta; + } + + return py_program_op_io_metadata; +} + +PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) { + py::class_(m, "_Program"); + + m.def( + "_get_program_from_buffer", + &_get_program_from_buffer, + py::return_value_policy::reference); + + m.def( + "_get_program_operators", + &_get_program_operators, + py::return_value_policy::copy); + + m.def( + "_get_io_metadata_for_program_operators", + &_get_io_metadata_for_program_operators, + py::return_value_policy::copy); + + py::class_(m, "_IOMetaData") + .def_readwrite("kernel_type", &IOMetaData::kernel_type) + .def_readwrite("dtype", &IOMetaData::dtype) + .def_readwrite("dim_order", &IOMetaData::dim_order); +} + +} // namespace executor +} // namespace torch diff --git a/codegen/tools/selective_build.pyi b/codegen/tools/selective_build.pyi new file mode 100644 index 00000000000..c80213623f1 --- /dev/null +++ b/codegen/tools/selective_build.pyi @@ -0,0 +1,23 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Dict, List + +class _Program: ... + +class _IOMetaData: + @property + def kernel_type(self) -> int: ... + @property + def dtype(self) -> int: ... + @property + def dim_order(self) -> List[int]: ... + +def _get_program_from_buffer(buffer: bytes) -> _Program: ... +def _get_program_operators(program: _Program) -> List[str]: ... +def _get_io_metadata_for_program_operators( + program: _Program, +) -> Dict[str, Any]: ... diff --git a/codegen/tools/targets.bzl b/codegen/tools/targets.bzl index bf298a76d44..39de8fcb482 100644 --- a/codegen/tools/targets.bzl +++ b/codegen/tools/targets.bzl @@ -19,7 +19,7 @@ def define_common_targets(is_fbcode = False): "//executorch/codegen:gen_lib", ] + select({ "DEFAULT": [], - "ovr_config//os:linux": [] if runtime.is_oss else ["//executorch/codegen/tools/fb:selective_build"], # TODO(larryliu0820) :selective_build doesn't build in OSS yet + "ovr_config//os:linux": [] if runtime.is_oss else ["//executorch/codegen/tools:selective_build"], # TODO(larryliu0820) :selective_build doesn't build in OSS yet }), ) @@ -29,7 +29,7 @@ def define_common_targets(is_fbcode = False): deps = [ ":gen_oplist_lib", ], - preload_deps = [] if runtime.is_oss else ["//executorch/codegen/tools/fb:selective_build"], # TODO(larryliu0820) :selective_build doesn't build in OSS yet + preload_deps = [] if runtime.is_oss else ["//executorch/codegen/tools:selective_build"], # TODO(larryliu0820) :selective_build doesn't build in OSS yet package_style = "inplace", visibility = [ "//executorch/...", @@ -155,6 +155,29 @@ def define_common_targets(is_fbcode = False): _is_external_target = True, ) + if not runtime.is_oss: + runtime.cxx_python_extension( + name = "selective_build", + srcs = [ + "selective_build.cpp", + ], + base_module = "executorch.codegen.tools", + types = ["selective_build.pyi"], + preprocessor_flags = [ + "-DEXECUTORCH_PYTHON_MODULE_NAME=selective_build", + ], + deps = [ + "//executorch/runtime/core:core", + "//executorch/schema:program", + ], + external_deps = [ + "pybind11", + ], + use_static_deps = True, + visibility = ["//executorch/codegen/..."], + ) + + # TODO(larryliu0820): This is a hack to only run these two on fbcode. These targets depends on exir which is only available in fbcode. if not runtime.is_oss and is_fbcode: runtime.python_binary( @@ -190,3 +213,21 @@ def define_common_targets(is_fbcode = False): "//libfb/py:parutil", ], ) + + runtime.python_test( + name = "test_selective_build", + srcs = [ + "test/test_selective_build.py", + ], + package_style = "inplace", + visibility = [ + "PUBLIC", + ], + deps = [ + ":selective_build", + "fbsource//third-party/pypi/expecttest:expecttest", + "//caffe2:torch", + "//executorch/exir:lib", + ], + _is_external_target = True, + ) diff --git a/codegen/tools/test/test_selective_build.py b/codegen/tools/test/test_selective_build.py new file mode 100644 index 00000000000..5d8cd80659f --- /dev/null +++ b/codegen/tools/test/test_selective_build.py @@ -0,0 +1,129 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest +from typing import Any, Optional, Tuple, Union + +import torch + +from executorch.codegen.tools.selective_build import ( # type: ignore[import-not-found] + _get_io_metadata_for_program_operators, + _get_program_from_buffer, + _get_program_operators, + _IOMetaData, +) +from executorch.exir import ExecutorchProgramManager, to_edge +from executorch.exir.scalar_type import ScalarType +from torch.export import export + + +class ModuleAdd(torch.nn.Module): + """The module to serialize and execute.""" + + def __init__(self): + super(ModuleAdd, self).__init__() + + def forward(self, x, y): + return x + y + + def get_methods_to_export(self): + return ("forward",) + + +class ModuleMulti(torch.nn.Module): + """The module to serialize and execute.""" + + def __init__(self): + super(ModuleMulti, self).__init__() + + def forward(self, x, y): + return x + y + + def forward2(self, x, y): + return x + y + 1 + + def get_methods_to_export(self): + return ("forward", "forward2") + + +def create_program( + eager_module: Optional[Union[ModuleAdd, ModuleMulti]] = None, +) -> Tuple[ExecutorchProgramManager, Tuple[Any, ...]]: + """Returns an executorch program based on ModuleAdd, along with inputs.""" + + if eager_module is None: + eager_module = ModuleAdd() + + class WrapperModule(torch.nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + + def forward(self, *args, **kwargs): + return self.fn(*args, **kwargs) + + # Trace the test module and create a serialized ExecuTorch program. + inputs = (torch.ones(2, 2), torch.ones(2, 2)) + input_map = {} + # pyre-fixme[29]: `Union[torch._tensor.Tensor, torch.nn.modules.module.Module]` + # is not a function. + for method in eager_module.get_methods_to_export(): + input_map[method] = inputs + + exported_methods = {} + # These cleanup passes are required to convert the `add` op to its out + # variant, along with some other transformations. + for method_name, method_input in input_map.items(): + module = WrapperModule(getattr(eager_module, method_name)) + exported_methods[method_name] = export(module, method_input, strict=True) + + exec_prog = to_edge(exported_methods).to_executorch() + + # Create the ExecuTorch program from the graph. + exec_prog.dump_executorch_program(verbose=True) + return (exec_prog, inputs) + + +class PybindingsTest(unittest.TestCase): + def test_dump_operators(self): + # Create and serialize a program. + orig_program, _ = create_program() + + # Deserialize the program and demonstrate that we could get its operator + # list. + program = _get_program_from_buffer(orig_program.buffer) + operators = _get_program_operators(program) + self.assertEqual(operators, ["aten::add.out"]) + + def test_get_op_io_meta(self): + # Checking whether get_op_io_meta returns the correct metadata for all its ios. + orig_program, inputs = create_program() + + # Deserialize the program and demonstrate that we could get its operator + # list. + program = _get_program_from_buffer(orig_program.buffer) + program_op_io_metadata = _get_io_metadata_for_program_operators(program) + + self.assertTrue(len(program_op_io_metadata) == 1) + self.assertTrue(isinstance(program_op_io_metadata, dict)) + + self.assertTrue("aten::add.out" in program_op_io_metadata) + self.assertTrue(isinstance(program_op_io_metadata["aten::add.out"], set)) + self.assertTrue(len(program_op_io_metadata["aten::add.out"]) == 1) + + for op_io_metadata in program_op_io_metadata["aten::add.out"]: + self.assertTrue(len(op_io_metadata) == 5) + self.assertTrue(isinstance(op_io_metadata, tuple)) + + for io_idx, io_metadata in enumerate(op_io_metadata): + self.assertTrue(isinstance(io_metadata, _IOMetaData)) + if io_idx == 2: + # TODO(gasoonjia): Create a enum class to map KernelTypes to int, remove the hardcoded 2 and 5 below. + self.assertEqual(io_metadata.kernel_type, 2) + else: + self.assertEqual(io_metadata.kernel_type, 5) + self.assertEqual(io_metadata.dtype, ScalarType.FLOAT) + self.assertEqual(io_metadata.dim_order, [0, 1]) diff --git a/setup.py b/setup.py index 86b946eebbd..cb0dcbbe9f7 100644 --- a/setup.py +++ b/setup.py @@ -729,6 +729,7 @@ def run(self): # noqa C901 if cmake_cache.is_enabled("EXECUTORCH_BUILD_PYBIND"): cmake_build_args += ["--target", "portable_lib"] + cmake_build_args += ["--target", "selective_build"] if cmake_cache.is_enabled("EXECUTORCH_BUILD_EXTENSION_TRAINING"): cmake_build_args += ["--target", "_training_lib"] @@ -790,6 +791,11 @@ def run(self): # noqa C901 modpath="executorch.extension.training.pybindings._training_lib", dependent_cmake_flags=["EXECUTORCH_BUILD_EXTENSION_TRAINING"], ), + BuiltExtension( + src="codegen/tools/selective_build.*", + modpath="executorch.codegen.tools.selective_build", + dependent_cmake_flags=["EXECUTORCH_BUILD_PYBIND"], + ), BuiltExtension( src="executorchcoreml.*", src_dir="backends/apple/coreml",