From c504b9b890625bb3adab17120a5ce45bc7c29a2f Mon Sep 17 00:00:00 2001 From: Gregory Comer Date: Tue, 20 May 2025 23:03:02 -0700 Subject: [PATCH] Introduce PAL function table (#10675) Summary: This is the implementation of the PAL changes described in https://github.com/pytorch/executorch/discussions/10432. This PR introduces a struct (`pal_table`) to hold function pointers to the PAL function implementations. There is a singleton instance of this struct, which is initialized with the weak/strong et_pal_ functions - maintaining backwards compatibility with the existing override mechanism. I've then added wrapper functions for the PAL into the executorch::runtime namespace which dispatch through the function table. It is intended that callers use these functions instead of the "raw" et_pal_ methods in order to correctly dispatch through the function table. I also update ET callers to do this. Differential Revision: D74121895 --- .github/workflows/pull.yml | 4 +- .../xnnpack/runtime/profiling/XNNProfiler.cpp | 4 +- devtools/etdump/etdump_flatcc.cpp | 8 +- .../executor_runner/executor_runner.cpp | 6 +- runtime/executor/platform_memory_allocator.h | 4 +- runtime/platform/abort.cpp | 2 +- runtime/platform/log.cpp | 4 +- runtime/platform/platform.cpp | 160 ++++++++++++++++++ runtime/platform/platform.h | 139 +++++++++++++++ runtime/platform/targets.bzl | 1 + runtime/platform/test/CMakeLists.txt | 4 + .../test/executor_pal_override_test.cpp | 74 +------- .../executor_pal_runtime_override_test.cpp | 130 ++++++++++++++ ...cutor_pal_static_runtime_override_test.cpp | 94 ++++++++++ runtime/platform/test/pal_spy.h | 86 ++++++++++ runtime/platform/test/targets.bzl | 13 ++ 16 files changed, 645 insertions(+), 88 deletions(-) create mode 100644 runtime/platform/platform.cpp create mode 100644 runtime/platform/test/executor_pal_runtime_override_test.cpp create mode 100644 runtime/platform/test/executor_pal_static_runtime_override_test.cpp create mode 100644 runtime/platform/test/pal_spy.h diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index 5b3dd671701..5b45e203f16 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -371,7 +371,7 @@ jobs: size=${arr[4]} # threshold=48120 on devserver with gcc11.4 # todo(lfq): update once binary size is below 50kb. - threshold="55504" + threshold="55584" if [[ "$size" -le "$threshold" ]]; then echo "Success $size <= $threshold" else @@ -406,7 +406,7 @@ jobs: output=$(ls -la cmake-out/test/size_test) arr=($output) size=${arr[4]} - threshold="51656" + threshold="51728" if [[ "$size" -le "$threshold" ]]; then echo "Success $size <= $threshold" else diff --git a/backends/xnnpack/runtime/profiling/XNNProfiler.cpp b/backends/xnnpack/runtime/profiling/XNNProfiler.cpp index 72614083c74..aed3394c714 100644 --- a/backends/xnnpack/runtime/profiling/XNNProfiler.cpp +++ b/backends/xnnpack/runtime/profiling/XNNProfiler.cpp @@ -62,7 +62,7 @@ Error XNNProfiler::start(EventTracer* event_tracer) { state_ = XNNProfilerState::Running; // Log the start of execution timestamp. - start_time_ = et_pal_current_ticks(); + start_time_ = runtime::pal_current_ticks(); return Error::Ok; } @@ -187,7 +187,7 @@ void XNNProfiler::log_operator_timings() { void XNNProfiler::submit_trace() { // Retrieve the system tick rate (ratio between ticks and nanoseconds). - auto tick_ns_conv_multiplier = et_pal_ticks_to_ns_multiplier(); + auto tick_ns_conv_multiplier = runtime::pal_ticks_to_ns_multiplier(); ET_CHECK(op_timings_.size() == op_count_); size_t name_len = 0; diff --git a/devtools/etdump/etdump_flatcc.cpp b/devtools/etdump/etdump_flatcc.cpp index ccca2beb257..a6e0a105069 100644 --- a/devtools/etdump/etdump_flatcc.cpp +++ b/devtools/etdump/etdump_flatcc.cpp @@ -224,7 +224,7 @@ EventTracerEntry ETDumpGen::start_profiling( prof_entry.chain_id = chain_id; prof_entry.debug_handle = debug_handle; } - prof_entry.start_time = et_pal_current_ticks(); + prof_entry.start_time = runtime::pal_current_ticks(); return prof_entry; } @@ -246,7 +246,7 @@ EventTracerEntry ETDumpGen::start_profiling_delegate( prof_entry.event_id = delegate_debug_index == kUnsetDelegateDebugIntId ? create_string_entry(name) : delegate_debug_index; - prof_entry.start_time = et_pal_current_ticks(); + prof_entry.start_time = runtime::pal_current_ticks(); return prof_entry; } @@ -254,7 +254,7 @@ void ETDumpGen::end_profiling_delegate( EventTracerEntry event_tracer_entry, const void* metadata, size_t metadata_len) { - et_timestamp_t end_time = et_pal_current_ticks(); + et_timestamp_t end_time = runtime::pal_current_ticks(); check_ready_to_add_events(); // Start building the ProfileEvent entry. @@ -469,7 +469,7 @@ Result ETDumpGen::log_intermediate_output_delegate_helper( } void ETDumpGen::end_profiling(EventTracerEntry prof_entry) { - et_timestamp_t end_time = et_pal_current_ticks(); + et_timestamp_t end_time = runtime::pal_current_ticks(); ET_CHECK_MSG( prof_entry.delegate_event_id_type == DelegateDebugIdType::kNone, "Delegate events must use end_profiling_delegate to mark the end of a delegate profiling event."); diff --git a/examples/portable/executor_runner/executor_runner.cpp b/examples/portable/executor_runner/executor_runner.cpp index b5d008ff1e7..434b4783bac 100644 --- a/examples/portable/executor_runner/executor_runner.cpp +++ b/examples/portable/executor_runner/executor_runner.cpp @@ -269,9 +269,11 @@ int main(int argc, char** argv) { (uint32_t)inputs.error()); ET_LOG(Debug, "Inputs prepared."); - const et_timestamp_t before_execute = et_pal_current_ticks(); + const et_timestamp_t before_execute = + executorch::runtime::pal_current_ticks(); Error status = method->execute(); - const et_timestamp_t after_execute = et_pal_current_ticks(); + const et_timestamp_t after_execute = + executorch::runtime::pal_current_ticks(); time_spent_executing += after_execute - before_execute; ET_CHECK_MSG( status == Error::Ok, diff --git a/runtime/executor/platform_memory_allocator.h b/runtime/executor/platform_memory_allocator.h index 7ab58bf0f3d..5951f116d3d 100644 --- a/runtime/executor/platform_memory_allocator.h +++ b/runtime/executor/platform_memory_allocator.h @@ -48,7 +48,7 @@ class PlatformMemoryAllocator final : public MemoryAllocator { // Allocate enough memory for the node, the data and the alignment bump. size_t alloc_size = sizeof(AllocationNode) + size + alignment; - void* node_memory = et_pal_allocate(alloc_size); + void* node_memory = runtime::pal_allocate(alloc_size); // If allocation failed, log message and return nullptr. if (node_memory == nullptr) { @@ -87,7 +87,7 @@ class PlatformMemoryAllocator final : public MemoryAllocator { AllocationNode* current = head_; while (current != nullptr) { AllocationNode* next = current->next; - et_pal_free(current); + runtime::pal_free(current); current = next; } head_ = nullptr; diff --git a/runtime/platform/abort.cpp b/runtime/platform/abort.cpp index 27320e4845a..155726fb2e7 100644 --- a/runtime/platform/abort.cpp +++ b/runtime/platform/abort.cpp @@ -17,7 +17,7 @@ namespace runtime { * up, and set an abnormal exit status (platform-defined). */ ET_NORETURN void runtime_abort() { - et_pal_abort(); + pal_abort(); } } // namespace runtime diff --git a/runtime/platform/log.cpp b/runtime/platform/log.cpp index 6529c73b238..b338ee10a71 100644 --- a/runtime/platform/log.cpp +++ b/runtime/platform/log.cpp @@ -23,7 +23,7 @@ namespace internal { * @retval Monotonically non-decreasing timestamp in system ticks. */ et_timestamp_t get_log_timestamp() { - return et_pal_current_ticks(); + return pal_current_ticks(); } // Double-check that the log levels are ordered from lowest to highest severity. @@ -96,7 +96,7 @@ void vlogf( ? kLevelToPal[size_t(level)] : et_pal_log_level_t::kUnknown; - et_pal_emit_log_message( + pal_emit_log_message( timestamp, pal_level, filename, function, line, buf, len); #endif // ET_LOG_ENABLED diff --git a/runtime/platform/platform.cpp b/runtime/platform/platform.cpp new file mode 100644 index 00000000000..850051f1fe1 --- /dev/null +++ b/runtime/platform/platform.cpp @@ -0,0 +1,160 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +namespace executorch::runtime { + +namespace { +/** + * The singleton instance of the PAL function table. + */ +PalImpl pal_impl = { + et_pal_init, + et_pal_abort, + et_pal_current_ticks, + et_pal_ticks_to_ns_multiplier, + et_pal_emit_log_message, + et_pal_allocate, + et_pal_free, + __FILE__}; + +/** + * Tracks whether the PAL has been overridden. This is used to warn when + * multiple callers override the PAL. + */ +bool is_pal_overridden = false; +} // namespace + +PalImpl PalImpl::create( + pal_emit_log_message_method emit_log_message, + const char* source_filename) { + return PalImpl::create( + nullptr, // init + nullptr, // abort + nullptr, // current_ticks + nullptr, // ticks_to_ns_multiplier + emit_log_message, + nullptr, // allocate + nullptr, // free + source_filename); +} + +PalImpl PalImpl::create( + pal_init_method init, + pal_abort_method abort, + pal_current_ticks_method current_ticks, + pal_ticks_to_ns_multiplier_method ticks_to_ns_multiplier, + pal_emit_log_message_method emit_log_message, + pal_allocate_method allocate, + pal_free_method free, + const char* source_filename) { + return PalImpl{ + init, + abort, + current_ticks, + ticks_to_ns_multiplier, + emit_log_message, + allocate, + free, + source_filename}; +} + +/** + * Override the PAL functions with user implementations. Any null entries in the + * table are unchanged and will keep the default implementation. + */ +bool register_pal(PalImpl impl) { + if (is_pal_overridden) { + ET_LOG( + Error, + "register_pal() called multiple times. Subsequent calls will override the previous implementation. Previous implementation was registered from %s.", + impl.source_filename != nullptr ? impl.source_filename : "unknown"); + } + is_pal_overridden = true; + + if (impl.abort != nullptr) { + pal_impl.abort = impl.abort; + } + + if (impl.current_ticks != nullptr) { + pal_impl.current_ticks = impl.current_ticks; + } + + if (impl.ticks_to_ns_multiplier != nullptr) { + pal_impl.ticks_to_ns_multiplier = impl.ticks_to_ns_multiplier; + } + + if (impl.emit_log_message != nullptr) { + pal_impl.emit_log_message = impl.emit_log_message; + } + + if (impl.allocate != nullptr) { + pal_impl.allocate = impl.allocate; + } + + if (impl.free != nullptr) { + pal_impl.free = impl.free; + } + + if (impl.init != nullptr) { + pal_impl.init = impl.init; + if (pal_impl.init != nullptr) { + pal_impl.init(); + } + } + + return true; +} + +const PalImpl* get_pal_impl() { + return &pal_impl; +} + +void pal_init() { + pal_impl.init(); +} + +ET_NORETURN void pal_abort() { + pal_impl.abort(); + // This should be unreachable, but in case the PAL implementation doesn't + // abort, force it here. + std::abort(); +} + +et_timestamp_t pal_current_ticks() { + return pal_impl.current_ticks(); +} + +et_tick_ratio_t pal_ticks_to_ns_multiplier() { + return pal_impl.ticks_to_ns_multiplier(); +} + +void pal_emit_log_message( + et_timestamp_t timestamp, + et_pal_log_level_t level, + const char* filename, + const char* function, + size_t line, + const char* message, + size_t length) { + pal_impl.emit_log_message( + timestamp, level, filename, function, line, message, length); +} + +void* pal_allocate(size_t size) { + return pal_impl.allocate(size); +} + +void pal_free(void* ptr) { + pal_impl.free(ptr); +} + +} // namespace executorch::runtime diff --git a/runtime/platform/platform.h b/runtime/platform/platform.h index 03cdef8eb2f..72353054914 100644 --- a/runtime/platform/platform.h +++ b/runtime/platform/platform.h @@ -11,6 +11,10 @@ * Platform abstraction layer to allow individual platform libraries to override * symbols in ExecuTorch. PAL functions are defined as C functions so a platform * library implementer can use C in lieu of C++. + * + * The et_pal_ methods should not be called directly. Use the corresponding + * methods in the executorch::runtime namespace instead to appropriately + * dispatch through the PAL function table. */ #pragma once @@ -53,12 +57,14 @@ typedef struct { * to initialize any global state. Typically overridden by PAL implementer. */ void et_pal_init(void) ET_INTERNAL_PLATFORM_WEAKNESS; +using pal_init_method = void (*)(); /** * Immediately abort execution, setting the device into an error state, if * available. */ ET_NORETURN void et_pal_abort(void) ET_INTERNAL_PLATFORM_WEAKNESS; +using pal_abort_method = void (*)(); /** * Return a monotonically non-decreasing timestamp in system ticks. @@ -66,6 +72,8 @@ ET_NORETURN void et_pal_abort(void) ET_INTERNAL_PLATFORM_WEAKNESS; * @retval Timestamp value in system ticks. */ et_timestamp_t et_pal_current_ticks(void) ET_INTERNAL_PLATFORM_WEAKNESS; +typedef et_timestamp_t (*et_pal_current_ticks_t)(void); +using pal_current_ticks_method = et_timestamp_t (*)(); /** * Return the conversion rate from system ticks to nanoseconds as a fraction. @@ -81,6 +89,7 @@ et_timestamp_t et_pal_current_ticks(void) ET_INTERNAL_PLATFORM_WEAKNESS; */ et_tick_ratio_t et_pal_ticks_to_ns_multiplier(void) ET_INTERNAL_PLATFORM_WEAKNESS; +using pal_ticks_to_ns_multiplier_method = et_tick_ratio_t (*)(); /** * Severity level of a log message. Values must map to printable 7-bit ASCII @@ -114,6 +123,14 @@ void et_pal_emit_log_message( size_t line, const char* message, size_t length) ET_INTERNAL_PLATFORM_WEAKNESS; +using pal_emit_log_message_method = void (*)( + et_timestamp_t timestamp, + et_pal_log_level_t level, + const char* filename, + const char* function, + size_t line, + const char* message, + size_t length); /** * NOTE: Core runtime code must not call this directly. It may only be called by @@ -126,6 +143,7 @@ void et_pal_emit_log_message( * et_pal_free(). */ void* et_pal_allocate(size_t size) ET_INTERNAL_PLATFORM_WEAKNESS; +using pal_allocate_method = void* (*)(size_t size); /** * Frees memory allocated by et_pal_allocate(). @@ -133,5 +151,126 @@ void* et_pal_allocate(size_t size) ET_INTERNAL_PLATFORM_WEAKNESS; * @param[in] ptr Pointer to memory to free. May be nullptr. */ void et_pal_free(void* ptr) ET_INTERNAL_PLATFORM_WEAKNESS; +using pal_free_method = void (*)(void* ptr); } // extern "C" + +namespace executorch::runtime { + +/** + * Table of pointers to platform abstraction layer functions. + */ +struct PalImpl { + // Note that this struct cannot contain constructors in order to ensure that + // the singleton instance can be initialized without relying on a global + // constructor. If it does require a global constructor, there can be a race + // between the init of the default PAL and the user static registration code. + static PalImpl create( + pal_emit_log_message_method emit_log_message, + const char* source_filename); + + static PalImpl create( + pal_init_method init, + pal_abort_method abort, + pal_current_ticks_method current_ticks, + pal_ticks_to_ns_multiplier_method ticks_to_ns_multiplier, + pal_emit_log_message_method emit_log_message, + pal_allocate_method allocate, + pal_free_method free, + const char* source_filename); + + pal_init_method init = nullptr; + pal_abort_method abort = nullptr; + pal_current_ticks_method current_ticks = nullptr; + pal_ticks_to_ns_multiplier_method ticks_to_ns_multiplier = nullptr; + pal_emit_log_message_method emit_log_message = nullptr; + pal_allocate_method allocate = nullptr; + pal_free_method free = nullptr; + + // An optional metadata field, indicating the name of the source + // file that registered the PAL implementation. + const char* source_filename; +}; + +/** + * Override the PAL functions with user implementations. Any null entries in the + * table are unchanged and will keep the default implementation. + * + * Returns true if the registration was successful, false otherwise. + */ +bool register_pal(PalImpl impl); + +/** + * Returns the PAL function table, which contains function pointers to the + * active implementation of each PAL function. + */ +const PalImpl* get_pal_impl(); + +/** + * Initialize the platform abstraction layer. + * + * This function should be called before any other function provided by the PAL + * to initialize any global state. Typically overridden by PAL implementer. + */ +void pal_init(); + +/** + * Immediately abort execution, setting the device into an error state, if + * available. + */ +ET_NORETURN void pal_abort(); + +/** + * Return a monotonically non-decreasing timestamp in system ticks. + * + * @retval Timestamp value in system ticks. + */ +et_timestamp_t pal_current_ticks(); + +/** + * Return the conversion rate from system ticks to nanoseconds as a fraction. + * To convert a system ticks to nanoseconds, multiply the tick count by the + * numerator and then divide by the denominator: + * nanoseconds = ticks * numerator / denominator + * + * The utility method executorch::runtime::ticks_to_ns(et_timestamp_t) can also + * be used to perform the conversion for a given tick count. It is defined in + * torch/executor/runtime/platform/clock.h. + * + * @retval The ratio of nanoseconds to system ticks. + */ +et_tick_ratio_t pal_ticks_to_ns_multiplier(); + +/** + * Severity level of a log message. Values must map to printable 7-bit ASCII + * uppercase letters. + */ +void pal_emit_log_message( + et_timestamp_t timestamp, + et_pal_log_level_t level, + const char* filename, + const char* function, + size_t line, + const char* message, + size_t length); + +/** + * NOTE: Core runtime code must not call this directly. It may only be called by + * a MemoryAllocator wrapper. + * + * Allocates size bytes of memory. + * + * @param[in] size Number of bytes to allocate. + * @returns the allocated memory, or nullptr on failure. Must be freed using + * et_pal_free(). + */ +void* pal_allocate(size_t size); + +/** + * Frees memory allocated by et_pal_allocate(). + * + * @param[in] ptr Pointer to memory to free. May be nullptr. + */ +void pal_free(void* ptr); + +} // namespace executorch::runtime diff --git a/runtime/platform/targets.bzl b/runtime/platform/targets.bzl index 5235101648e..eecac8ae5db 100644 --- a/runtime/platform/targets.bzl +++ b/runtime/platform/targets.bzl @@ -76,6 +76,7 @@ def define_common_targets(): srcs = [ "abort.cpp", "log.cpp", + "platform.cpp", "profiler.cpp", "runtime.cpp", ], diff --git a/runtime/platform/test/CMakeLists.txt b/runtime/platform/test/CMakeLists.txt index 0afffaaabf0..356c05a01e7 100644 --- a/runtime/platform/test/CMakeLists.txt +++ b/runtime/platform/test/CMakeLists.txt @@ -19,6 +19,10 @@ include(${EXECUTORCH_ROOT}/tools/cmake/Test.cmake) et_cxx_test(platform_test SOURCES executor_pal_test.cpp) +et_cxx_test(platform_runtime_override_test SOURCES executor_pal_runtime_override_test.cpp stub_platform.cpp) + +et_cxx_test(platform_static_runtime_override_test SOURCES executor_pal_static_runtime_override_test.cpp) + # TODO: Re-enable this test on OSS # et_cxx_test(platform_death_test SOURCES executor_pal_death_test.cpp) diff --git a/runtime/platform/test/executor_pal_override_test.cpp b/runtime/platform/test/executor_pal_override_test.cpp index 9bc500e652e..07cf7490983 100644 --- a/runtime/platform/test/executor_pal_override_test.cpp +++ b/runtime/platform/test/executor_pal_override_test.cpp @@ -8,6 +8,7 @@ #include #include +#include #include #include @@ -16,79 +17,6 @@ using namespace ::testing; using executorch::runtime::LogLevel; -class PalSpy : public PlatformIntercept { - public: - PalSpy() = default; - - void init() override { - ++init_call_count; - } - - static constexpr et_timestamp_t kTimestamp = 1234; - - et_timestamp_t current_ticks() override { - ++current_ticks_call_count; - return kTimestamp; - } - - et_tick_ratio_t ticks_to_ns_multiplier() override { - return tick_ns_multiplier; - } - - void emit_log_message( - et_timestamp_t timestamp, - et_pal_log_level_t level, - const char* filename, - const char* function, - size_t line, - const char* message, - size_t length) override { - ++emit_log_message_call_count; - last_log_message_args.timestamp = timestamp; - last_log_message_args.level = level; - last_log_message_args.filename = filename; - last_log_message_args.function = function; - last_log_message_args.line = line; - last_log_message_args.message = message; - last_log_message_args.length = length; - } - - void* allocate(size_t size) override { - ++allocate_call_count; - last_allocated_size = size; - last_allocated_ptr = (void*)0x1234; - return nullptr; - } - - void free(void* ptr) override { - ++free_call_count; - last_freed_ptr = ptr; - } - - virtual ~PalSpy() = default; - - size_t init_call_count = 0; - size_t current_ticks_call_count = 0; - size_t emit_log_message_call_count = 0; - et_tick_ratio_t tick_ns_multiplier = {1, 1}; - size_t allocate_call_count = 0; - size_t free_call_count = 0; - size_t last_allocated_size = 0; - void* last_allocated_ptr = nullptr; - void* last_freed_ptr = nullptr; - - /// The args that were passed to the most recent call to emit_log_message(). - struct { - et_timestamp_t timestamp; - et_pal_log_level_t level; - std::string filename; // Copy of the char* to avoid lifetime issues. - std::string function; - size_t line; - std::string message; - size_t length; - } last_log_message_args = {}; -}; - // Demonstrate what would happen if we didn't intercept the PAL calls. TEST(ExecutorPalOverrideTest, DiesIfNotIntercepted) { ET_EXPECT_DEATH( diff --git a/runtime/platform/test/executor_pal_runtime_override_test.cpp b/runtime/platform/test/executor_pal_runtime_override_test.cpp new file mode 100644 index 00000000000..ee0526a8b6b --- /dev/null +++ b/runtime/platform/test/executor_pal_runtime_override_test.cpp @@ -0,0 +1,130 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +namespace { +PalSpy* active_spy; + +void pal_init(void) { + active_spy->init(); +} + +et_timestamp_t pal_current_ticks(void) { + return active_spy->current_ticks(); +} + +et_tick_ratio_t pal_ticks_to_ns_multiplier(void) { + return active_spy->ticks_to_ns_multiplier(); +} + +void pal_emit_log_message( + et_timestamp_t timestamp, + et_pal_log_level_t level, + const char* filename, + ET_UNUSED const char* function, + size_t line, + const char* message, + ET_UNUSED size_t length) { + active_spy->emit_log_message( + timestamp, level, filename, function, line, message, length); +} + +void* pal_allocate(size_t size) { + return active_spy->allocate(size); +} + +void pal_free(void* ptr) { + active_spy->free(ptr); +} +} // namespace + +class RuntimePalOverrideTest : public ::testing::Test { + protected: + void SetUp() override { + // Capture the current PAL implementation so that it can be restored + // after the test. + _original_pal_impl = *executorch::runtime::get_pal_impl(); + } + + void TearDown() override { + // Restore the original PAL implementation. + + // This is a slightly hacky way to allow this test to exist alongside + // the executor_pal_override_test, which provides a build-time override + // for et_pal_init. This implementation asserts that an intercept exists. + // Since register_pal calls init, we need to make sure that an intercept + // is registered. It will be deregistered when it goes out of scope, + // allowing the tests to run in any order. + InterceptWith iw(_spy); + auto success = executorch::runtime::register_pal(_original_pal_impl); + if (!success) { + throw std::runtime_error("Failed to restore PAL implementation."); + } + } + + void RegisterSpy() { + active_spy = &_spy; + + executorch::runtime::register_pal(executorch::runtime::PalImpl::create( + pal_init, + nullptr, // abort + pal_current_ticks, + pal_ticks_to_ns_multiplier, + pal_emit_log_message, + pal_allocate, + pal_free, + __FILE__)); + } + + PalSpy _spy; + + private: + // The PAL implementation at the time of setup. + executorch::runtime::PalImpl _original_pal_impl; +}; + +TEST_F(RuntimePalOverrideTest, SmokeTest) { + EXPECT_EQ(_spy.init_call_count, 0); + EXPECT_EQ(_spy.current_ticks_call_count, 0); + EXPECT_EQ(_spy.allocate_call_count, 0); + EXPECT_EQ(_spy.free_call_count, 0); + + RegisterSpy(); + + // Expect register to call init. + EXPECT_EQ(_spy.init_call_count, 1); + + EXPECT_EQ(executorch::runtime::pal_current_ticks(), 1234); + EXPECT_EQ(_spy.current_ticks_call_count, 1); + + et_tick_ratio_t ticks_to_ns_multiplier = + executorch::runtime::pal_ticks_to_ns_multiplier(); + EXPECT_EQ(ticks_to_ns_multiplier.numerator, 1); + EXPECT_EQ(ticks_to_ns_multiplier.denominator, 1); + + executorch::runtime::pal_emit_log_message( + 5, kError, "test.cpp", "test_function", 6, "test message", 7); + EXPECT_EQ(_spy.emit_log_message_call_count, 1); + EXPECT_EQ(_spy.last_log_message_args.timestamp, 5); + EXPECT_EQ(_spy.last_log_message_args.level, kError); + EXPECT_EQ(_spy.last_log_message_args.filename, "test.cpp"); + EXPECT_EQ(_spy.last_log_message_args.function, "test_function"); + EXPECT_EQ(_spy.last_log_message_args.line, 6); + EXPECT_EQ(_spy.last_log_message_args.message, "test message"); + EXPECT_EQ(_spy.last_log_message_args.length, 7); + + executorch::runtime::pal_allocate(16); + EXPECT_EQ(_spy.allocate_call_count, 1); + + executorch::runtime::pal_free(nullptr); + EXPECT_EQ(_spy.free_call_count, 1); +} diff --git a/runtime/platform/test/executor_pal_static_runtime_override_test.cpp b/runtime/platform/test/executor_pal_static_runtime_override_test.cpp new file mode 100644 index 00000000000..baf5da52b5c --- /dev/null +++ b/runtime/platform/test/executor_pal_static_runtime_override_test.cpp @@ -0,0 +1,94 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +namespace { +PalSpy spy = PalSpy(); + +void pal_init(void) { + spy.init(); +} + +et_timestamp_t pal_current_ticks(void) { + return spy.current_ticks(); +} + +et_tick_ratio_t pal_ticks_to_ns_multiplier(void) { + return spy.ticks_to_ns_multiplier(); +} + +void pal_emit_log_message( + et_timestamp_t timestamp, + et_pal_log_level_t level, + const char* filename, + ET_UNUSED const char* function, + size_t line, + const char* message, + ET_UNUSED size_t length) { + spy.emit_log_message( + timestamp, level, filename, function, line, message, length); +} + +void* pal_allocate(size_t size) { + return spy.allocate(size); +} + +void pal_free(void* ptr) { + spy.free(ptr); +} + +// Statically register PAL impleementation. +bool registration_result = + executorch::runtime::register_pal(executorch::runtime::PalImpl::create( + pal_init, + nullptr, // abort + pal_current_ticks, + pal_ticks_to_ns_multiplier, + pal_emit_log_message, + pal_allocate, + pal_free, + __FILE__)); +} // namespace + +TEST(RuntimePalOverrideTest, SmokeTest) { + EXPECT_EQ(spy.current_ticks_call_count, 0); + EXPECT_EQ(spy.allocate_call_count, 0); + EXPECT_EQ(spy.free_call_count, 0); + + // Expect registration to call init. + EXPECT_EQ(spy.init_call_count, 1); + + EXPECT_EQ(executorch::runtime::pal_current_ticks(), 1234); + EXPECT_EQ(spy.current_ticks_call_count, 1); + + et_tick_ratio_t ticks_to_ns_multiplier = + executorch::runtime::pal_ticks_to_ns_multiplier(); + EXPECT_EQ(ticks_to_ns_multiplier.numerator, 1); + EXPECT_EQ(ticks_to_ns_multiplier.denominator, 1); + + executorch::runtime::pal_emit_log_message( + 5, kError, "test.cpp", "test_function", 6, "test message", 7); + EXPECT_EQ(spy.emit_log_message_call_count, 1); + EXPECT_EQ(spy.last_log_message_args.timestamp, 5); + EXPECT_EQ(spy.last_log_message_args.level, kError); + EXPECT_EQ(spy.last_log_message_args.filename, "test.cpp"); + EXPECT_EQ(spy.last_log_message_args.function, "test_function"); + EXPECT_EQ(spy.last_log_message_args.line, 6); + EXPECT_EQ(spy.last_log_message_args.message, "test message"); + EXPECT_EQ(spy.last_log_message_args.length, 7); + + executorch::runtime::pal_allocate(16); + EXPECT_EQ(spy.allocate_call_count, 1); + + executorch::runtime::pal_free(nullptr); + EXPECT_EQ(spy.free_call_count, 1); +} diff --git a/runtime/platform/test/pal_spy.h b/runtime/platform/test/pal_spy.h new file mode 100644 index 00000000000..5831ba17a33 --- /dev/null +++ b/runtime/platform/test/pal_spy.h @@ -0,0 +1,86 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +#include + +class PalSpy : public PlatformIntercept { + public: + PalSpy() = default; + + void init() override { + ++init_call_count; + } + + static constexpr et_timestamp_t kTimestamp = 1234; + + et_timestamp_t current_ticks() override { + ++current_ticks_call_count; + return kTimestamp; + } + + et_tick_ratio_t ticks_to_ns_multiplier() override { + return tick_ns_multiplier; + } + + void emit_log_message( + et_timestamp_t timestamp, + et_pal_log_level_t level, + const char* filename, + const char* function, + size_t line, + const char* message, + size_t length) override { + ++emit_log_message_call_count; + last_log_message_args.timestamp = timestamp; + last_log_message_args.level = level; + last_log_message_args.filename = filename; + last_log_message_args.function = function; + last_log_message_args.line = line; + last_log_message_args.message = message; + last_log_message_args.length = length; + } + + void* allocate(size_t size) override { + ++allocate_call_count; + last_allocated_size = size; + last_allocated_ptr = (void*)0x1234; + return nullptr; + } + + void free(void* ptr) override { + ++free_call_count; + last_freed_ptr = ptr; + } + + virtual ~PalSpy() = default; + + size_t init_call_count = 0; + size_t current_ticks_call_count = 0; + size_t emit_log_message_call_count = 0; + et_tick_ratio_t tick_ns_multiplier = {1, 1}; + size_t allocate_call_count = 0; + size_t free_call_count = 0; + size_t last_allocated_size = 0; + void* last_allocated_ptr = nullptr; + void* last_freed_ptr = nullptr; + + /// The args that were passed to the most recent call to emit_log_message(). + struct { + et_timestamp_t timestamp; + et_pal_log_level_t level; + std::string filename; // Copy of the char* to avoid lifetime issues. + std::string function; + size_t line; + std::string message; + size_t length; + } last_log_message_args = {}; +}; diff --git a/runtime/platform/test/targets.bzl b/runtime/platform/test/targets.bzl index 9955b5fb6b8..6a46eb29f4b 100644 --- a/runtime/platform/test/targets.bzl +++ b/runtime/platform/test/targets.bzl @@ -40,6 +40,7 @@ def define_common_targets(): "stub_platform.cpp", ], exported_headers = [ + "pal_spy.h", "stub_platform.h", ], deps = [ @@ -65,6 +66,18 @@ def define_common_targets(): ], ) + runtime.cxx_test( + name = "runtime_platform_override_test", + srcs = [ + "executor_pal_runtime_override_test.cpp", + ], + deps = [ + ":stub_platform", + "//executorch/runtime/core:core", + "//executorch/runtime/platform:platform", + ], + ) + runtime.cxx_test( name = "logging_test", srcs = [