| 
 | 1 | +/*  | 
 | 2 | + * Copyright (c) Meta Platforms, Inc. and affiliates.  | 
 | 3 | + * All rights reserved.  | 
 | 4 | + *  | 
 | 5 | + * This source code is licensed under the BSD-style license found in the  | 
 | 6 | + * LICENSE file in the root directory of this source tree.  | 
 | 7 | + */  | 
 | 8 | + | 
 | 9 | +#include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator  | 
 | 10 | +#include <executorch/kernels/test/TestUtil.h>  | 
 | 11 | +#include <executorch/kernels/test/supported_features.h>  | 
 | 12 | +#include <executorch/runtime/core/exec_aten/exec_aten.h>  | 
 | 13 | +#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>  | 
 | 14 | +#include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>  | 
 | 15 | + | 
 | 16 | +#include <gtest/gtest.h>  | 
 | 17 | + | 
 | 18 | +using executorch::aten::Scalar;  | 
 | 19 | +using executorch::aten::ScalarType;  | 
 | 20 | +using executorch::aten::string_view;  | 
 | 21 | +using executorch::aten::Tensor;  | 
 | 22 | +using torch::executor::testing::TensorFactory;  | 
 | 23 | + | 
 | 24 | +class OpEluTest : public OperatorTest {  | 
 | 25 | + protected:  | 
 | 26 | +  Tensor& op_elu_out(  | 
 | 27 | +      const Tensor& self,  | 
 | 28 | +      const Scalar& alpha,  | 
 | 29 | +      const Scalar& scale,  | 
 | 30 | +      const Scalar& input_scale,  | 
 | 31 | +      Tensor& out) {  | 
 | 32 | +    return torch::executor::aten::elu_outf(  | 
 | 33 | +        context_, self, alpha, scale, input_scale, out);  | 
 | 34 | +  }  | 
 | 35 | + | 
 | 36 | +  template <ScalarType DTYPE>  | 
 | 37 | +  void test_elu_execution() {  | 
 | 38 | +    TensorFactory<DTYPE> tf;  | 
 | 39 | + | 
 | 40 | +    const std::vector<int32_t> sizes = {3, 2};  | 
 | 41 | + | 
 | 42 | +    Tensor in = tf.make(sizes, /*data=*/{-0.125, -0.25, -1, 0, 1.25, 100});  | 
 | 43 | + | 
 | 44 | +    Tensor out = tf.zeros(sizes);  | 
 | 45 | + | 
 | 46 | +    // Run full gelu.  | 
 | 47 | +    op_elu_out(in, 1.25, 1, 1, out);  | 
 | 48 | + | 
 | 49 | +    // Check that it matches the expected output.  | 
 | 50 | +    EXPECT_TENSOR_CLOSE(  | 
 | 51 | +        out,  | 
 | 52 | +        tf.make(  | 
 | 53 | +            sizes,  | 
 | 54 | +            /*data=*/  | 
 | 55 | +            {-0.146879, -0.276499, -0.790151, 0, 1.25, 100}));  | 
 | 56 | +  }  | 
 | 57 | + | 
 | 58 | +  template <ScalarType DTYPE>  | 
 | 59 | +  void test_integer_elu_dies() {  | 
 | 60 | +    TensorFactory<DTYPE> tf;  | 
 | 61 | + | 
 | 62 | +    Tensor in = tf.ones({1});  | 
 | 63 | +    Tensor out = tf.ones({1});  | 
 | 64 | +    ET_EXPECT_KERNEL_FAILURE(context_, op_elu_out(in, 1, 1, 1, out));  | 
 | 65 | +  }  | 
 | 66 | +};  | 
 | 67 | + | 
 | 68 | +TEST_F(OpEluTest, Basic) {  | 
 | 69 | +#define TEST_ENTRY(ctype, dtype) test_elu_execution<ScalarType::dtype>();  | 
 | 70 | +  ET_FORALL_FLOATHBF16_TYPES(TEST_ENTRY);  | 
 | 71 | +#undef TEST_ENTRY  | 
 | 72 | +}  | 
 | 73 | + | 
 | 74 | +TEST_F(OpEluTest, UnhandledDtypeDies) {  | 
 | 75 | +#define TEST_ENTRY(ctype, dtype) test_integer_elu_dies<ScalarType::dtype>();  | 
 | 76 | +  ET_FORALL_INT_TYPES(TEST_ENTRY);  | 
 | 77 | +#undef TEST_ENTRY  | 
 | 78 | +}  | 
 | 79 | + | 
 | 80 | +TEST_F(OpEluTest, MismatchedOutputDtypeDies) {  | 
 | 81 | +  // Two different dtypes. This test uses two types with the same size to  | 
 | 82 | +  // demonstrate that the ScalarType itself matters, not the size of the  | 
 | 83 | +  // tensor elements.  | 
 | 84 | +  TensorFactory<ScalarType::Float> tf_float;  | 
 | 85 | +  TensorFactory<ScalarType::Double> tf_double;  | 
 | 86 | + | 
 | 87 | +  const std::vector<int32_t> sizes = {2, 2};  | 
 | 88 | + | 
 | 89 | +  Tensor a = tf_float.ones(sizes);  | 
 | 90 | + | 
 | 91 | +  // Destination with a dtype different from the input.  | 
 | 92 | +  Tensor out = tf_double.zeros(sizes);  | 
 | 93 | + | 
 | 94 | +  ET_EXPECT_KERNEL_FAILURE(context_, op_elu_out(a, 1, 1, 1, out));  | 
 | 95 | +}  | 
0 commit comments