Skip to content

Commit 90f0e53

Browse files
David Linfacebook-github-bot
authored andcommitted
Add mean.dtype_out op for Ads model (pytorch#7404)
Summary: title Reviewed By: manuelcandales Differential Revision: D67453766
1 parent 4192fec commit 90f0e53

File tree

5 files changed

+103
-1
lines changed

5 files changed

+103
-1
lines changed

kernels/aten/functions.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,8 @@
257257

258258
- op: mean.out
259259

260+
- op: mean.dtype_out
261+
260262
- op: min.dim_min
261263

262264
- op: min.unary_out

kernels/portable/cpu/op_mean.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,14 @@ Tensor& mean_dim_out(
6666
return out;
6767
}
6868

69+
Tensor& mean_dtype_out(
70+
KernelRuntimeContext& ctx,
71+
const Tensor& in,
72+
optional<ScalarType> dtype,
73+
Tensor& out) {
74+
return mean_dim_out(ctx, in, ArrayRef<int64_t>(), false, dtype, out);
75+
}
76+
6977
} // namespace native
7078
} // namespace executor
7179
} // namespace torch

kernels/portable/cpu/util/reduce_util.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -386,6 +386,7 @@ bool check_mean_dim_args(
386386
check_reduction_args(in, dim_list, keepdim, dtype, out));
387387

388388
if (dtype) {
389+
ET_LOG(Info, "dtype is %hhd", static_cast<int8_t>(dtype.value()));
389390
ET_LOG_AND_RETURN_IF_FALSE(torch::executor::isFloatingType(dtype.value()));
390391
ET_LOG_AND_RETURN_IF_FALSE(out.scalar_type() == dtype.value());
391392
} else {

kernels/portable/functions.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -577,6 +577,11 @@
577577
- arg_meta: null
578578
kernel_name: torch::executor::mean_dim_out
579579

580+
- op: mean.dtype_out
581+
kernels:
582+
- arg_meta: null
583+
kernel_name: torch::executor::mean_dtype_out
584+
580585
- op: min.dim_min
581586
kernels:
582587
- arg_meta: null

kernels/test/op_mean_test.cpp

Lines changed: 87 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
#include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
1010
#include <executorch/kernels/test/TestUtil.h>
1111
#include <executorch/kernels/test/supported_features.h>
12-
#include <executorch/runtime/core/exec_aten/exec_aten.h>
12+
#include <executorch/runtime/core/error.h>
1313
#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
1414
#include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
1515
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
@@ -23,6 +23,7 @@ using exec_aten::optional;
2323
using exec_aten::ScalarType;
2424
using exec_aten::Tensor;
2525
using torch::executor::testing::TensorFactory;
26+
using executorch::runtime::Error;
2627

2728
class OpMeanOutTest : public OperatorTest {
2829
protected:
@@ -36,6 +37,14 @@ class OpMeanOutTest : public OperatorTest {
3637
context_, self, dim, keepdim, dtype, out);
3738
}
3839

40+
Tensor& op_mean_dtype_out(
41+
const Tensor& self,
42+
optional<ScalarType> dtype,
43+
Tensor& out) {
44+
return torch::executor::aten::mean_outf(
45+
context_, self, dtype, out);
46+
}
47+
3948
template <ScalarType IN_DTYPE, ScalarType OUT_DTYPE>
4049
void test_mean_dim_out_invalid_dimensions() {
4150
TensorFactory<IN_DTYPE> tf_in;
@@ -466,3 +475,80 @@ TEST_F(OpMeanOutTest, DynamicShapeUnbound) {
466475
op_mean_out(x, ArrayRef<int64_t>{1}, false, ScalarType::Float, out);
467476
EXPECT_TENSOR_CLOSE(out, expected_result);
468477
}
478+
479+
TEST_F(OpMeanOutTest, DTypeOutFloatValid) {
480+
TensorFactory<ScalarType::Float> tf;
481+
482+
Tensor x = tf.make(
483+
{10, 10},
484+
{1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
485+
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
486+
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
487+
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
488+
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
489+
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
490+
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
491+
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0});
492+
Tensor expected_result =
493+
tf.make({}, {1.0});
494+
495+
Tensor out = tf.zeros({});
496+
Tensor ret =
497+
op_mean_dtype_out(x, ScalarType::Float, out);
498+
EXPECT_TENSOR_CLOSE(out, expected_result);
499+
}
500+
501+
TEST_F(OpMeanOutTest, DTypeOutFloatToBoolInvalid) {
502+
TensorFactory<ScalarType::Float> tf;
503+
504+
Tensor x = tf.make(
505+
{10, 10},
506+
{1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
507+
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
508+
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
509+
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
510+
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
511+
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
512+
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
513+
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0});
514+
Tensor expected_result =
515+
tf.make({}, {1.0});
516+
517+
Tensor out = tf.zeros({});
518+
519+
ET_EXPECT_KERNEL_FAILURE(
520+
context_,
521+
op_mean_dtype_out(x, ScalarType::Bool, out));
522+
}
523+
524+
TEST_F(OpMeanOutTest, DTypeOutFloatInfinity) {
525+
TensorFactory<ScalarType::Float> tf;
526+
527+
Tensor x = tf.make(
528+
{2, 1},
529+
{INFINITY, INFINITY});
530+
Tensor expected_result =
531+
tf.make({}, {INFINITY});
532+
533+
Tensor out = tf.zeros({});
534+
535+
Tensor ret =
536+
op_mean_dtype_out(x, ScalarType::Float, out);
537+
EXPECT_TENSOR_CLOSE(out, expected_result);
538+
}
539+
540+
TEST_F(OpMeanOutTest, DTypeOutFloatNAN) {
541+
TensorFactory<ScalarType::Float> tf;
542+
543+
Tensor x = tf.make(
544+
{2, 1},
545+
{NAN, INFINITY});
546+
Tensor expected_result =
547+
tf.make({}, {NAN});
548+
549+
Tensor out = tf.zeros({});
550+
551+
Tensor ret =
552+
op_mean_dtype_out(x, ScalarType::Float, out);
553+
EXPECT_TENSOR_CLOSE(out, expected_result);
554+
}

0 commit comments

Comments
 (0)