Skip to content

Commit 478494b

Browse files
committed
Add portable randn kernel implementation
1 parent f2fb351 commit 478494b

File tree

6 files changed

+160
-0
lines changed

6 files changed

+160
-0
lines changed

kernels/portable/cpu/op_randn.cpp

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
#include <c10/util/irange.h>
9+
10+
#include <executorch/kernels/portable/cpu/scalar_utils.h>
11+
#include <executorch/runtime/kernel/kernel_includes.h>
12+
13+
#include <random>
14+
15+
namespace torch {
16+
namespace executor {
17+
namespace native {
18+
19+
using executorch::aten::IntArrayRef;
20+
using Tensor = executorch::aten::Tensor;
21+
using ScalarType = executorch::aten::ScalarType;
22+
23+
template <class CTYPE>
24+
void impl(CTYPE* data, int64_t numel, std::mt19937& gen, std::normal_distribution<double>& dist) {
25+
for (const auto i : c10::irange(numel)) {
26+
auto val = dist(gen);
27+
data[i] = static_cast<CTYPE>(val);
28+
}
29+
}
30+
31+
Tensor& randn_out(
32+
KernelRuntimeContext& ctx,
33+
const IntArrayRef sizes,
34+
Tensor& out) {
35+
(void)ctx;
36+
37+
std::mt19937 gen((std::random_device())());
38+
std::normal_distribution<double> dist(0.0, 1.0);
39+
40+
// Resize for dynamic shape
41+
ET_KERNEL_CHECK_MSG(
42+
ctx,
43+
resize_tensor(out, sizes) == Error::Ok,
44+
InvalidArgument,
45+
out,
46+
"Failed to resize output tensor.");
47+
48+
ET_SWITCH_FLOATHBF16_TYPES(out.scalar_type(), ctx, "randn.out", CTYPE, [&] {
49+
auto data_out = out.mutable_data_ptr<CTYPE>();
50+
impl(data_out, out.numel(), gen, dist);
51+
/*
52+
for (const auto i : c10::irange(out.numel())) {
53+
data_out[i] = static_cast<CTYPE>(dist(gen));
54+
}*/
55+
});
56+
57+
return out;
58+
}
59+
60+
} // namespace native
61+
} // namespace executor
62+
} // namespace torch
63+

kernels/portable/functions.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -713,6 +713,12 @@
713713
- arg_meta: null
714714
kernel_name: torch::executor::prod_out
715715

716+
- op: randn.out
717+
kernels:
718+
- arg_meta: null
719+
kernel_name: torch::executor::randn_out
720+
tags: nondeterministic_seeded
721+
716722
- op: reciprocal.out
717723
kernels:
718724
- arg_meta: null

kernels/test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,7 @@ set(all_test_sources
197197
"op_permute_copy_test.cpp"
198198
"op_pixel_shuffle_test.cpp"
199199
"op_prod_test.cpp"
200+
"op_randn_test.cpp"
200201
"op_reciprocal_test.cpp"
201202
"op_relu_test.cpp"
202203
"op_remainder_test.cpp"

kernels/test/op_randn_test.cpp

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <c10/util/irange.h>
10+
#include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
11+
#include <executorch/kernels/test/TestUtil.h>
12+
#include <executorch/kernels/test/supported_features.h>
13+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
14+
#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
15+
#include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
16+
17+
#include <gtest/gtest.h>
18+
19+
#include <cmath>
20+
#include <numeric>
21+
22+
using executorch::aten::IntArrayRef;
23+
using executorch::aten::ScalarType;
24+
using executorch::aten::Tensor;
25+
using torch::executor::testing::TensorFactory;
26+
27+
class OpRandnTest : public OperatorTest {
28+
protected:
29+
void op_randn_out(
30+
const IntArrayRef sizes,
31+
Tensor& out) {
32+
torch::executor::aten::randn_outf(
33+
context_, sizes, out);
34+
}
35+
36+
template <typename CTYPE, ScalarType DTYPE>
37+
void test_randn(std::vector<int64_t>& sizes) {
38+
TensorFactory<DTYPE> tf;
39+
40+
// Tensor factory wants int32 scales, op kernel wants int64.
41+
std::vector<int32_t> sizes_i32;
42+
std::transform(sizes.begin(), sizes.end(), std::back_inserter(sizes_i32),
43+
[] (int64_t s) { return static_cast<int32_t>(s); });
44+
Tensor out = tf.zeros(sizes_i32);
45+
46+
IntArrayRef sizes_ref(sizes.data(), sizes.size());
47+
op_randn_out(sizes_ref, out);
48+
49+
// Check mean and standard deviation. To avoid flaky CI, test pretty loosely.
50+
auto out_data = out.const_data_ptr<CTYPE>();
51+
double mean = std::accumulate(out_data, out_data + out.numel(), 0.0, [](double acc, CTYPE n) { return acc + static_cast<double>(n); }) / out.numel();
52+
double var = std::accumulate(out_data, out_data + out.numel(), 0.0,
53+
[=](double acc, CTYPE n) { return acc + std::pow(static_cast<double>(n) - mean, 2); }) / out.numel();
54+
auto stdev = std::sqrt(var);
55+
56+
// These are very rough thresholds. A better test implementation would probably do a proper
57+
// statistical test to compare the generated empirical data to the reference distribution, but
58+
// this should do for now.
59+
EXPECT_LE(std::abs(mean), 5.0 / std::sqrt(out.numel()));
60+
EXPECT_LE(std::abs(stdev - 1.0), 0.1);
61+
EXPECT_GT(stdev, 0);
62+
}
63+
};
64+
65+
TEST_F(OpRandnTest, SmokeTest) {
66+
std::vector<int64_t> sizes = {2, 3, 4, 128};
67+
68+
#define TEST_ENTRY(ctype, dtype) \
69+
test_randn<ctype, ScalarType::dtype>(sizes);
70+
ET_FORALL_FLOATHBF16_TYPES(TEST_ENTRY);
71+
#undef TEST_ENTRY
72+
}
73+
74+
TEST_F(OpRandnTest, Rank) {
75+
std::vector<int64_t> sizes = {1024};
76+
77+
for (int64_t i = 0; i < 4; i++) {
78+
sizes.push_back(i + 1);
79+
test_randn<float, executorch::aten::ScalarType::Float>(sizes);
80+
}
81+
}

kernels/test/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,7 @@ def define_common_targets():
285285
_common_op_test("op_pixel_unshuffle_test", ["aten", "portable"])
286286
_common_op_test("op_pow_test", ["aten", "portable"])
287287
_common_op_test("op_prod_test", ["aten", "portable"])
288+
_common_op_test("op_randn_test", ["aten", "portable"])
288289
_common_op_test("op_reciprocal_test", ["aten", "portable"])
289290
_common_op_test("op_relu_test", ["aten", "portable"])
290291
_common_op_test("op_remainder_test", ["aten", "portable"])

shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -973,6 +973,14 @@ ATEN_OPS = (
973973
"//executorch/kernels/portable/cpu/util:reduce_util",
974974
],
975975
),
976+
op_target(
977+
name = "op_randn",
978+
deps = [
979+
":scalar_utils",
980+
"//executorch/runtime/core/exec_aten/util:scalar_type_util",
981+
"//executorch/runtime/core/exec_aten/util:tensor_util",
982+
]
983+
),
976984
op_target(
977985
name = "op_reciprocal",
978986
deps = [

0 commit comments

Comments
 (0)