1010#include < executorch/kernels/test/TestUtil.h>
1111#include < executorch/kernels/test/supported_features.h>
1212#include < executorch/runtime/core/exec_aten/exec_aten.h>
13+ #include < executorch/runtime/core/error.h>
1314#include < executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
1415#include < executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
1516#include < executorch/runtime/core/exec_aten/util/scalar_type_util.h>
@@ -23,6 +24,7 @@ using exec_aten::optional;
2324using exec_aten::ScalarType;
2425using exec_aten::Tensor;
2526using torch::executor::testing::TensorFactory;
27+ using executorch::runtime::Error;
2628
2729class OpMeanOutTest : public OperatorTest {
2830 protected:
@@ -36,6 +38,14 @@ class OpMeanOutTest : public OperatorTest {
3638 context_, self, dim, keepdim, dtype, out);
3739 }
3840
41+ Tensor& op_mean_dtype_out (
42+ const Tensor& self,
43+ optional<ScalarType> dtype,
44+ Tensor& out) {
45+ return torch::executor::aten::mean_outf (
46+ context_, self, dtype, out);
47+ }
48+
3949 template <ScalarType IN_DTYPE, ScalarType OUT_DTYPE>
4050 void test_mean_dim_out_invalid_dimensions () {
4151 TensorFactory<IN_DTYPE> tf_in;
@@ -466,3 +476,80 @@ TEST_F(OpMeanOutTest, DynamicShapeUnbound) {
466476 op_mean_out (x, ArrayRef<int64_t >{1 }, false , ScalarType::Float, out);
467477 EXPECT_TENSOR_CLOSE (out, expected_result);
468478}
479+
480+ TEST_F (OpMeanOutTest, DTypeOutFloatValid) {
481+ TensorFactory<ScalarType::Float> tf;
482+
483+ Tensor x = tf.make (
484+ {10 , 10 },
485+ {1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 ,
486+ 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 ,
487+ 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 ,
488+ 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 ,
489+ 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 ,
490+ 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 ,
491+ 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 ,
492+ 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 });
493+ Tensor expected_result =
494+ tf.make ({}, {1.0 });
495+
496+ Tensor out = tf.zeros ({});
497+ Tensor ret =
498+ op_mean_dtype_out (x, ScalarType::Float, out);
499+ EXPECT_TENSOR_CLOSE (out, expected_result);
500+ }
501+
502+ TEST_F (OpMeanOutTest, DTypeOutFloatToBoolInvalid) {
503+ TensorFactory<ScalarType::Float> tf;
504+
505+ Tensor x = tf.make (
506+ {10 , 10 },
507+ {1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 ,
508+ 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 ,
509+ 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 ,
510+ 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 ,
511+ 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 ,
512+ 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 ,
513+ 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 ,
514+ 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 });
515+ Tensor expected_result =
516+ tf.make ({}, {1.0 });
517+
518+ Tensor out = tf.zeros ({});
519+
520+ ET_EXPECT_KERNEL_FAILURE (
521+ context_,
522+ op_mean_dtype_out (x, ScalarType::Bool, out));
523+ }
524+
525+ TEST_F (OpMeanOutTest, DTypeOutFloatInfinity) {
526+ TensorFactory<ScalarType::Float> tf;
527+
528+ Tensor x = tf.make (
529+ {2 , 1 },
530+ {INFINITY, INFINITY});
531+ Tensor expected_result =
532+ tf.make ({}, {INFINITY});
533+
534+ Tensor out = tf.zeros ({});
535+
536+ Tensor ret =
537+ op_mean_dtype_out (x, ScalarType::Float, out);
538+ EXPECT_TENSOR_CLOSE (out, expected_result);
539+ }
540+
541+ TEST_F (OpMeanOutTest, DTypeOutFloatNAN) {
542+ TensorFactory<ScalarType::Float> tf;
543+
544+ Tensor x = tf.make (
545+ {2 , 1 },
546+ {NAN, INFINITY});
547+ Tensor expected_result =
548+ tf.make ({}, {NAN});
549+
550+ Tensor out = tf.zeros ({});
551+
552+ Tensor ret =
553+ op_mean_dtype_out (x, ScalarType::Float, out);
554+ EXPECT_TENSOR_CLOSE (out, expected_result);
555+ }
0 commit comments