Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 7 additions & 156 deletions extension/pytree/function_ref.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,167 +6,18 @@
* LICENSE file in the root directory of this source tree.
*/

//===- llvm/ADT/STLFunctionalExtras.h - Extras for <functional> -*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains some extension to <functional>.
//
// No library is required when using these functions.
//
//===----------------------------------------------------------------------===//
// Extra additions to <functional>
//===----------------------------------------------------------------------===//

/// An efficient, type-erasing, non-owning reference to a callable. This is
/// intended for use as the type of a function parameter that is not used
/// after the function in question returns.
///
/// This class does not own the callable, so it is not in general safe to store
/// a FunctionRef.

// torch::executor: modified from llvm::function_ref
// see https://www.foonathan.net/2017/01/function-ref-implementation/

#pragma once

#include <cstdint>
#include <type_traits>
#include <utility>

namespace executorch {
namespace extension {
namespace pytree {

//===----------------------------------------------------------------------===//
// Features from C++20
//===----------------------------------------------------------------------===//

namespace internal {

template <typename T>
struct remove_cvref {
using type =
typename std::remove_cv<typename std::remove_reference<T>::type>::type;
};

template <typename T>
using remove_cvref_t = typename remove_cvref<T>::type;

} // namespace internal

template <typename Fn>
class FunctionRef;

template <typename Ret, typename... Params>
class FunctionRef<Ret(Params...)> {
Ret (*callback_)(const void* memory, Params... params) = nullptr;
union Storage {
void* callable;
Ret (*function)(Params...);
} storage_;

public:
FunctionRef() = default;
explicit FunctionRef(std::nullptr_t) {}

/**
* Case 1: A callable object passed by lvalue reference.
* Taking rvalue reference is error prone because the object will be always
* be destroyed immediately.
*/
template <
typename Callable,
// This is not the copy-constructor.
typename std::enable_if<
!std::is_same<internal::remove_cvref_t<Callable>, FunctionRef>::value,
int32_t>::type = 0,
// Avoid lvalue reference to non-capturing lambda.
typename std::enable_if<
!std::is_convertible<Callable, Ret (*)(Params...)>::value,
int32_t>::type = 0,
// Functor must be callable and return a suitable type.
// To make this container type safe, we need to ensure either:
// 1. The return type is void.
// 2. Or the resulting type from calling the callable is convertible to
// the declared return type.
typename std::enable_if<
std::is_void<Ret>::value ||
std::is_convertible<
decltype(std::declval<Callable>()(std::declval<Params>()...)),
Ret>::value,
int32_t>::type = 0>
explicit FunctionRef(Callable& callable)
: callback_([](const void* memory, Params... params) {
auto& storage = *static_cast<const Storage*>(memory);
auto& callable = *static_cast<Callable*>(storage.callable);
return static_cast<Ret>(callable(std::forward<Params>(params)...));
}) {
storage_.callable = &callable;
}

/**
* Case 2: A plain function pointer.
* Instead of storing an opaque pointer to underlying callable object,
* store a function pointer directly.
* Note that in the future a variant which coerces compatible function
* pointers could be implemented by erasing the storage type.
*/
/* implicit */ FunctionRef(Ret (*ptr)(Params...))
: callback_([](const void* memory, Params... params) {
auto& storage = *static_cast<const Storage*>(memory);
return storage.function(std::forward<Params>(params)...);
}) {
storage_.function = ptr;
}

/**
* Case 3: Implicit conversion from lambda to FunctionRef.
* A common use pattern is like:
* void foo(FunctionRef<...>) {...}
* foo([](...){...})
* Here constructors for non const lvalue reference or function pointer
* would not work because they do not cover implicit conversion from rvalue
* lambda.
* We need to define a constructor for capturing temporary callables and
* always try to convert the lambda to a function pointer behind the scene.
*/
template <
typename Function,
// This is not the copy-constructor.
typename std::enable_if<
!std::is_same<Function, FunctionRef>::value,
int32_t>::type = 0,
// Function is convertible to pointer of (Params...) -> Ret.
typename std::enable_if<
std::is_convertible<Function, Ret (*)(Params...)>::value,
int32_t>::type = 0>
/* implicit */ FunctionRef(const Function& function)
: FunctionRef(static_cast<Ret (*)(Params...)>(function)) {}

Ret operator()(Params... params) const {
return callback_(&storage_, std::forward<Params>(params)...);
}
#include <executorch/runtime/core/function_ref.h>

explicit operator bool() const {
return callback_;
}
};
/// This header is DEPRECATED; use executorch/runtime/core/function_ref.h directly instead.

} // namespace pytree
} // namespace extension
} // namespace executorch
namespace executorch::extension::pytree {
using executorch::runtime::FunctionRef;
} // namespace executorch::extension::pytree

namespace torch {
namespace executor {
namespace pytree {
namespace torch::executor::pytree {
// TODO(T197294990): Remove these deprecated aliases once all users have moved
// to the new `::executorch` namespaces.
using ::executorch::extension::pytree::FunctionRef;
} // namespace pytree
} // namespace executor
} // namespace torch
} // namespace torch::executor::pytree
2 changes: 1 addition & 1 deletion extension/pytree/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,6 @@ set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../..)

include(${EXECUTORCH_ROOT}/tools/cmake/Test.cmake)

set(_test_srcs function_ref_test.cpp test_pytree.cpp)
set(_test_srcs test_pytree.cpp)

et_cxx_test(extension_pytree_test SOURCES ${_test_srcs} EXTRA_LIBS)
6 changes: 0 additions & 6 deletions extension/pytree/test/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,6 @@ cpp_unittest(
deps = ["//executorch/extension/pytree:pytree"],
)

