diff --git a/backends/test/README.md b/backends/test/README.md new file mode 100644 index 00000000000..e69de29bb2d diff --git a/backends/test/TARGETS b/backends/test/TARGETS new file mode 100644 index 00000000000..a6c52d105f6 --- /dev/null +++ b/backends/test/TARGETS @@ -0,0 +1,8 @@ +# Any targets that should be shared between fbcode and xplat must be defined in +# targets.bzl. This file can contain fbcode-only targets. + +load(":targets.bzl", "define_common_targets") + +oncall("executorch") + +define_common_targets(is_fbcode = True) diff --git a/backends/test/multi_method_delegate_test.cpp b/backends/test/multi_method_delegate_test.cpp new file mode 100644 index 00000000000..e24585434c4 --- /dev/null +++ b/backends/test/multi_method_delegate_test.cpp @@ -0,0 +1,164 @@ +#include + +#include +#include +#include +#include + +#include +#include + +#include +#include +#include + +using executorch::runtime::Error; +using executorch::runtime::EValue; +using executorch::runtime::HierarchicalAllocator; +using executorch::runtime::MemoryManager; +using executorch::runtime::Method; +using executorch::runtime::MethodMeta; +using executorch::runtime::Program; +using executorch::runtime::Result; +using executorch::runtime::Span; + +using executorch::extension::FileDataLoader; +using executorch::extension::MallocMemoryAllocator; +using executorch::extension::prepare_input_tensors; + +/* + * Backend agnostic base class. + */ +class ETPTEMethodRunBaseTest : public ::testing::Test { + protected: + void SetUp() override { + executorch::runtime::runtime_init(); + } + + // Runs the PTE e2e without using outside resources. + // This will run in a single thread. + // TODO(T208989128) - Add Synchronizer based run method. + void run( + const int id, + const std::string& kTestPTEPath, + const std::string& kMethodName, + std::atomic& count) const { + Result loader = FileDataLoader::from(kTestPTEPath.c_str()); + ASSERT_EQ(loader.error(), Error::Ok); + + Result program = Program::load( + &loader.get(), Program::Verification::InternalConsistency); + ASSERT_EQ(program.error(), Error::Ok); + + Result method_meta = program->method_meta(kMethodName.c_str()); + ASSERT_EQ(method_meta.error(), Error::Ok); + + const size_t num_memory_planned_buffers = + method_meta->num_memory_planned_buffers(); + + std::vector> planned_buffers; + std::vector> planned_spans; + for (size_t i = 0; i < num_memory_planned_buffers; ++i) { + const size_t buffer_size = + static_cast(method_meta->memory_planned_buffer_size(i).get()); + planned_buffers.push_back(std::make_unique(buffer_size)); + planned_spans.push_back({planned_buffers.back().get(), buffer_size}); + } + + auto method_allocator = std::make_unique(); + auto memory_planned_allocator = std::make_unique( + Span(planned_spans.data(), planned_spans.size())); + auto temp_allocator = std::make_unique(); + + auto memory_manager = std::make_unique( + method_allocator.get(), + memory_planned_allocator.get(), + temp_allocator.get()); + + Result method = + program->load_method(kMethodName.c_str(), memory_manager.get()); + ASSERT_EQ(method.error(), Error::Ok); + + auto inputs = prepare_input_tensors(*method); + ASSERT_EQ(inputs.error(), Error::Ok); + + Error err = method->execute(); + for (int i = 0; i < id % 7; i++) { + err = method->execute(); + ASSERT_EQ(err, Error::Ok); + } + + std::vector outputs(method->outputs_size()); + err = method->get_outputs(outputs.data(), outputs.size()); + ET_CHECK(err == Error::Ok); + // TODO(T208989129) - Add validation of outputs using bundled + // inputs/outputs. + count++; + } +}; + +class XNNPACKMultiDelegateTest : public ETPTEMethodRunBaseTest { + protected: + std::string kTestPTE1Path, kTestPTE2Path; + std::string kMethodName; + int num_threads; + + void SetUp() override { + ETPTEMethodRunBaseTest::SetUp(); + const char* pte1_path = + std::getenv("ET_XNNPACK_GENERATED_ADD_LARGE_PTE_PATH"); + if (pte1_path == nullptr) { + std::cerr << "ET_XNNPACK_GENERATED_ADD_LARGE_PTE_PATH is not set" + << std::endl; + FAIL(); + } + kTestPTE1Path = std::string(pte1_path); + + const char* pte2_path = + std::getenv("ET_XNNPACK_GENERATED_SUB_LARGE_PTE_PATH"); + if (pte1_path == nullptr) { + std::cerr << "ET_XNNPACK_GENERATED_SUB_LARGE_PTE_PATH is not set" + << std::endl; + FAIL(); + } + kTestPTE2Path = std::string(pte2_path); + + num_threads = 40; + kMethodName = "forward"; + } +}; + +// This test is to validate the assumption that the delegate is thread safe. +// That includes the following: +// 1. The delegate can be initilized by multiple threads in parallel. +// 2. The delegate can be executed by multiple threads in parallel. +// 3. The delegate can be destroyed by multiple threads in parallel. +// Regardless of the underlying implementation of the delegate. +// This is particularly important when we have shared resources across +// delegate instances through a singleton backend instance. +TEST_F(XNNPACKMultiDelegateTest, MultipleThreads) { + ASSERT_NE(kTestPTE1Path.size(), 0); + ASSERT_NE(kTestPTE2Path.size(), 0); + ASSERT_NE(num_threads, 0); + ASSERT_NE(kMethodName.size(), 0); + + std::vector threads(num_threads); + std::atomic count{0}; + + for (int i = 0; i < num_threads; i++) { + threads[i] = std::thread([&, i]() { + run(i, i % 7 ? kTestPTE1Path : kTestPTE2Path, kMethodName, count); + }); + } + for (int i = 0; i < num_threads; i++) { + threads[i].join(); + } + ASSERT_EQ(count, num_threads); +} + +// TODO(T208989291): Add more tests here. For example, +// - PTEs with multiple methods +// - PTEs with proucer and consumer relationships in different threads +// - PTEs with more than 1 delegate instances +// - PTEs with different type of delegate instances +// - Add more patterns of delegate initialization and execution diff --git a/backends/test/targets.bzl b/backends/test/targets.bzl new file mode 100644 index 00000000000..6588c57fcc7 --- /dev/null +++ b/backends/test/targets.bzl @@ -0,0 +1,29 @@ +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + +def define_common_targets(is_fbcode = False): + """Defines targets that should be shared between fbcode and xplat. + + The directory containing this targets.bzl file should also contain both + TARGETS and BUCK files that call this function. + """ + if not runtime.is_oss and is_fbcode: + modules_env = { + "ET_XNNPACK_GENERATED_ADD_LARGE_PTE_PATH": "$(location fbcode//executorch/test/models:exported_xnnp_delegated_programs[ModuleAddLarge.pte])", + "ET_XNNPACK_GENERATED_SUB_LARGE_PTE_PATH": "$(location fbcode//executorch/test/models:exported_xnnp_delegated_programs[ModuleSubLarge.pte])", + } + + runtime.cxx_test( + name = "multi_method_delegate_test", + srcs = [ + "multi_method_delegate_test.cpp", + ], + deps = [ + "//executorch/runtime/executor:program", + "//executorch/extension/data_loader:file_data_loader", + "//executorch/extension/memory_allocator:malloc_memory_allocator", + "//executorch/kernels/portable:generated_lib", + "//executorch/backends/xnnpack:xnnpack_backend", + "//executorch/extension/runner_util:inputs", + ], + env = modules_env, + ) diff --git a/test/models/export_delegated_program.py b/test/models/export_delegated_program.py index e9dccdbdf1d..a37fe32e556 100644 --- a/test/models/export_delegated_program.py +++ b/test/models/export_delegated_program.py @@ -13,7 +13,7 @@ import executorch.exir as exir import torch -from executorch.exir import to_edge +from executorch.exir import EdgeCompileConfig, to_edge, to_edge_transform_and_lower from executorch.exir.backend.backend_api import to_backend from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult from executorch.exir.backend.test.backend_with_compiler_demo import ( @@ -52,6 +52,41 @@ def get_random_inputs(self) -> Sequence[torch.Tensor]: return (torch.ones(2, 2), 2 * torch.ones(2, 2), 3 * torch.ones(2, 2)) +class ModuleAddLarge(nn.Module): + def __init__(self): + super().__init__() + + def forward( + self, a: torch.Tensor, b: torch.Tensor, c: torch.Tensor + ) -> torch.Tensor: + x: torch.Tensor = torch.add(a, b) + y: torch.Tensor = torch.add(x, c) + z: torch.Tensor = torch.add(x, y) + return z + + def get_random_inputs(self) -> Sequence[torch.Tensor]: + n = 10 # to create a large tensor + return (torch.ones(n, n, n), 2 * torch.ones(n, n, n), 3 * torch.ones(n, n, n)) + + +class ModuleSubLarge(nn.Module): + def __init__(self): + super().__init__() + + def forward( + self, a: torch.Tensor, b: torch.Tensor, c: torch.Tensor + ) -> torch.Tensor: + x: torch.Tensor = torch.sub(a, b) + y: torch.Tensor = torch.sub(x, c) + z: torch.Tensor = torch.sub(x, y) + w: torch.Tensor = torch.sub(z, c) + return w + + def get_random_inputs(self) -> Sequence[torch.Tensor]: + n = 10 # to create a large tensor + return (torch.ones(n, n, n), 2 * torch.ones(n, n, n), 3 * torch.ones(n, n, n)) + + # # Backends # @@ -95,30 +130,45 @@ def __init__(self, fn): def forward(self, *args, **kwargs): return self.fn(*args, **kwargs) - edge: exir.EdgeProgramManager = to_edge( - export(WrapperModule(getattr(eager_module, method)), args=inputs) + exported_program = export(WrapperModule(getattr(eager_module, method)), args=inputs) + + edge_config = EdgeCompileConfig(_check_ir_validity=False) + et_config = exir.ExecutorchBackendConfig( + extract_delegate_segments=extract_delegate_segments, + constant_tensor_alignment=constant_tensor_alignemnt, + delegate_alignment=delegate_alignment, ) - lowered_module = to_backend(backend_id, edge.exported_program(), compile_specs=[]) + if backend_id == "XnnpackBackend": + from executorch.backends.xnnpack.partition.xnnpack_partitioner import ( + XnnpackPartitioner, + ) - class CompositeModule(nn.Module): - def __init__(self): - super().__init__() - self.lowered_module = lowered_module + executorch_program = to_edge_transform_and_lower( + exported_program, + compile_config=edge_config, + partitioner=[XnnpackPartitioner()], + ).to_executorch(config=et_config) + else: + edge: exir.EdgeProgramManager = to_edge(exported_program) + lowered_module = to_backend( + backend_id, edge.exported_program(), compile_specs=[] + ) - def forward(self, *args, **kwargs): - return self.lowered_module(*args, **kwargs) + class CompositeModule(nn.Module): + def __init__(self): + super().__init__() + self.lowered_module = lowered_module - composite_module = CompositeModule() - composite_module(*inputs) + def forward(self, *args, **kwargs): + return self.lowered_module(*args, **kwargs) - executorch_program = to_edge(export(composite_module, args=inputs)).to_executorch( - config=exir.ExecutorchBackendConfig( - extract_delegate_segments=extract_delegate_segments, - constant_tensor_alignment=constant_tensor_alignemnt, - delegate_alignment=delegate_alignment, - ) - ) + composite_module = CompositeModule() + composite_module(*inputs) + + executorch_program = to_edge( + export(composite_module, args=inputs) + ).to_executorch(config=et_config) return executorch_program.buffer diff --git a/test/models/targets.bzl b/test/models/targets.bzl index aea47c9e036..f291a17c62b 100644 --- a/test/models/targets.bzl +++ b/test/models/targets.bzl @@ -117,6 +117,8 @@ def define_common_targets(): par_style = "xar", deps = [ ":export_delegated_program_lib", + "//executorch/backends/xnnpack/partition:xnnpack_partitioner", + ], visibility = [], # Private ) @@ -124,6 +126,8 @@ def define_common_targets(): # Class names of nn.Modules for :exported_delegated_programs to export. DELEGATED_MODULES_TO_EXPORT = [ "ModuleAddMul", + "ModuleAddLarge", + "ModuleSubLarge", ] # Name of the backend to use when exporting delegated programs. @@ -153,3 +157,23 @@ def define_common_targets(): "//executorch/test/...", ], ) + + runtime.genrule( + name = "exported_xnnp_delegated_programs", + cmd = "$(exe :export_delegated_program)" + + " --modules " + ",".join(DELEGATED_MODULES_TO_EXPORT) + + " --backend_id " + "XnnpackBackend" + + " --outdir $OUT", + outs = { + fname + ".pte": [fname + ".pte"] + for fname in DELEGATED_MODULES_TO_EXPORT + }, + default_outs = ["."], + visibility = [ + "//executorch/runtime/executor/test/...", + "//executorch/backends/test/...", + "//executorch/test/...", + "@EXECUTORCH_CLIENTS", + ], + env = {"PYTORCH_DISABLE_JUSTKNOBS": "1",}, + )