Skip to content

Commit 726292e

Browse files
pssrawatfacebook-github-bot
authored andcommitted
Add view_as_real_copy.out (pytorch#10207)
Summary: As title. Needed for multichannel ASR. Differential Revision: D72294238
1 parent 900b42c commit 726292e

File tree

9 files changed

+161
-0
lines changed

9 files changed

+161
-0
lines changed

kernels/aten/functions.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,8 @@
423423

424424
- op: var.out
425425

426+
- op: view_as_real_copy.out
427+
426428
- op: view_copy.out
427429

428430
- op: where.self_out
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/kernels/portable/cpu/util/copy_ops_util.h>
10+
#include <executorch/kernels/portable/cpu/util/functional_util.h>
11+
#include <executorch/runtime/kernel/kernel_includes.h>
12+
#include <executorch/runtime/platform/assert.h>
13+
14+
namespace torch {
15+
namespace executor {
16+
namespace native {
17+
18+
using Tensor = executorch::aten::Tensor;
19+
20+
template <typename SELF_CTYPE, typename OUT_CTYPE>
21+
inline void _to_impl(const Tensor& self, Tensor& out) {
22+
auto self_data = self.mutable_data_ptr<SELF_CTYPE>();
23+
auto out_data = out.mutable_data_ptr<OUT_CTYPE>();
24+
25+
for (size_t i = 0, e = self.numel(); i < e; i++) {
26+
auto val_in = self_data[i];
27+
out_data[2 * i] = static_cast<OUT_CTYPE>(val_in.real_);
28+
out_data[2 * i + 1] = static_cast<OUT_CTYPE>(val_in.imag_);
29+
}
30+
}
31+
32+
// view_as_real_copy(Tensor self) -> Tensor
33+
Tensor& view_as_real_copy_out(
34+
KernelRuntimeContext& ctx,
35+
const Tensor& self,
36+
Tensor& out) {
37+
(void)ctx;
38+
39+
// Get the output shape
40+
Tensor::SizesType expected_output_size[kTensorDimensionLimit];
41+
get_view_as_real_copy_out_target_size(self, expected_output_size);
42+
43+
// Resize for dynamic shape
44+
ET_KERNEL_CHECK_MSG(
45+
ctx,
46+
resize_tensor(
47+
out, {expected_output_size, static_cast<size_t>(out.dim())}) ==
48+
Error::Ok,
49+
InvalidArgument,
50+
out,
51+
"Failed to resize output tensor.");
52+
53+
// The input tensor must be complex type
54+
ET_KERNEL_CHECK_MSG(
55+
ctx,
56+
executorch::runtime::isComplexType(self.scalar_type()),
57+
InvalidArgument,
58+
out,
59+
"Input tensor must be complex type");
60+
61+
ET_KERNEL_CHECK(
62+
ctx, tensors_have_same_dim_order(self, out), InvalidArgument, out);
63+
64+
constexpr auto op_name = "view_as_real_copy.out";
65+
66+
ET_SWITCH_COMPLEXH_TYPES(self.scalar_type(), ctx, op_name, CTYPE_IN, [&] {
67+
ET_SWITCH_FLOATH_TYPES(out.scalar_type(), ctx, op_name, CTYPE_OUT, [&] {
68+
_to_impl<CTYPE_IN, CTYPE_OUT>(self, out);
69+
});
70+
});
71+
72+
return out;
73+
}
74+
75+
} // namespace native
76+
} // namespace executor
77+
} // namespace torch

kernels/portable/cpu/util/copy_ops_util.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1018,5 +1018,14 @@ void get_unfold_copy_out_target_size(
10181018
*out_ndim = self.dim() + 1;
10191019
}
10201020

1021+
void get_view_as_real_copy_out_target_size(
1022+
const Tensor& self,
1023+
executorch::aten::SizesType* out_sizes) {
1024+
for (auto i : c10::irange(self.dim())) {
1025+
out_sizes[i] = self.size(i);
1026+
}
1027+
out_sizes[self.dim()] = 2;
1028+
}
1029+
10211030
} // namespace executor
10221031
} // namespace torch

kernels/portable/cpu/util/copy_ops_util.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,5 +247,9 @@ void get_unfold_copy_out_target_size(
247247
executorch::aten::SizesType* out_sizes,
248248
size_t* out_ndim);
249249

250+
void get_view_as_real_copy_out_target_size(
251+
const Tensor& self,
252+
executorch::aten::SizesType* out_sizes);
253+
250254
} // namespace executor
251255
} // namespace torch

