Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 32 additions & 85 deletions extension/pytree/function_ref.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@
/// a FunctionRef.

// torch::executor: modified from llvm::function_ref
// see https://www.foonathan.net/2017/01/function-ref-implementation/
// - renamed to FunctionRef
// - removed LLVM_GSL_POINTER and LLVM_LIFETIME_BOUND macro uses
// - use namespaced internal::remove_cvref_t


#pragma once

Expand Down Expand Up @@ -64,99 +67,43 @@ class FunctionRef;

template <typename Ret, typename... Params>
class FunctionRef<Ret(Params...)> {
Ret (*callback_)(const void* memory, Params... params) = nullptr;
union Storage {
void* callable;
Ret (*function)(Params...);
} storage_;
Ret (*callback)(intptr_t callable, Params ...params) = nullptr;
intptr_t callable;

public:
FunctionRef() = default;
explicit FunctionRef(std::nullptr_t) {}

/**
* Case 1: A callable object passed by lvalue reference.
* Taking rvalue reference is error prone because the object will be always
* be destroyed immediately.
*/
template <
typename Callable,
// This is not the copy-constructor.
typename std::enable_if<
!std::is_same<internal::remove_cvref_t<Callable>, FunctionRef>::value,
int32_t>::type = 0,
// Avoid lvalue reference to non-capturing lambda.
typename std::enable_if<
!std::is_convertible<Callable, Ret (*)(Params...)>::value,
int32_t>::type = 0,
// Functor must be callable and return a suitable type.
// To make this container type safe, we need to ensure either:
// 1. The return type is void.
// 2. Or the resulting type from calling the callable is convertible to
// the declared return type.
typename std::enable_if<
std::is_void<Ret>::value ||
std::is_convertible<
decltype(std::declval<Callable>()(std::declval<Params>()...)),
Ret>::value,
int32_t>::type = 0>
explicit FunctionRef(Callable& callable)
: callback_([](const void* memory, Params... params) {
auto& storage = *static_cast<const Storage*>(memory);
auto& callable = *static_cast<Callable*>(storage.callable);
return static_cast<Ret>(callable(std::forward<Params>(params)...));
}) {
storage_.callable = &callable;
template<typename Callable>
static Ret callback_fn(intptr_t callable, Params ...params) {
return (*reinterpret_cast<Callable*>(callable))(
std::forward<Params>(params)...);
}

/**
* Case 2: A plain function pointer.
* Instead of storing an opaque pointer to underlying callable object,
* store a function pointer directly.
* Note that in the future a variant which coerces compatible function
* pointers could be implemented by erasing the storage type.
*/
/* implicit */ FunctionRef(Ret (*ptr)(Params...))
: callback_([](const void* memory, Params... params) {
auto& storage = *static_cast<const Storage*>(memory);
return storage.function(std::forward<Params>(params)...);
}) {
storage_.function = ptr;
}
public:
FunctionRef() = default;
FunctionRef(std::nullptr_t) {}

/**
* Case 3: Implicit conversion from lambda to FunctionRef.
* A common use pattern is like:
* void foo(FunctionRef<...>) {...}
* foo([](...){...})
* Here constructors for non const lvalue reference or function pointer
* would not work because they do not cover implicit conversion from rvalue
* lambda.
* We need to define a constructor for capturing temporary callables and
* always try to convert the lambda to a function pointer behind the scene.
*/
template <
typename Function,
template <typename Callable>
FunctionRef(
Callable &&callable,
// This is not the copy-constructor.
typename std::enable_if<
!std::is_same<Function, FunctionRef>::value,
int32_t>::type = 0,
// Function is convertible to pointer of (Params...) -> Ret.
typename std::enable_if<
std::is_convertible<Function, Ret (*)(Params...)>::value,
int32_t>::type = 0>
/* implicit */ FunctionRef(const Function& function)
: FunctionRef(static_cast<Ret (*)(Params...)>(function)) {}

Ret operator()(Params... params) const {
return callback_(&storage_, std::forward<Params>(params)...);
std::enable_if_t<!std::is_same<internal::remove_cvref_t<Callable>,
FunctionRef>::value> * = nullptr,
// Functor must be callable and return a suitable type.
std::enable_if_t<std::is_void<Ret>::value ||
std::is_convertible<decltype(std::declval<Callable>()(
std::declval<Params>()...)),
Ret>::value> * = nullptr)
: callback(callback_fn<std::remove_reference_t<Callable>>),
callable(reinterpret_cast<intptr_t>(&callable)) {}

Ret operator()(Params ...params) const {
return callback(callable, std::forward<Params>(params)...);
}

explicit operator bool() const {
return callback_;
explicit operator bool() const { return callback; }

bool operator==(const FunctionRef<Ret(Params...)> &Other) const {
return callable == Other.callable;
}
};

} // namespace pytree
} // namespace extension
} // namespace executorch
Expand Down
39 changes: 7 additions & 32 deletions extension/pytree/test/function_ref_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,6 @@ using namespace ::testing;
using ::executorch::extension::pytree::FunctionRef;

namespace {
class Item {
private:
int32_t val_;
FunctionRef<void(int32_t&)> ref_;

public:
/* implicit */ Item(int32_t val, FunctionRef<void(int32_t&)> ref)
: val_(val), ref_(ref) {}

int32_t get() {
ref_(val_);
return val_;
}
};

void one(int32_t& i) {
i = 1;
}
Expand All @@ -39,8 +24,9 @@ void one(int32_t& i) {
TEST(FunctionRefTest, CapturingLambda) {
auto one = 1;
auto f = [&](int32_t& i) { i = one; };
Item item(0, FunctionRef<void(int32_t&)>{f});
EXPECT_EQ(item.get(), 1);
int32_t val = 0;
FunctionRef<void(int32_t&)>{f}(val);
EXPECT_EQ(val, 1);
// ERROR:
// Item item1(0, f);
// Item item2(0, [&](int32_t& i) { i = 2; });
Expand All @@ -58,16 +44,6 @@ TEST(FunctionRefTest, NonCapturingLambda) {
FunctionRef<void(int32_t&)> ref1(lambda);
ref1(val);
EXPECT_EQ(val, 1);

Item item(0, [](int32_t& i) { i = 1; });
EXPECT_EQ(item.get(), 1);

auto f = [](int32_t& i) { i = 1; };
Item item1(0, f);
EXPECT_EQ(item1.get(), 1);

Item item2(0, std::move(f));
EXPECT_EQ(item2.get(), 1);
}

TEST(FunctionRefTest, FunctionPointer) {
Expand All @@ -76,9 +52,8 @@ TEST(FunctionRefTest, FunctionPointer) {
ref(val);
EXPECT_EQ(val, 1);

Item item(0, one);
EXPECT_EQ(item.get(), 1);

Item item1(0, &one);
EXPECT_EQ(item1.get(), 1);
val = 0;
FunctionRef<void(int32_t&)> ref2(one);
ref(val);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

s/ref/ref2 perhaps?

EXPECT_EQ(val, 1);
}
Loading