cpp_unittest(
name = "function_ref_test",
srcs = ["function_ref_test.cpp"],
deps = ["//executorch/extension/pytree:pytree"],
)

python_unittest(
name = "pybindings_test",
srcs = [
Expand Down
105 changes: 105 additions & 0 deletions runtime/core/function_ref.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

//===- llvm/ADT/STLFunctionalExtras.h - Extras for <functional> -*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains some extension to <functional>.
//
// No library is required when using these functions.
//
//===----------------------------------------------------------------------===//
// Extra additions to <functional>
//===----------------------------------------------------------------------===//

/// An efficient, type-erasing, non-owning reference to a callable. This is
/// intended for use as the type of a function parameter that is not used
/// after the function in question returns.
///
/// This class does not own the callable, so it is not in general safe to store
/// a FunctionRef.

// torch::executor: modified from llvm::function_ref
// - renamed to FunctionRef
// - removed LLVM_GSL_POINTER and LLVM_LIFETIME_BOUND macro uses
// - use namespaced internal::remove_cvref_t


#pragma once

#include <cstdint>
#include <type_traits>
#include <utility>

namespace executorch::runtime {

//===----------------------------------------------------------------------===//
// Features from C++20
//===----------------------------------------------------------------------===//

namespace internal {

template <typename T>
struct remove_cvref {
using type =
typename std::remove_cv<typename std::remove_reference<T>::type>::type;
};

template <typename T>
using remove_cvref_t = typename remove_cvref<T>::type;

} // namespace internal

template <typename Fn>
class FunctionRef;

template <typename Ret, typename... Params>
class FunctionRef<Ret(Params...)> {
Ret (*callback)(intptr_t callable, Params ...params) = nullptr;
intptr_t callable;

template<typename Callable>
static Ret callback_fn(intptr_t callable, Params ...params) {
return (*reinterpret_cast<Callable*>(callable))(
std::forward<Params>(params)...);
}

public:
FunctionRef() = default;
FunctionRef(std::nullptr_t) {}

template <typename Callable>
FunctionRef(
Callable &&callable,
// This is not the copy-constructor.
std::enable_if_t<!std::is_same<internal::remove_cvref_t<Callable>,
FunctionRef>::value> * = nullptr,
// Functor must be callable and return a suitable type.
std::enable_if_t<std::is_void<Ret>::value ||
std::is_convertible<decltype(std::declval<Callable>()(
std::declval<Params>()...)),
Ret>::value> * = nullptr)
: callback(callback_fn<std::remove_reference_t<Callable>>),
callable(reinterpret_cast<intptr_t>(&callable)) {}

Ret operator()(Params ...params) const {
return callback(callable, std::forward<Params>(params)...);
}

explicit operator bool() const { return callback; }

bool operator==(const FunctionRef<Ret(Params...)> &Other) const {
return callable == Other.callable;
}
};
} // namespace executorch::runtime
1 change: 1 addition & 0 deletions runtime/core/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def define_common_targets():
"defines.h",
"error.h",
"freeable_buffer.h",
"function_ref.h",
"result.h",
"span.h",
],
Expand Down
9 changes: 5 additions & 4 deletions runtime/core/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@ set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../..)
include(${EXECUTORCH_ROOT}/tools/cmake/Test.cmake)

set(_test_srcs
span_test.cpp
array_ref_test.cpp
error_handling_test.cpp
evalue_test.cpp
event_tracer_test.cpp
freeable_buffer_test.cpp
array_ref_test.cpp
memory_allocator_test.cpp
function_ref_test.cpp
hierarchical_allocator_test.cpp
evalue_test.cpp
memory_allocator_test.cpp
span_test.cpp
)

et_cxx_test(runtime_core_test SOURCES ${_test_srcs} EXTRA_LIBS)
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,6 @@ using namespace ::testing;
using ::executorch::extension::pytree::FunctionRef;

namespace {
class Item {
private:
int32_t val_;
FunctionRef<void(int32_t&)> ref_;

public:
/* implicit */ Item(int32_t val, FunctionRef<void(int32_t&)> ref)
: val_(val), ref_(ref) {}

int32_t get() {
ref_(val_);
return val_;
}
};

void one(int32_t& i) {
i = 1;
}
Expand All @@ -39,8 +24,9 @@ void one(int32_t& i) {
TEST(FunctionRefTest, CapturingLambda) {
auto one = 1;
auto f = [&](int32_t& i) { i = one; };
Item item(0, FunctionRef<void(int32_t&)>{f});
EXPECT_EQ(item.get(), 1);
int32_t val = 0;
FunctionRef<void(int32_t&)>{f}(val);
EXPECT_EQ(val, 1);
// ERROR:
// Item item1(0, f);
// Item item2(0, [&](int32_t& i) { i = 2; });
Expand All @@ -58,16 +44,6 @@ TEST(FunctionRefTest, NonCapturingLambda) {
FunctionRef<void(int32_t&)> ref1(lambda);
ref1(val);
EXPECT_EQ(val, 1);

Item item(0, [](int32_t& i) { i = 1; });
EXPECT_EQ(item.get(), 1);

auto f = [](int32_t& i) { i = 1; };
Item item1(0, f);
EXPECT_EQ(item1.get(), 1);

Item item2(0, std::move(f));
EXPECT_EQ(item2.get(), 1);
}

TEST(FunctionRefTest, FunctionPointer) {
Expand All @@ -76,9 +52,8 @@ TEST(FunctionRefTest, FunctionPointer) {
ref(val);
EXPECT_EQ(val, 1);

Item item(0, one);
EXPECT_EQ(item.get(), 1);

Item item1(0, &one);
EXPECT_EQ(item1.get(), 1);
val = 0;
FunctionRef<void(int32_t&)> ref2(one);
ref(val);
EXPECT_EQ(val, 1);
}
Loading
Loading