99#include < executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
1010#include < executorch/kernels/test/TestUtil.h>
1111#include < executorch/kernels/test/supported_features.h>
12- #include < executorch/runtime/core/exec_aten/exec_aten .h>
12+ #include < executorch/runtime/core/error .h>
1313#include < executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
1414#include < executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
1515#include < executorch/runtime/core/exec_aten/util/scalar_type_util.h>
@@ -22,6 +22,7 @@ using exec_aten::ArrayRef;
2222using exec_aten::optional;
2323using exec_aten::ScalarType;
2424using exec_aten::Tensor;
25+ using executorch::runtime::Error;
2526using torch::executor::testing::TensorFactory;
2627
2728class OpMeanOutTest : public OperatorTest {
@@ -36,6 +37,13 @@ class OpMeanOutTest : public OperatorTest {
3637 context_, self, dim, keepdim, dtype, out);
3738 }
3839
40+ Tensor& op_mean_dtype_out (
41+ const Tensor& self,
42+ optional<ScalarType> dtype,
43+ Tensor& out) {
44+ return torch::executor::aten::mean_outf (context_, self, dtype, out);
45+ }
46+
3947 template <ScalarType IN_DTYPE, ScalarType OUT_DTYPE>
4048 void test_mean_dim_out_invalid_dimensions () {
4149 TensorFactory<IN_DTYPE> tf_in;
@@ -466,3 +474,68 @@ TEST_F(OpMeanOutTest, DynamicShapeUnbound) {
466474 op_mean_out (x, ArrayRef<int64_t >{1 }, false , ScalarType::Float, out);
467475 EXPECT_TENSOR_CLOSE (out, expected_result);
468476}
477+
478+ TEST_F (OpMeanOutTest, DTypeOutFloatValid) {
479+ TensorFactory<ScalarType::Float> tf;
480+
481+ Tensor x = tf.make (
482+ {10 , 10 },
483+ {1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 ,
484+ 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 ,
485+ 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 ,
486+ 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 ,
487+ 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 ,
488+ 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 ,
489+ 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 ,
490+ 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 });
491+ Tensor expected_result = tf.make ({}, {1.0 });
492+
493+ Tensor out = tf.zeros ({});
494+ Tensor ret = op_mean_dtype_out (x, ScalarType::Float, out);
495+ EXPECT_TENSOR_CLOSE (out, expected_result);
496+ }
497+
498+ TEST_F (OpMeanOutTest, DTypeOutFloatToBoolInvalid) {
499+ TensorFactory<ScalarType::Float> tf;
500+
501+ Tensor x = tf.make (
502+ {10 , 10 },
503+ {1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 ,
504+ 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 ,
505+ 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 ,
506+ 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 ,
507+ 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 ,
508+ 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 ,
509+ 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 ,
510+ 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 });
511+ Tensor expected_result = tf.make ({}, {1.0 });
512+
513+ Tensor out = tf.zeros ({});
514+
515+ ET_EXPECT_KERNEL_FAILURE (
516+ context_, op_mean_dtype_out (x, ScalarType::Bool, out));
517+ }
518+
519+ TEST_F (OpMeanOutTest, DTypeOutFloatInfinity) {
520+ TensorFactory<ScalarType::Float> tf;
521+
522+ Tensor x = tf.make ({2 , 1 }, {INFINITY, INFINITY});
523+ Tensor expected_result = tf.make ({}, {INFINITY});
524+
525+ Tensor out = tf.zeros ({});
526+
527+ Tensor ret = op_mean_dtype_out (x, ScalarType::Float, out);
528+ EXPECT_TENSOR_CLOSE (out, expected_result);
529+ }
530+
531+ TEST_F (OpMeanOutTest, DTypeOutFloatNAN) {
532+ TensorFactory<ScalarType::Float> tf;
533+
534+ Tensor x = tf.make ({2 , 1 }, {NAN, INFINITY});
535+ Tensor expected_result = tf.make ({}, {NAN});
536+
537+ Tensor out = tf.zeros ({});
538+
539+ Tensor ret = op_mean_dtype_out (x, ScalarType::Float, out);
540+ EXPECT_TENSOR_CLOSE (out, expected_result);
541+ }
0 commit comments