diff --git a/kernels/aten/functions.yaml b/kernels/aten/functions.yaml index 98df54fe75f..77bf9cd573b 100644 --- a/kernels/aten/functions.yaml +++ b/kernels/aten/functions.yaml @@ -317,6 +317,8 @@ - op: rand.out +- op: randn.out + - op: reciprocal.out - op: relu.out diff --git a/kernels/portable/cpu/op_randn.cpp b/kernels/portable/cpu/op_randn.cpp new file mode 100644 index 00000000000..a0732e7f177 --- /dev/null +++ b/kernels/portable/cpu/op_randn.cpp @@ -0,0 +1,50 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include +#include + +#include + +namespace torch { +namespace executor { +namespace native { + +using executorch::aten::IntArrayRef; +using Tensor = executorch::aten::Tensor; +using ScalarType = executorch::aten::ScalarType; + +Tensor& +randn_out(KernelRuntimeContext& ctx, const IntArrayRef sizes, Tensor& out) { + (void)ctx; + + std::mt19937 gen((std::random_device())()); + std::normal_distribution dist(0.0, 1.0); + + // Resize for dynamic shape + ET_KERNEL_CHECK_MSG( + ctx, + resize_tensor(out, sizes) == Error::Ok, + InvalidArgument, + out, + "Failed to resize output tensor."); + + ET_SWITCH_FLOATHBF16_TYPES(out.scalar_type(), ctx, "randn.out", CTYPE, [&] { + auto data_out = out.mutable_data_ptr(); + for (const auto i : c10::irange(out.numel())) { + data_out[i] = static_cast(dist(gen)); + } + }); + + return out; +} + +} // namespace native +} // namespace executor +} // namespace torch diff --git a/kernels/portable/functions.yaml b/kernels/portable/functions.yaml index b763b7ae585..feaee415f91 100644 --- a/kernels/portable/functions.yaml +++ b/kernels/portable/functions.yaml @@ -719,6 +719,12 @@ kernel_name: torch::executor::rand_out tags: nondeterministic_seeded +- op: randn.out + kernels: + - arg_meta: null + kernel_name: torch::executor::randn_out + tags: nondeterministic_seeded + - op: reciprocal.out kernels: - arg_meta: null diff --git a/kernels/test/CMakeLists.txt b/kernels/test/CMakeLists.txt index 64a15e5c385..4f174b5a652 100644 --- a/kernels/test/CMakeLists.txt +++ b/kernels/test/CMakeLists.txt @@ -198,6 +198,7 @@ set(all_test_sources "op_pixel_shuffle_test.cpp" "op_prod_test.cpp" "op_rand_test.cpp" + "op_randn_test.cpp" "op_reciprocal_test.cpp" "op_relu_test.cpp" "op_remainder_test.cpp" diff --git a/kernels/test/op_randn_test.cpp b/kernels/test/op_randn_test.cpp new file mode 100644 index 00000000000..41456584e91 --- /dev/null +++ b/kernels/test/op_randn_test.cpp @@ -0,0 +1,93 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include // Declares the operator +#include +#include +#include +#include +#include + +#include + +#include +#include + +using executorch::aten::IntArrayRef; +using executorch::aten::ScalarType; +using executorch::aten::Tensor; +using torch::executor::testing::TensorFactory; + +class OpRandnTest : public OperatorTest { + protected: + void op_randn_out(const IntArrayRef sizes, Tensor& out) { + torch::executor::aten::randn_outf(context_, sizes, out); + } + + template + void test_randn(std::vector& sizes) { + TensorFactory tf; + + // Tensor factory wants int32 scales, op kernel wants int64. + std::vector sizes_i32; + std::transform( + sizes.begin(), + sizes.end(), + std::back_inserter(sizes_i32), + [](int64_t s) { return static_cast(s); }); + Tensor out = tf.zeros(sizes_i32); + + IntArrayRef sizes_ref(sizes.data(), sizes.size()); + op_randn_out(sizes_ref, out); + + // Check mean and standard deviation. To avoid flaky CI, test pretty + // loosely. + auto out_data = out.const_data_ptr(); + double mean = + std::accumulate( + out_data, + out_data + out.numel(), + 0.0, + [](double acc, CTYPE n) { return acc + static_cast(n); }) / + out.numel(); + double var = std::accumulate( + out_data, + out_data + out.numel(), + 0.0, + [=](double acc, CTYPE n) { + return acc + std::pow(static_cast(n) - mean, 2); + }) / + out.numel(); + auto stdev = std::sqrt(var); + + // These are very rough thresholds. A better test implementation would + // probably do a proper statistical test to compare the generated empirical + // data to the reference distribution, but this should do. + EXPECT_LE(std::abs(mean), 5.0 / std::sqrt(out.numel())); + EXPECT_LE(std::abs(stdev - 1.0), 0.1); + EXPECT_GT(stdev, 0); + } +}; + +TEST_F(OpRandnTest, SmokeTest) { + std::vector sizes = {2, 3, 4, 128}; + +#define TEST_ENTRY(ctype, dtype) test_randn(sizes); + ET_FORALL_FLOATHBF16_TYPES(TEST_ENTRY); +#undef TEST_ENTRY +} + +TEST_F(OpRandnTest, Rank) { + std::vector sizes = {1024}; + + for (int64_t i = 0; i < 4; i++) { + sizes.push_back(i + 1); + test_randn(sizes); + } +} diff --git a/kernels/test/targets.bzl b/kernels/test/targets.bzl index 4d2aa7f6644..bde3b8632b0 100644 --- a/kernels/test/targets.bzl +++ b/kernels/test/targets.bzl @@ -286,6 +286,7 @@ def define_common_targets(): _common_op_test("op_pow_test", ["aten", "portable"]) _common_op_test("op_prod_test", ["aten", "portable"]) _common_op_test("op_rand_test", ["aten", "portable"]) + _common_op_test("op_randn_test", ["aten", "portable"]) _common_op_test("op_reciprocal_test", ["aten", "portable"]) _common_op_test("op_relu_test", ["aten", "portable"]) _common_op_test("op_remainder_test", ["aten", "portable"]) diff --git a/shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl b/shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl index f98196daa35..a731ce5c674 100644 --- a/shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl +++ b/shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl @@ -981,6 +981,14 @@ ATEN_OPS = ( "//executorch/runtime/core/exec_aten/util:tensor_util", ] ), + op_target( + name = "op_randn", + deps = [ + ":scalar_utils", + "//executorch/runtime/core/exec_aten/util:scalar_type_util", + "//executorch/runtime/core/exec_aten/util:tensor_util", + ] + ), op_target( name = "op_reciprocal", deps = [