diff --git a/kernels/portable/cpu/util/elementwise_util.cpp b/kernels/portable/cpu/util/elementwise_util.cpp new file mode 100644 index 00000000000..eae6977b1fd --- /dev/null +++ b/kernels/portable/cpu/util/elementwise_util.cpp @@ -0,0 +1,86 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +namespace torch::executor::native::utils::internal { + +template +inline bool validate_elementwise_fn_inputs_impl( + KernelRuntimeContext& ctx, + const Tensor& out, + SupportedTensorDtypes out_dtypes, + ScalarType compute_type, + Args... inputs) { + static_assert( + (std::is_same_v> && + ...)); + const auto check_input_dtype = [](auto input, auto compute_type) { + return internal::check_tensor_dtype( + *input.first, input.second, compute_type); + }; + ET_KERNEL_CHECK( + ctx, + (check_input_dtype(inputs, compute_type) && ...) && + internal::check_tensor_dtype(out, out_dtypes, compute_type), + InvalidArgument, + false); + + return true; +} + +bool validate_elementwise_fn_inputs( + KernelRuntimeContext& ctx, + const Tensor& out, + SupportedTensorDtypes out_dtypes, + ScalarType compute_type, + std::pair input) { + return validate_elementwise_fn_inputs_impl( + ctx, + out, + out_dtypes, + compute_type, + input); +} + +bool validate_elementwise_fn_inputs( + KernelRuntimeContext& ctx, + const Tensor& out, + SupportedTensorDtypes out_dtypes, + ScalarType compute_type, + std::pair input0, + std::pair input1) { + return validate_elementwise_fn_inputs_impl( + ctx, + out, + out_dtypes, + compute_type, + input0, + input1); +} + +bool validate_elementwise_fn_inputs( + KernelRuntimeContext& ctx, + const Tensor& out, + SupportedTensorDtypes out_dtypes, + ScalarType compute_type, + std::pair input0, + std::pair input1, + std::pair input2) { + return validate_elementwise_fn_inputs_impl( + ctx, + out, + out_dtypes, + compute_type, + input0, + input1, + input2); +} + + +} // namespace torch::executor::native::utils::internal diff --git a/kernels/portable/cpu/util/elementwise_util.h b/kernels/portable/cpu/util/elementwise_util.h index a376a89747b..c61ebc074a8 100644 --- a/kernels/portable/cpu/util/elementwise_util.h +++ b/kernels/portable/cpu/util/elementwise_util.h @@ -220,30 +220,29 @@ inline void dtype_specialized_elementwise_fn_impl( }); } -template -inline bool validate_elementwise_fn_inputs( - const Op& compute_fun, +bool validate_elementwise_fn_inputs( KernelRuntimeContext& ctx, const Tensor& out, SupportedTensorDtypes out_dtypes, - Args... inputs) { - static_assert( - (std::is_same_v> && - ...)); - constexpr auto compute_type = CppTypeToScalarType::value; - const auto check_input_dtype = [](auto input, auto compute_type) { - return internal::check_tensor_dtype( - *input.first, input.second, compute_type); - }; - ET_KERNEL_CHECK( - ctx, - (check_input_dtype(inputs, compute_type) && ...) && - internal::check_tensor_dtype(out, out_dtypes, compute_type), - InvalidArgument, - false); + ScalarType compute_type, + std::pair input); - return true; -} +bool validate_elementwise_fn_inputs( + KernelRuntimeContext& ctx, + const Tensor& out, + SupportedTensorDtypes out_dtypes, + ScalarType compute_type, + std::pair input0, + std::pair input1); + +bool validate_elementwise_fn_inputs( + KernelRuntimeContext& ctx, + const Tensor& out, + SupportedTensorDtypes out_dtypes, + ScalarType compute_type, + std::pair input0, + std::pair input1, + std::pair input2); template < typename CTYPE_COMPUTE, @@ -314,8 +313,9 @@ inline void apply_elementwise_fn_runtime_out_dtypes( const Tensor& out, SupportedTensorDtypes out_dtypes, Args... inputs) { - const bool inputs_valid = validate_elementwise_fn_inputs( - compute_fun, ctx, out, out_dtypes, inputs...); + constexpr auto compute_type = CppTypeToScalarType::value; + const bool inputs_valid = validate_elementwise_fn_inputs( + ctx, out, out_dtypes, compute_type, inputs...); if (!inputs_valid) { return; } @@ -339,13 +339,13 @@ inline void apply_elementwise_fn( KernelRuntimeContext& ctx, const Tensor& out, Args... inputs) { - const bool inputs_valid = validate_elementwise_fn_inputs( - compute_fun, ctx, out, out_dtypes, inputs...); + constexpr auto compute_type = CppTypeToScalarType::value; + const bool inputs_valid = validate_elementwise_fn_inputs( + ctx, out, out_dtypes, compute_type, inputs...); if (!inputs_valid) { return; } - constexpr auto compute_type = CppTypeToScalarType::value; if constexpr (should_include_kernel_dtype(op_name, compute_type)) { const bool all_inputs_compute_dtype = ((inputs.first->scalar_type() == compute_type) && ...); diff --git a/kernels/portable/cpu/util/targets.bzl b/kernels/portable/cpu/util/targets.bzl index 44b95aa55c4..4ebb14bbe88 100644 --- a/kernels/portable/cpu/util/targets.bzl +++ b/kernels/portable/cpu/util/targets.bzl @@ -104,6 +104,9 @@ def define_common_targets(): runtime.cxx_library( name = "elementwise_util", + srcs = [ + "elementwise_util.cpp", + ], exported_headers = [ "elementwise_util.h", ],