99#include < executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
1010#include < executorch/kernels/test/TestUtil.h>
1111#include < executorch/kernels/test/supported_features.h>
12- #include < executorch/runtime/core/exec_aten/exec_aten .h>
12+ #include < executorch/runtime/core/error .h>
1313#include < executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
1414#include < executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
1515#include < executorch/runtime/core/exec_aten/util/scalar_type_util.h>
@@ -23,6 +23,7 @@ using exec_aten::optional;
2323using exec_aten::ScalarType;
2424using exec_aten::Tensor;
2525using torch::executor::testing::TensorFactory;
26+ using executorch::runtime::Error;
2627
2728class OpMeanOutTest : public OperatorTest {
2829 protected:
@@ -36,6 +37,14 @@ class OpMeanOutTest : public OperatorTest {
3637 context_, self, dim, keepdim, dtype, out);
3738 }
3839
40+ Tensor& op_mean_dtype_out (
41+ const Tensor& self,
42+ optional<ScalarType> dtype,
43+ Tensor& out) {
44+ return torch::executor::aten::mean_outf (
45+ context_, self, dtype, out);
46+ }
47+
3948 template <ScalarType IN_DTYPE, ScalarType OUT_DTYPE>
4049 void test_mean_dim_out_invalid_dimensions () {
4150 TensorFactory<IN_DTYPE> tf_in;
@@ -466,3 +475,80 @@ TEST_F(OpMeanOutTest, DynamicShapeUnbound) {
466475 op_mean_out (x, ArrayRef<int64_t >{1 }, false , ScalarType::Float, out);
467476 EXPECT_TENSOR_CLOSE (out, expected_result);
468477}
478+
479+ TEST_F (OpMeanOutTest, DTypeOutFloatValid) {
480+ TensorFactory<ScalarType::Float> tf;
481+
482+ Tensor x = tf.make (
483+ {10 , 10 },
484+ {1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 ,
485+ 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 ,
486+ 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 ,
487+ 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 ,
488+ 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 ,
489+ 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 ,
490+ 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 ,
491+ 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 });
492+ Tensor expected_result =
493+ tf.make ({}, {1.0 });
494+
495+ Tensor out = tf.zeros ({});
496+ Tensor ret =
497+ op_mean_dtype_out (x, ScalarType::Float, out);
498+ EXPECT_TENSOR_CLOSE (out, expected_result);
499+ }
500+
501+ TEST_F (OpMeanOutTest, DTypeOutFloatToBoolInvalid) {
502+ TensorFactory<ScalarType::Float> tf;
503+
504+ Tensor x = tf.make (
505+ {10 , 10 },
506+ {1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 ,
507+ 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 ,
508+ 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 ,
509+ 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 ,
510+ 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 ,
511+ 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 ,
512+ 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 ,
513+ 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 });
514+ Tensor expected_result =
515+ tf.make ({}, {1.0 });
516+
517+ Tensor out = tf.zeros ({});
518+
519+ ET_EXPECT_KERNEL_FAILURE (
520+ context_,
521+ op_mean_dtype_out (x, ScalarType::Bool, out));
522+ }
523+
524+ TEST_F (OpMeanOutTest, DTypeOutFloatInfinity) {
525+ TensorFactory<ScalarType::Float> tf;
526+
527+ Tensor x = tf.make (
528+ {2 , 1 },
529+ {INFINITY, INFINITY});
530+ Tensor expected_result =
531+ tf.make ({}, {INFINITY});
532+
533+ Tensor out = tf.zeros ({});
534+
535+ Tensor ret =
536+ op_mean_dtype_out (x, ScalarType::Float, out);
537+ EXPECT_TENSOR_CLOSE (out, expected_result);
538+ }
539+
540+ TEST_F (OpMeanOutTest, DTypeOutFloatNAN) {
541+ TensorFactory<ScalarType::Float> tf;
542+
543+ Tensor x = tf.make (
544+ {2 , 1 },
545+ {NAN, INFINITY});
546+ Tensor expected_result =
547+ tf.make ({}, {NAN});
548+
549+ Tensor out = tf.zeros ({});
550+
551+ Tensor ret =
552+ op_mean_dtype_out (x, ScalarType::Float, out);
553+ EXPECT_TENSOR_CLOSE (out, expected_result);
554+ }
0 commit comments