diff --git a/kernels/portable/cpu/util/reduce_util.h b/kernels/portable/cpu/util/reduce_util.h index a41c580c98e..0c40ed90175 100644 --- a/kernels/portable/cpu/util/reduce_util.h +++ b/kernels/portable/cpu/util/reduce_util.h @@ -565,6 +565,15 @@ size_t compute_reduced_out_size( bool keepdim, exec_aten::SizesType* sizes_arr); +inline size_t compute_reduced_out_size( + const exec_aten::Tensor& in, + int64_t dim, + bool keepdim, + exec_aten::SizesType* sizes_arr) { + return compute_reduced_out_size( + in, exec_aten::optional(dim), keepdim, sizes_arr); +} + inline ssize_t compute_reduced_out_dim( const exec_aten::Tensor& in, const exec_aten::optional& dim, @@ -588,6 +597,14 @@ inline ssize_t compute_reduced_out_dim( : 0); } +inline ssize_t compute_reduced_out_dim( + const exec_aten::Tensor& in, + int64_t dim, + bool keepdim) { + return compute_reduced_out_dim( + in, exec_aten::optional(dim), keepdim); +} + // // Resize out tensor of reduction op // @@ -604,6 +621,15 @@ Error resize_reduction_out( bool keepdim, exec_aten::Tensor& out); +inline Error resize_reduction_out( + const exec_aten::Tensor& in, + int64_t dim, + bool keepdim, + exec_aten::Tensor& out) { + return resize_reduction_out( + in, exec_aten::optional(dim), keepdim, out); +} + #ifndef USE_ATEN_LIB bool check_reduction_args( const Tensor& in, diff --git a/runtime/core/exec_aten/exec_aten.h b/runtime/core/exec_aten/exec_aten.h index bfb47daa05d..66f36b38d46 100644 --- a/runtime/core/exec_aten/exec_aten.h +++ b/runtime/core/exec_aten/exec_aten.h @@ -104,8 +104,7 @@ using ArrayRef = torch::executor::ArrayRef; template using optional = torch::executor::optional; using nullopt_t = torch::executor::nullopt_t; -// NOLINTNEXTLINE(facebook-hte-NamespaceScopedStaticDeclaration) -static constexpr nullopt_t nullopt{0}; +using torch::executor::nullopt; using ScalarType = torch::executor::ScalarType; using TensorList = ArrayRef; using Scalar = torch::executor::Scalar; diff --git a/runtime/core/portable_type/optional.h b/runtime/core/portable_type/optional.h index 21fe0d39267..b2aa7f2c341 100644 --- a/runtime/core/portable_type/optional.h +++ b/runtime/core/portable_type/optional.h @@ -8,175 +8,15 @@ #pragma once -#include -#include -#include // std::forward and other template magic checks +#include namespace executorch { namespace runtime { namespace etensor { -/// Used to indicate an optional type with uninitialized state. -struct nullopt_t final { - constexpr explicit nullopt_t(int32_t) {} -}; - -/// A constant of type nullopt_t that is used to indicate an optional type with -/// uninitialized state. -constexpr nullopt_t nullopt{0}; - -/// Leaner optional class, subset of c10, std, and boost optional APIs. -template -class optional final { - public: - /// The type wrapped by the optional class. - using value_type = T; - - /// Constructs an optional object that does not contain a value. - /* implicit */ optional() noexcept : storage_(trivial_init), init_(false) {} - - /// Constructs an optional object that does not contain a value. - /* implicit */ optional(nullopt_t) noexcept - : storage_(trivial_init), init_(false) {} - - /// Constructs an optional object that matches the state of v. - /* implicit */ optional(const optional& v) - : storage_(trivial_init), init_(v.init_) { - if (init_) { - new (&storage_.value_) T(v.storage_.value_); - } - } - - /// Constructs an optional object that contains the specified value. - /* implicit */ optional(const T& v) : storage_(v), init_(true) {} - - /// Constructs an optional object from v. - /* implicit */ optional(optional&& v) noexcept( - std::is_nothrow_move_constructible::value) - : storage_(trivial_init), init_(v.init_) { - if (init_) { - new (&storage_.value_) T(std::forward(v.storage_.value_)); - } - } - - /// Constructs an optional object that contains the specified value. - /* implicit */ optional(T&& v) : storage_(std::forward(v)), init_(true) {} - - optional& operator=(const optional& rhs) { - if (init_ && !rhs.init_) { - clear(); - } else if (!init_ && rhs.init_) { - init_ = true; - new (&storage_.value_) T(rhs.storage_.value_); - } else if (init_ && rhs.init_) { - storage_.value_ = rhs.storage_.value_; - } - return *this; - } - - optional& operator=(optional&& rhs) noexcept( - std::is_nothrow_move_assignable::value && - std::is_nothrow_move_constructible::value) { - if (init_ && !rhs.init_) { - clear(); - } else if (!init_ && rhs.init_) { - init_ = true; - new (&storage_.value_) T(std::forward(rhs.storage_.value_)); - } else if (init_ && rhs.init_) { - storage_.value_ = std::forward(rhs.storage_.value_); - } - return *this; - } - - /// Destroys the stored value if there is one - ~optional() { - if (init_) { - storage_.value_.~T(); - } - } - - optional& operator=(nullopt_t) noexcept { - clear(); - return *this; - } - - /// Returns true if the object contains a value, false otherwise - explicit operator bool() const noexcept { - return init_; - } - - /// Returns true if the object contains a value, false otherwise - bool has_value() const noexcept { - return init_; - } - - /// Returns a constant reference to the contained value. Calls ET_CHECK if - /// the object does not contain a value. - T const& value() const& { - ET_CHECK(init_); - return contained_val(); - } - - /// Returns a mutable reference to the contained value. Calls ET_CHECK if the - /// object does not contain a value. - T& value() & { - ET_CHECK(init_); - return contained_val(); - } - - /// Returns an rvalue of the contained value. Calls ET_CHECK if the object - /// does not contain a value. - T&& value() && { - ET_CHECK(init_); - return std::forward(contained_val()); - } - - private: - // Used to invoke the dummy ctor of storage_t in the initializer lists of - // optional_base as default ctor is implicitly deleted because T is nontrivial - struct trivial_init_t { - } trivial_init{}; - - /** - * A wrapper type that lets us avoid constructing a T when there is no value. - * If there is a value present, the optional class must destroy it. - */ - union storage_t { - /// A small, trivially-constructable alternative to T. - unsigned char dummy_; - /// The constructed value itself, if optional::has_value_ is true. - T value_; - - /* implicit */ storage_t(trivial_init_t) { - dummy_ = 0; - } - - template - storage_t(Args&&... args) : value_(std::forward(args)...) {} - - ~storage_t() {} - }; - - const T& contained_val() const& { - return storage_.value_; - } - T&& contained_val() && { - return std::move(storage_.value_); - } - T& contained_val() & { - return storage_.value_; - } - - void clear() noexcept { - if (init_) { - storage_.value_.~T(); - } - init_ = false; - } - - storage_t storage_; - bool init_; -}; +using std::nullopt; +using std::nullopt_t; +using std::optional; } // namespace etensor } // namespace runtime