Skip to content

Commit f378182

Browse files
committed
STASH: randn_out
Just the operator implementation for randn_out in case we need it in the future. fails torchgen assertion during building currently and needs a cursory test. ghstack-source-id: 2805458 ghstack-comment-id: 2749232500 Pull Request resolved: #9553
1 parent 94ec549 commit f378182

File tree

3 files changed

+53
-1
lines changed

3 files changed

+53
-1
lines changed

kernels/portable/cpu/op_randn.cpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/runtime/kernel/kernel_includes.h>
10+
#include <executorch/runtime/platform/assert.h>
11+
12+
#include <random>
13+
#include <type_traits>
14+
15+
namespace torch::executor::native {
16+
17+
using executorch::aten::Tensor;
18+
19+
Tensor& randn_out(
20+
KernelRuntimeContext& context,
21+
IntArrayRef size,
22+
Tensor& out) {
23+
(void)context;
24+
25+
// Resize for dynamic shape
26+
ET_KERNEL_CHECK_MSG(
27+
context,
28+
resize_tensor(out, size) == Error::Ok,
29+
InvalidArgument,
30+
out,
31+
"Failed to resize output tensor.");
32+
33+
std::default_random_engine gen;
34+
ET_SWITCH_FLOATHBF16_TYPES(out.scalar_type(), ctx, "randn.out", CTYPE, [&]() {
35+
using dist_type = std::conditional_t<c10::is_reduced_floating_point_v<CTYPE>, float, CTYPE>;
36+
std::normal_distribution<dist_type> dist;
37+
std::generate_n(out.mutable_data_ptr<CTYPE>(), out.numel(), [&]() {
38+
return static_cast<CTYPE>(dist(gen));
39+
});
40+
});
41+
return out;
42+
}
43+
44+
} // namespace torch::executor::native

kernels/portable/functions.yaml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,6 @@
308308
- arg_meta: null
309309
kernel_name: torch::executor::div_out_mode
310310

311-
312311
- op: embedding.out
313312
kernels:
314313
- arg_meta: null
@@ -697,6 +696,11 @@
697696
- arg_meta: null
698697
kernel_name: torch::executor::prod_out
699698

699+
- op: randn.out
700+
kernels:
701+
- arg_meta: null
702+
kernel_name: torch::executor::randn_out
703+
700704
- op: reciprocal.out
701705
kernels:
702706
- arg_meta: null

shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -961,6 +961,10 @@ ATEN_OPS = (
961961
"//executorch/kernels/portable/cpu/util:reduce_util",
962962
],
963963
),
964+
op_target(
965+
name = "op_randn",
966+
deps = [],
967+
),
964968
op_target(
965969
name = "op_reciprocal",
966970
deps = [

0 commit comments

Comments
 (0)