kernels/portable/functions.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -957,6 +957,11 @@
957957
- arg_meta: null
958958
kernel_name: torch::executor::var_out
959959

960+
- op: view_as_real_copy.out
961+
kernels:
962+
- arg_meta: null
963+
kernel_name: torch::executor::view_as_real_copy_out
964+
960965
- op: view_copy.out
961966
kernels:
962967
- arg_meta: null

kernels/test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,7 @@ set(all_test_sources
242242
"op_upsample_bilinear2d_test.cpp"
243243
"op_upsample_nearest2d_test.cpp"
244244
"op_var_test.cpp"
245+
"op_view_as_real_copy_test.cpp"
245246
"op_view_copy_test.cpp"
246247
"op_where_test.cpp"
247248
"op_zeros_test.cpp"
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
10+
#include <executorch/kernels/test/TestUtil.h>
11+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
12+
#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
13+
#include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
14+
15+
#include <gtest/gtest.h>
16+
17+
using namespace ::testing;
18+
using executorch::aten::ScalarType;
19+
using executorch::aten::Tensor;
20+
using torch::executor::testing::TensorFactory;
21+
22+
class OpViewAsRealTest : public OperatorTest {
23+
protected:
24+
Tensor& view_as_real_copy_out(const Tensor& self, Tensor& out) {
25+
return torch::executor::aten::view_as_real_copy_outf(context_, self, out);
26+
}
27+
28+
template <typename CTYPE, ScalarType DTYPE>
29+
void run_complex_smoke_test() {
30+
TensorFactory<DTYPE> tf;
31+
constexpr auto REAL_DTYPE = executorch::runtime::toRealValueType(DTYPE);
32+
TensorFactory<REAL_DTYPE> tf_out;
33+
using REAL_CTYPE =
34+
typename executorch::runtime::ScalarTypeToCppType<REAL_DTYPE>::type;
35+
Tensor in = tf.make(
36+
{2, 2},
37+
{CTYPE{REAL_CTYPE(3), REAL_CTYPE(4)},
38+
CTYPE{REAL_CTYPE(-1.7), REAL_CTYPE(7.4)},
39+
CTYPE{REAL_CTYPE(5), REAL_CTYPE(-12)},
40+
CTYPE{REAL_CTYPE(8.3), REAL_CTYPE(0.1)}});
41+
Tensor out = tf_out.zeros({2, 2, 2});
42+
Tensor expected =
43+
tf_out.make({2, 2, 2}, {3, 4, -1.7, 7.4, 5, -12, 8.3, 0.1});
44+
Tensor ret = view_as_real_copy_out(in, out);
45+
EXPECT_TENSOR_EQ(out, ret);
46+
EXPECT_TENSOR_EQ(out, expected);
47+
}
48+
};
49+
50+
TEST_F(OpViewAsRealTest, ComplexSmokeTest) {
51+
#define RUN_SMOKE_TEST(ctype, dtype) \
52+
run_complex_smoke_test<ctype, ScalarType::dtype>();
53+
ET_FORALL_COMPLEXH_TYPES(RUN_SMOKE_TEST);
54+
#undef RUN_SMOKE_TEST
55+
}

kernels/test/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,7 @@ def define_common_targets():
331331
_common_op_test("op_upsample_bilinear2d_test", ["aten", "portable"])
332332
_common_op_test("op_upsample_nearest2d_test", ["aten", "portable"])
333333
_common_op_test("op_var_test", ["aten", "portable"])
334+
_common_op_test("op_view_as_real_copy_test", ["aten", "portable"])
334335
_common_op_test("op_view_copy_test", ["aten", "portable"])
335336
_common_op_test("op_where_test", ["aten", "portable"])
336337
_common_op_test("op_zeros_test", ["aten", "portable"])

shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1269,6 +1269,13 @@ ATEN_OPS = (
12691269
"//executorch/kernels/portable/cpu/util:reduce_util",
12701270
],
12711271
),
1272+
op_target(
1273+
name = "op_view_as_real_copy",
1274+
deps = [
1275+
"//executorch/kernels/portable/cpu/util:functional_util",
1276+
"//executorch/kernels/portable/cpu/util:copy_ops_util",
1277+
],
1278+
),
12721279
op_target(
12731280
name = "op_view_copy",
12741281
deps = [

0 commit comments

Comments
 (0)