From 6ad1c02029e9aeb732fea7bbf67e2a29d43092ce Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Mon, 24 Mar 2025 12:48:51 -0700 Subject: [PATCH] Update [ghstack-poisoned] --- kernels/portable/cpu/op_randn.cpp | 44 +++++++++++++++++++ kernels/portable/functions.yaml | 6 ++- .../kernels/portable/op_registration_util.bzl | 4 ++ 3 files changed, 53 insertions(+), 1 deletion(-) create mode 100644 kernels/portable/cpu/op_randn.cpp diff --git a/kernels/portable/cpu/op_randn.cpp b/kernels/portable/cpu/op_randn.cpp new file mode 100644 index 00000000000..4b83cc35185 --- /dev/null +++ b/kernels/portable/cpu/op_randn.cpp @@ -0,0 +1,44 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include +#include + +namespace torch::executor::native { + +using executorch::aten::Tensor; + +Tensor& randn_out( + KernelRuntimeContext& context, + IntArrayRef size, + Tensor& out) { + (void)context; + + // Resize for dynamic shape + ET_KERNEL_CHECK_MSG( + context, + resize_tensor(out, size) == Error::Ok, + InvalidArgument, + out, + "Failed to resize output tensor."); + + std::default_random_engine gen; + ET_SWITCH_FLOATHBF16_TYPES(out.scalar_type(), ctx, "randn.out", CTYPE, [&]() { + using dist_type = std::conditional_t, float, CTYPE>; + std::normal_distribution dist; + std::generate_n(out.mutable_data_ptr(), out.numel(), [&]() { + return static_cast(dist(gen)); + }); + }); + return out; +} + +} // namespace torch::executor::native diff --git a/kernels/portable/functions.yaml b/kernels/portable/functions.yaml index 29dfe8b1a0c..9775bb490c7 100644 --- a/kernels/portable/functions.yaml +++ b/kernels/portable/functions.yaml @@ -308,7 +308,6 @@ - arg_meta: null kernel_name: torch::executor::div_out_mode - - op: embedding.out kernels: - arg_meta: null @@ -697,6 +696,11 @@ - arg_meta: null kernel_name: torch::executor::prod_out +- op: randn.out + kernels: + - arg_meta: null + kernel_name: torch::executor::randn_out + - op: reciprocal.out kernels: - arg_meta: null diff --git a/shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl b/shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl index b56413b92f4..2cd0f623cf3 100644 --- a/shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl +++ b/shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl @@ -961,6 +961,10 @@ ATEN_OPS = ( "//executorch/kernels/portable/cpu/util:reduce_util", ], ), + op_target( + name = "op_randn", + deps = [], + ), op_target( name = "op_reciprocal", deps = [