diff --git a/runtime/core/freeable_buffer.h b/runtime/core/freeable_buffer.h index a90c899103d..c743f32116a 100644 --- a/runtime/core/freeable_buffer.h +++ b/runtime/core/freeable_buffer.h @@ -9,6 +9,12 @@ #pragma once #include +#include +#include + +#include +#include +#include namespace executorch { namespace runtime { @@ -20,20 +26,35 @@ class FreeableBuffer final { public: // Callback signature for the function that does the freeing. using FreeFn = void (*)(void* context, void* data, size_t size); + using FreeUInt64Fn = + void (*)(void* context, uint64_t data_uint64, size_t size); + + private: + // Forward declare types. + struct PointerData { + const void* data_; + FreeFn free_fn_; + }; + struct UInt64Data { + // A pointer value cast to uint64_t. + uint64_t data_; + FreeUInt64Fn free_fn_; + }; + + public: /** * Creates an empty FreeableBuffer with size zero and a null data pointer. */ FreeableBuffer() - : free_fn_(nullptr), + : data_(PointerData{nullptr, nullptr}), free_fn_context_(nullptr), - data_(nullptr), size_(0) {} /** * Creates a FreeableBuffer with an optional free function. * - * @param[in] data The data of the segment. + * @param[in] data The data of the segment, as a void*. * @param[in] size The size of the segment data, in bytes. * @param[in] free_fn Optional function to free the data. Guaranteed to be * called exactly once before the FreeableBuffer is destroyed. May be @@ -47,9 +68,35 @@ class FreeableBuffer final { size_t size, FreeFn free_fn, void* free_fn_context = nullptr) - : free_fn_(free_fn), + : data_(PointerData{data, free_fn}), + free_fn_context_(free_fn_context), + size_(size) {} + + /** + * Creates a FreeableBuffer with an optional free function. + * + * NOTE: most users should use the other ctor with FreeFn. + * This variant exists for situations where the FreeableBuffer points to + * memory on a different core whose pointer value is larger than the local + * core's void*. + * + * @param[in] data Pointer to the data of the segment, cast to a uint64_t + * value. + * @param[in] size The size of the segment data, in bytes. + * @param[in] free_fn Optional function to free the data. Guaranteed to be + * called exactly once before the FreeableBuffer is destroyed. May be + * nullptr. NOTE: This function must be thread-safe. If it modifies common + * state, the function must do its own locking. + * @param[in] free_fn_context Opaque pointer to pass as the `context` + * parameter of `free_fn`. May be nullptr. + */ + explicit FreeableBuffer( + const uint64_t data_uint64, + size_t size, + FreeUInt64Fn free_fn, + void* free_fn_context = nullptr) + : data_(UInt64Data{data_uint64, free_fn}), free_fn_context_(free_fn_context), - data_(data), size_(size) {} /** @@ -57,13 +104,15 @@ class FreeableBuffer final { * leaving `rhs` pointing to nullptr. */ FreeableBuffer(FreeableBuffer&& rhs) noexcept - : free_fn_(rhs.free_fn_), + : data_(rhs.data_), free_fn_context_(rhs.free_fn_context_), - data_(rhs.data_), size_(rhs.size_) { - rhs.free_fn_ = nullptr; + if (std::holds_alternative(rhs.data_)) { + rhs.data_ = PointerData{nullptr, nullptr}; + } else { + rhs.data_ = UInt64Data{0, nullptr}; + } rhs.free_fn_context_ = nullptr; - rhs.data_ = nullptr; rhs.size_ = 0; } @@ -75,11 +124,22 @@ class FreeableBuffer final { * Frees the data if not already free. Safe to call multiple times. */ void Free() { - if (data_ != nullptr) { - if (free_fn_ != nullptr) { - free_fn_(free_fn_context_, const_cast(data_), size_); + if (std::holds_alternative(data_)) { + PointerData& ptr_data = std::get(data_); + if (ptr_data.data_ != nullptr && ptr_data.free_fn_ != nullptr) { + // Do not need to check for truncation here, as free_fn_ is only set + // using the void* ctor. + ptr_data.free_fn_( + free_fn_context_, const_cast(ptr_data.data_), size_); } - data_ = nullptr; + ptr_data.data_ = nullptr; + size_ = 0; + } else { + UInt64Data& int64_data = std::get(data_); + if (int64_data.data_ != 0 && int64_data.free_fn_ != nullptr) { + int64_data.free_fn_(free_fn_context_, int64_data.data_, size_); + } + int64_data.data_ = static_cast(0); size_ = 0; } } @@ -95,7 +155,37 @@ class FreeableBuffer final { * Pointer to the data. Returns nullptr if the data has been freed. */ const void* data() const { - return data_; + ET_CHECK_MSG( + std::holds_alternative(data_), + "FreeableBuffer is backed by an uint64_t, please use the data_uint64_type() API."); + return std::get(data_).data_; + } + + /** + * Pointer to the data. Returns nullptr if the data has been freed. + * Safe version of data() API that returns an ERror if the data is + * backed by int64_t instead of void*. + */ + Result data_safe() const { + ET_CHECK_OR_RETURN_ERROR( + std::holds_alternative(data_), + InvalidType, + "FreeableBuffer is backed by an uint64_t, please use the data_uint64_type() API."); + return std::get(data_).data_; + } + + /** + * Data address as a uint64_t. Returns zero if the data has been freed. + * Most users should use data(). data_uint64_type() is only helpful in + * situations where the FreeableBuffer points to memory on a different core + * whose pointer value is larger than the local core's void *. + */ + Result data_uint64_type() const { + ET_CHECK_OR_RETURN_ERROR( + std::holds_alternative(data_), + InvalidType, + "FreeableBuffer is backed by a void*, please use the data() API."); + return std::get(data_).data_; } private: @@ -104,9 +194,15 @@ class FreeableBuffer final { FreeableBuffer& operator=(FreeableBuffer&& rhs) noexcept = delete; FreeableBuffer& operator=(const FreeableBuffer& rhs) = delete; - FreeFn free_fn_; + // This stores either a PointerData or a UInt64Data structure. Most users + // should use the PointerData variant and the void* ctor. This creates a + // FreeableBuffer backed by void*, accessed using the void* getter data(). + // The UInt64Data variant is only helpful in situations where the + // FreeableBuffer points to memory on a different core whose pointer value + // is larger than the local core's void*. + std::variant data_; + void* free_fn_context_; - const void* data_; size_t size_; }; diff --git a/runtime/core/test/freeable_buffer_test.cpp b/runtime/core/test/freeable_buffer_test.cpp index e2edff24227..2848a6b049d 100644 --- a/runtime/core/test/freeable_buffer_test.cpp +++ b/runtime/core/test/freeable_buffer_test.cpp @@ -6,16 +6,21 @@ * LICENSE file in the root directory of this source tree. */ +#include #include +#include +#include #include using namespace ::testing; + +using executorch::runtime::Error; using executorch::runtime::FreeableBuffer; struct FreeCallArgs { size_t calls; - void* data; + std::variant data; size_t size; }; @@ -26,9 +31,18 @@ void RecordFree(void* context, void* data, size_t size) { call->size = size; } +void RecordInt64Free(void* context, uint64_t data, size_t size) { + auto* call = reinterpret_cast(context); + call->calls++; + call->data = data; + call->size = size; +} + TEST(FreeableBufferTest, EmptyTest) { FreeableBuffer fb; EXPECT_EQ(fb.data(), nullptr); + EXPECT_EQ(fb.data_safe().error(), Error::Ok); + EXPECT_EQ(fb.data_safe().get(), nullptr); EXPECT_EQ(fb.size(), 0); } @@ -42,11 +56,33 @@ TEST(FreeableBufferTest, DataAndSizeTest) { // It should return the ctor params unmodified. EXPECT_EQ(fb.size(), sizeof(i)); EXPECT_EQ(fb.data(), &i); + EXPECT_EQ(fb.data_safe().error(), Error::Ok); + EXPECT_EQ(fb.data_safe().get(), &i); // Freeing should clear them, even though free_fn is nullptr. fb.Free(); EXPECT_EQ(fb.size(), 0); EXPECT_EQ(fb.data(), nullptr); + EXPECT_EQ(fb.data_safe().error(), Error::Ok); + EXPECT_EQ(fb.data_safe().get(), nullptr); + + // Use uint64_t constructor. + const uint64_t i64 = 1; + FreeableBuffer fb2( + /*data_uint64=*/i64, + /*size=*/sizeof(i64), + /*free_fn=*/nullptr); + + // It should return the ctor params unmodified. + EXPECT_EQ(fb2.size(), sizeof(i64)); + EXPECT_EQ(fb2.data_uint64_type().error(), Error::Ok); + EXPECT_EQ(fb2.data_uint64_type().get(), i64); + + // Freeing should clear them, even though free_fn is nullptr. + fb2.Free(); + EXPECT_EQ(fb2.size(), 0); + EXPECT_EQ(fb2.data_uint64_type().error(), Error::Ok); + EXPECT_EQ(fb2.data_uint64_type().get(), 0); } TEST(FreeableBufferTest, FreeTest) { @@ -68,7 +104,7 @@ TEST(FreeableBufferTest, FreeTest) { // Called once during Free() with the expected data/size. fb.Free(); EXPECT_EQ(call.calls, 1); - EXPECT_EQ(call.data, &i); + EXPECT_EQ(std::get(call.data), &i); EXPECT_EQ(call.size, sizeof(i)); // A second call to Free() should not call the function again. @@ -78,6 +114,31 @@ TEST(FreeableBufferTest, FreeTest) { // The destructor should not have called the function again. EXPECT_EQ(call.calls, 1); + + // Test with uint64_t constructor and free function. + FreeCallArgs call2 = {}; + { + uint64_t i64 = 1; + FreeableBuffer fb( + /*data_uint64=*/i64, + /*size=*/sizeof(i64), + /*free_fn=*/RecordInt64Free, + /*free_fn_context=*/&call2); + + // Not called during construction. + EXPECT_EQ(call2.calls, 0); + + // Called once during Free() with the expected data/size. + fb.Free(); + EXPECT_EQ(call2.calls, 1); + EXPECT_EQ(std::get(call2.data), i64); + EXPECT_EQ(call2.size, sizeof(i64)); + + // A second call to Free() should not call the function again. + fb.Free(); + EXPECT_EQ(call2.calls, 1); + } + EXPECT_EQ(call2.calls, 1); } TEST(FreeableBufferTest, DestructorTest) { @@ -99,8 +160,24 @@ TEST(FreeableBufferTest, DestructorTest) { // The destructor should have freed the data. EXPECT_EQ(call.calls, 1); - EXPECT_EQ(call.data, &i); + EXPECT_EQ(std::get(call.data), &i); EXPECT_EQ(call.size, sizeof(i)); + + // Test with uint64_t constructor and free function. + FreeCallArgs call2 = {}; + uint64_t i64 = 1; + { + FreeableBuffer fb2( + /*data_uint64=*/i64, + /*size=*/sizeof(i), + /*free_fn=*/RecordInt64Free, + /*free_fn_context=*/&call2); + EXPECT_EQ(call2.calls, 0); + } + // The destructor should have freed the data. + EXPECT_EQ(call2.calls, 1); + EXPECT_EQ(std::get(call2.data), i64); + EXPECT_EQ(call2.size, sizeof(i)); } TEST(FreeableBufferTest, MoveTest) { @@ -127,7 +204,6 @@ TEST(FreeableBufferTest, MoveTest) { // The destination FreeableBuffer should have the data. EXPECT_EQ(fb_dst.size(), sizeof(i)); EXPECT_EQ(fb_dst.data(), &i); - // Freeing the source FreeableBuffer should not call the free function. fb_src.Free(); EXPECT_EQ(call.calls, 0); @@ -135,6 +211,59 @@ TEST(FreeableBufferTest, MoveTest) { // Freeing the destination FreeableBuffer should call the free function. fb_dst.Free(); EXPECT_EQ(call.calls, 1); - EXPECT_EQ(call.data, &i); EXPECT_EQ(call.size, sizeof(i)); + + // Test with uint64_t constructor and free function. + FreeCallArgs call2 = {}; + const uint64_t i64 = 1; + FreeableBuffer fb_src2( + /*data_uint64=*/i64, + /*size=*/sizeof(i64), + /*free_fn=*/RecordInt64Free, + /*free_fn_context=*/&call2); + EXPECT_EQ(fb_src2.size(), sizeof(i64)); + EXPECT_EQ(fb_src2.data_uint64_type().error(), Error::Ok); + EXPECT_EQ(fb_src2.data_uint64_type().get(), i64); + + // Move it into a second FreeableBuffer. + FreeableBuffer fb_dst2(std::move(fb_src2)); + + // The source FreeableBuffer should now be empty. + EXPECT_EQ(fb_src2.size(), 0); // NOLINT(bugprone-use-after-move) + EXPECT_EQ( + fb_src2.data_uint64_type().error(), + Error::Ok); // NOLINT(bugprone-use-after-move) + EXPECT_EQ( + fb_src2.data_uint64_type().get(), 0); // NOLINT(bugprone-use-after-move) + + // The destination FreeableBuffer should have the data. + EXPECT_EQ(fb_dst2.size(), sizeof(i64)); + EXPECT_EQ(fb_dst2.data_uint64_type().error(), Error::Ok); + EXPECT_EQ(fb_dst2.data_uint64_type().get(), i64); + // Freeing the source FreeableBuffer should not call the free function. + fb_src2.Free(); + EXPECT_EQ(call2.calls, 0); + + // Freeing the destination FreeableBuffer should call the free function. + fb_dst2.Free(); + EXPECT_EQ(call2.calls, 1); + EXPECT_EQ(call2.size, sizeof(i64)); +} + +TEST(FreeableBufferTest, APIMisuseDeathTest) { + executorch::runtime::pal_init(); + int i; + FreeableBuffer fb( + /*data=*/&i, + /*size=*/sizeof(i), + /*free_fn=*/nullptr); + EXPECT_EQ(fb.data_uint64_type().error(), Error::InvalidType); + + uint64_t i64 = 1; + FreeableBuffer fb2( + /*data_uint64=*/i64, + /*size=*/sizeof(i64), + /*free_fn=*/nullptr); + EXPECT_EQ(fb2.data_safe().error(), Error::InvalidType); + ET_EXPECT_DEATH(fb2.data(), ".*"); }