Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .ci/docker/ci_commit_pins/pytorch.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
295f2ed4d103017f7e19a7b8263ece606cd629db
59d5cf083b4f860dea76fe8936076177f9367f10
9 changes: 9 additions & 0 deletions exir/dialects/edge/op/sample_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,15 @@
],
"returns": [Return(ArgType.Tensor)],
},
"elu.default": { # (Tensor self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1, *, Tensor(a!) out) -> Tensor(a!)
"args": [
InArg(ArgType.Tensor),
InArg(ArgType.Scalar),
InArg(ArgType.Scalar),
InArg(ArgType.Scalar),
],
"returns": [Return(ArgType.Tensor)],
},
"embedding.default": { # (Tensor weight, Tensor indices, SymInt padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) -> Tensor
"args": [
InArg(ArgType.Tensor),
Expand Down
4 changes: 2 additions & 2 deletions install_requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def python_is_compatible():
#
# NOTE: If you're changing, make the corresponding change in .ci/docker/ci_commit_pins/pytorch.txt
# by picking the hash from the same date in https://hud.pytorch.org/hud/pytorch/pytorch/nightly/
NIGHTLY_VERSION = "dev20250311"
NIGHTLY_VERSION = "dev20250325"


def install_requirements(use_pytorch_nightly):
Expand All @@ -80,7 +80,7 @@ def install_requirements(use_pytorch_nightly):
# Setting use_pytorch_nightly to false to test the pinned PyTorch commit. Note
# that we don't need to set any version number there because they have already
# been installed on CI before this step, so pip won't reinstall them
f"torch==2.7.0.{NIGHTLY_VERSION}" if use_pytorch_nightly else "torch",
f"torch==2.8.0.{NIGHTLY_VERSION}" if use_pytorch_nightly else "torch",
(
f"torchvision==0.22.0.{NIGHTLY_VERSION}"
if use_pytorch_nightly
Expand Down
2 changes: 2 additions & 0 deletions kernels/aten/functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@

- op: div.out_mode

- op: elu.out

- op: embedding.out

- op: empty.out
Expand Down
62 changes: 62 additions & 0 deletions kernels/portable/cpu/op_elu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <cmath>
#include <type_traits>

#include <executorch/kernels/portable/cpu/scalar_utils.h>
#include <executorch/kernels/portable/cpu/util/elementwise_util.h>
#include <executorch/runtime/kernel/kernel_includes.h>

namespace torch::executor::native {

Tensor& elu_out(
KernelRuntimeContext& ctx,
const Tensor& in,
const Scalar& alpha,
const Scalar& scale,
const Scalar& input_scale,
Tensor& out) {
ET_KERNEL_CHECK(ctx, tensors_have_same_dtype(in, out), InvalidArgument, out);
ET_KERNEL_CHECK(
ctx, resize_tensor(out, in.sizes()) == Error::Ok, InvalidArgument, out);

ET_KERNEL_CHECK(
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);

ET_KERNEL_CHECK(ctx, tensor_is_floating_type(in), InvalidArgument, out);

ET_KERNEL_CHECK(ctx, tensors_have_same_dtype(in, out), InvalidArgument, out);

static constexpr const char op_name[] = "elu.out";
ET_SWITCH_FLOATHBF16_TYPES(in.scalar_type(), ctx, op_name, CTYPE, [&]() {
using MathT = std::
conditional_t<c10::is_reduced_floating_point_v<CTYPE>, float, CTYPE>;
MathT math_alpha = 0;
MathT math_scale = 0;
MathT math_input_scale = 0;
ET_EXTRACT_SCALAR(alpha, math_alpha);
ET_EXTRACT_SCALAR(scale, math_scale);
ET_EXTRACT_SCALAR(input_scale, math_input_scale);
const auto negcoef = math_alpha * math_scale;
utils::apply_unitensor_elementwise_fn<CTYPE, op_name>(
[negcoef, math_scale, math_input_scale](auto x) {
return MathT(x) <= MathT(0)
? std::expm1(MathT(x) * math_input_scale) * negcoef
: MathT(x) * math_scale;
},
ctx,
in,
utils::SupportedTensorDtypes::FLOATHBF16,
out,
utils::SupportedTensorDtypes::SAME_AS_COMMON);
});
return out;
}

} // namespace torch::executor::native
5 changes: 5 additions & 0 deletions kernels/portable/functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,11 @@
- arg_meta: null
kernel_name: torch::executor::eq_tensor_out

- op: elu.out
kernels:
- arg_meta: null
kernel_name: torch::executor::elu_out

- op: erf.out
kernels:
- arg_meta: null
Expand Down
1 change: 1 addition & 0 deletions kernels/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ set(all_test_sources
"op_detach_copy_test.cpp"
"op_diagonal_copy_test.cpp"
"op_div_test.cpp"
"op_elu_test.cpp"
"op_embedding_test.cpp"
"op_empty_test.cpp"
"op_eq_test.cpp"
Expand Down
95 changes: 95 additions & 0 deletions kernels/test/op_elu_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
#include <executorch/kernels/test/TestUtil.h>
#include <executorch/kernels/test/supported_features.h>
#include <executorch/runtime/core/exec_aten/exec_aten.h>
#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
#include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>

#include <gtest/gtest.h>

using executorch::aten::Scalar;
using executorch::aten::ScalarType;
using executorch::aten::string_view;
using executorch::aten::Tensor;
using torch::executor::testing::TensorFactory;

class OpEluTest : public OperatorTest {
protected:
Tensor& op_elu_out(
const Tensor& self,
const Scalar& alpha,
const Scalar& scale,
const Scalar& input_scale,
Tensor& out) {
return torch::executor::aten::elu_outf(
context_, self, alpha, scale, input_scale, out);
}

template <ScalarType DTYPE>
void test_elu_execution() {
TensorFactory<DTYPE> tf;

const std::vector<int32_t> sizes = {3, 2};

Tensor in = tf.make(sizes, /*data=*/{-0.125, -0.25, -1, 0, 1.25, 100});

Tensor out = tf.zeros(sizes);

// Run full gelu.
op_elu_out(in, 1.25, 1, 1, out);

// Check that it matches the expected output.
EXPECT_TENSOR_CLOSE(
out,
tf.make(
sizes,
/*data=*/
{-0.146879, -0.276499, -0.790151, 0, 1.25, 100}));
}

template <ScalarType DTYPE>
void test_integer_elu_dies() {
TensorFactory<DTYPE> tf;

Tensor in = tf.ones({1});
Tensor out = tf.ones({1});
ET_EXPECT_KERNEL_FAILURE(context_, op_elu_out(in, 1, 1, 1, out));
}
};

TEST_F(OpEluTest, Basic) {
#define TEST_ENTRY(ctype, dtype) test_elu_execution<ScalarType::dtype>();
ET_FORALL_FLOATHBF16_TYPES(TEST_ENTRY);
#undef TEST_ENTRY
}

TEST_F(OpEluTest, UnhandledDtypeDies) {
#define TEST_ENTRY(ctype, dtype) test_integer_elu_dies<ScalarType::dtype>();
ET_FORALL_INT_TYPES(TEST_ENTRY);
#undef TEST_ENTRY
}

TEST_F(OpEluTest, MismatchedOutputDtypeDies) {
// Two different dtypes. This test uses two types with the same size to
// demonstrate that the ScalarType itself matters, not the size of the
// tensor elements.
TensorFactory<ScalarType::Float> tf_float;
TensorFactory<ScalarType::Double> tf_double;

const std::vector<int32_t> sizes = {2, 2};

Tensor a = tf_float.ones(sizes);

// Destination with a dtype different from the input.
Tensor out = tf_double.zeros(sizes);

ET_EXPECT_KERNEL_FAILURE(context_, op_elu_out(a, 1, 1, 1, out));
}
1 change: 1 addition & 0 deletions kernels/test/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ def define_common_targets():
_common_op_test("op_detach_copy_test", ["aten", "portable"])
_common_op_test("op_diagonal_copy_test", ["aten", "portable"])
_common_op_test("op_div_test", ["aten", "portable", "optimized"])
_common_op_test("op_elu_test", ["aten", "portable"])
_common_op_test("op_embedding_test", ["aten", "portable"])
_common_op_test("op_empty_test", ["aten", "portable"])
_common_op_test("op_eq_test", ["aten", "portable"])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,13 @@ ATEN_OPS = (
":scalar_utils",
],
),
op_target(
name = "op_elu",
deps = [
":scalar_utils",
"//executorch/kernels/portable/cpu/util:elementwise_util",
],
),
op_target(
name = "op_embedding",
deps = [
Expand Down
Loading