| 
 | 1 | +/*  | 
 | 2 | + * Copyright (c) Meta Platforms, Inc. and affiliates.  | 
 | 3 | + * All rights reserved.  | 
 | 4 | + *  | 
 | 5 | + * This source code is licensed under the BSD-style license found in the  | 
 | 6 | + * LICENSE file in the root directory of this source tree.  | 
 | 7 | + */  | 
 | 8 | + | 
 | 9 | +#include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator  | 
 | 10 | +#include <executorch/kernels/test/TestUtil.h>  | 
 | 11 | +#include <executorch/runtime/core/exec_aten/exec_aten.h>  | 
 | 12 | +#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>  | 
 | 13 | +#include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>  | 
 | 14 | + | 
 | 15 | +#include <gtest/gtest.h>  | 
 | 16 | + | 
 | 17 | +using namespace ::testing;  | 
 | 18 | +using executorch::aten::ScalarType;  | 
 | 19 | +using executorch::aten::Tensor;  | 
 | 20 | +using torch::executor::testing::TensorFactory;  | 
 | 21 | + | 
 | 22 | +class OpViewAsRealTest : public OperatorTest {  | 
 | 23 | + protected:  | 
 | 24 | +  Tensor& view_as_real_copy_out(const Tensor& self, Tensor& out) {  | 
 | 25 | +    return torch::executor::aten::view_as_real_copy_outf(context_, self, out);  | 
 | 26 | +  }  | 
 | 27 | + | 
 | 28 | +  template <typename CTYPE, ScalarType DTYPE>  | 
 | 29 | +  void run_complex_smoke_test() {  | 
 | 30 | +    TensorFactory<DTYPE> tf;  | 
 | 31 | +    constexpr auto REAL_DTYPE = executorch::runtime::toRealValueType(DTYPE);  | 
 | 32 | +    TensorFactory<REAL_DTYPE> tf_out;  | 
 | 33 | + | 
 | 34 | +    Tensor in = tf.make(  | 
 | 35 | +        {2, 2},  | 
 | 36 | +        {CTYPE(3, 4), CTYPE(-1.7, 7.4), CTYPE(5, -12), CTYPE(8.3, 0.1)});  | 
 | 37 | +    Tensor out = tf_out.zeros({2, 2, 2});  | 
 | 38 | +    Tensor expected =  | 
 | 39 | +        tf_out.make({2, 2, 2}, {3, 4, -1.7, 7.4, 5, -12, 8.3, 0.1});  | 
 | 40 | +    Tensor ret = view_as_real_copy_out(in, out);  | 
 | 41 | + | 
 | 42 | +    EXPECT_TENSOR_EQ(out, ret);  | 
 | 43 | +    EXPECT_TENSOR_EQ(out, expected);  | 
 | 44 | +  }  | 
 | 45 | + | 
 | 46 | +  // Tests on tensors with 0 size  | 
 | 47 | +  template <typename CTYPE, ScalarType DTYPE>  | 
 | 48 | +  void test_empty_input() {  | 
 | 49 | +    TensorFactory<DTYPE> tf;  | 
 | 50 | +    constexpr auto REAL_DTYPE = executorch::runtime::toRealValueType(DTYPE);  | 
 | 51 | +    TensorFactory<REAL_DTYPE> tf_out;  | 
 | 52 | + | 
 | 53 | +    Tensor in = tf.make(/*sizes=*/{3, 0, 4}, /*data=*/{});  | 
 | 54 | +    Tensor out = tf_out.zeros({3, 0, 4, 2});  | 
 | 55 | +    Tensor expected = tf_out.make(/*sizes=*/{3, 0, 4, 2}, /*data=*/{});  | 
 | 56 | +    Tensor ret = view_as_real_copy_out(in, out);  | 
 | 57 | + | 
 | 58 | +    EXPECT_TENSOR_EQ(out, ret);  | 
 | 59 | +    EXPECT_TENSOR_EQ(out, expected);  | 
 | 60 | +  }  | 
 | 61 | + | 
 | 62 | +  // Tests on 0-dim input tensors  | 
 | 63 | +  template <typename CTYPE, ScalarType DTYPE>  | 
 | 64 | +  void zero_dim_input() {  | 
 | 65 | +    TensorFactory<DTYPE> tf;  | 
 | 66 | +    constexpr auto REAL_DTYPE = executorch::runtime::toRealValueType(DTYPE);  | 
 | 67 | +    TensorFactory<REAL_DTYPE> tf_out;  | 
 | 68 | + | 
 | 69 | +    Tensor in = tf.make(/*sizes=*/{}, {CTYPE(0, 0)});  | 
 | 70 | +    Tensor out = tf_out.zeros({2});  | 
 | 71 | +    Tensor expected = tf_out.zeros(/*sizes=*/{2});  | 
 | 72 | +    Tensor ret = view_as_real_copy_out(in, out);  | 
 | 73 | + | 
 | 74 | +    EXPECT_TENSOR_EQ(out, ret);  | 
 | 75 | +    EXPECT_TENSOR_EQ(out, expected);  | 
 | 76 | +  }  | 
 | 77 | +};  | 
 | 78 | + | 
 | 79 | +TEST_F(OpViewAsRealTest, ComplexSmokeTest) {  | 
 | 80 | +#define RUN_SMOKE_TEST(ctype, dtype)                  \  | 
 | 81 | +  run_complex_smoke_test<ctype, ScalarType::dtype>(); \  | 
 | 82 | +  test_empty_input<ctype, ScalarType::dtype>();       \  | 
 | 83 | +  zero_dim_input<ctype, ScalarType::dtype>();  | 
 | 84 | +  ET_FORALL_COMPLEXH_TYPES(RUN_SMOKE_TEST);  | 
 | 85 | +#undef RUN_SMOKE_TEST  | 
 | 86 | +}  | 
0 commit comments