Skip to content

Commit a512f05

Browse files
georgehongfacebook-github-bot
authored andcommitted
Refactor op_cat to use util pattern
Summary: Refactor cat op and helper functions to be accessible in a cat-specific target, exposing functionality to code that requires it. Differential Revision: D79599708
1 parent b054e8d commit a512f05

File tree

5 files changed

+166
-73
lines changed

5 files changed

+166
-73
lines changed

kernels/portable/cpu/op_cat.cpp

Lines changed: 2 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include <cstring>
1010

11+
#include <executorch/kernels/portable/cpu/util/cat_util.h>
1112
#include <executorch/kernels/portable/cpu/util/copy_ops_util.h>
1213
#include <executorch/runtime/kernel/kernel_includes.h>
1314

@@ -37,79 +38,7 @@ Tensor& cat_out(
3738
resize_tensor(out, {expected_out_size, expected_out_dim}) == Error::Ok,
3839
InvalidArgument,
3940
out);
40-
41-
// Special handling when all inputs are 1D-empty tensors for aten consistency
42-
// In that case, just return an 1D-empty tensor without checking dim
43-
bool all_1d_empty = true;
44-
for (size_t i = 0; i < tensors.size(); ++i) {
45-
if (tensors[i].numel() != 0 || tensors[i].dim() != 1) {
46-
all_1d_empty = false;
47-
break;
48-
}
49-
}
50-
if (all_1d_empty) {
51-
return out;
52-
}
53-
54-
const size_t outer = getLeadingDims(out, dim);
55-
const size_t dim_stride = getTrailingDims(out, dim);
56-
const size_t ninputs = tensors.size();
57-
58-
const auto out_type = out.scalar_type();
59-
const bool out_is_complex =
60-
executorch::runtime::isComplexType(out.scalar_type());
61-
62-
if (out_is_complex) {
63-
// TODO: The current support for complex dtype enforces that input and
64-
// output tensors have the same dtype. Support mixed dtypes in the future.
65-
for (size_t i = 0; i < ninputs; ++i) {
66-
const auto in_type = tensors[i].scalar_type();
67-
ET_KERNEL_CHECK(ctx, out_type == in_type, InvalidArgument, out);
68-
}
69-
ET_SWITCH_COMPLEXH_TYPES(out_type, ctx, "cat.out", CTYPE, [&] {
70-
CTYPE* out_ptr = out.mutable_data_ptr<CTYPE>();
71-
for (size_t i = 0; i < outer; ++i) {
72-
for (size_t j = 0; j < ninputs; ++j) {
73-
if (tensors[j].numel() == 0) {
74-
return;
75-
}
76-
size_t inner = tensors[j].size(dim) * dim_stride;
77-
const CTYPE* const in_ptr =
78-
tensors[j].const_data_ptr<CTYPE>() + i * inner;
79-
memcpy(out_ptr, in_ptr, inner * sizeof(CTYPE));
80-
out_ptr += inner;
81-
}
82-
}
83-
});
84-
} else {
85-
ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, "cat.out", CTYPE_OUT, [&] {
86-
CTYPE_OUT* out_ptr = out.mutable_data_ptr<CTYPE_OUT>();
87-
for (size_t i = 0; i < outer; ++i) {
88-
for (size_t j = 0; j < ninputs; ++j) {
89-
const auto in_type = tensors[j].scalar_type();
90-
ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, "cat.out", CTYPE_IN, [&] {
91-
if (tensors[j].numel() == 0) {
92-
return;
93-
}
94-
size_t inner = tensors[j].size(dim) * dim_stride;
95-
const CTYPE_IN* const in_ptr =
96-
tensors[j].const_data_ptr<CTYPE_IN>() + i * inner;
97-
98-
if (sizeof(CTYPE_IN) == sizeof(CTYPE_OUT)) {
99-
memcpy(out_ptr, in_ptr, inner * sizeof(CTYPE_IN));
100-
} else {
101-
for (size_t k = 0; k < inner; ++k) {
102-
out_ptr[k] = static_cast<CTYPE_OUT>(in_ptr[k]);
103-
}
104-
}
105-
out_ptr += inner;
106-
});
107-
}
108-
}
109-
});
110-
}
111-
112-
return out;
41+
return cat_out_impl(ctx, tensors, dim, out);
11342
}
11443

11544
} // namespace native
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/kernels/portable/cpu/util/cat_util.h>
10+
#include <executorch/kernels/portable/cpu/util/copy_ops_util.h>
11+
12+
namespace torch::executor::native {
13+
14+
bool check_cat_args(
15+
executorch::aten::ArrayRef<Tensor> tensors,
16+
int64_t dim,
17+
Tensor& out) {
18+
return torch::executor::check_cat_args(
19+
tensors, dim, out);
20+
}
21+
22+
void get_cat_out_target_size(
23+
executorch::aten::ArrayRef<Tensor> tensors,
24+
int64_t dim,
25+
executorch::aten::SizesType* out_sizes,
26+
size_t* out_ndim) {
27+
torch::executor::get_cat_out_target_size(
28+
tensors, dim, out_sizes, out_ndim);
29+
}
30+
31+
Tensor& cat_out_impl(
32+
KernelRuntimeContext& ctx,
33+
executorch::aten::ArrayRef<Tensor> tensors,
34+
int64_t dim,
35+
Tensor& out) {
36+
// Special handling when all inputs are 1D-empty tensors for aten consistency
37+
// In that case, just return an 1D-empty tensor without checking dim
38+
bool all_1d_empty = true;
39+
for (size_t i = 0; i < tensors.size(); ++i) {
40+
if (tensors[i].numel() != 0 || tensors[i].dim() != 1) {
41+
all_1d_empty = false;
42+
break;
43+
}
44+
}
45+
if (all_1d_empty) {
46+
return out;
47+
}
48+
49+
const size_t outer = getLeadingDims(out, dim);
50+
const size_t dim_stride = getTrailingDims(out, dim);
51+
const size_t ninputs = tensors.size();
52+
53+
const auto out_type = out.scalar_type();
54+
const bool out_is_complex =
55+
executorch::runtime::isComplexType(out.scalar_type());
56+
57+
if (out_is_complex) {
58+
// TODO: The current support for complex dtype enforces that input and
59+
// output tensors have the same dtype. Support mixed dtypes in the future.
60+
for (size_t i = 0; i < ninputs; ++i) {
61+
const auto in_type = tensors[i].scalar_type();
62+
ET_KERNEL_CHECK(ctx, out_type == in_type, InvalidArgument, out);
63+
}
64+
ET_SWITCH_COMPLEXH_TYPES(out_type, ctx, "cat.out", CTYPE, [&] {
65+
CTYPE* out_ptr = out.mutable_data_ptr<CTYPE>();
66+
for (size_t i = 0; i < outer; ++i) {
67+
for (size_t j = 0; j < ninputs; ++j) {
68+
if (tensors[j].numel() == 0) {
69+
return;
70+
}
71+
size_t inner = tensors[j].size(dim) * dim_stride;
72+
const CTYPE* const in_ptr =
73+
tensors[j].const_data_ptr<CTYPE>() + i * inner;
74+
memcpy(out_ptr, in_ptr, inner * sizeof(CTYPE));
75+
out_ptr += inner;
76+
}
77+
}
78+
});
79+
} else {
80+
ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, "cat.out", CTYPE_OUT, [&] {
81+
CTYPE_OUT* out_ptr = out.mutable_data_ptr<CTYPE_OUT>();
82+
for (size_t i = 0; i < outer; ++i) {
83+
for (size_t j = 0; j < ninputs; ++j) {
84+
const auto in_type = tensors[j].scalar_type();
85+
ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, "cat.out", CTYPE_IN, [&] {
86+
if (tensors[j].numel() == 0) {
87+
return;
88+
}
89+
size_t inner = tensors[j].size(dim) * dim_stride;
90+
const CTYPE_IN* const in_ptr =
91+
tensors[j].const_data_ptr<CTYPE_IN>() + i * inner;
92+
93+
if (sizeof(CTYPE_IN) == sizeof(CTYPE_OUT)) {
94+
memcpy(out_ptr, in_ptr, inner * sizeof(CTYPE_IN));
95+
} else {
96+
for (size_t k = 0; k < inner; ++k) {
97+
out_ptr[k] = static_cast<CTYPE_OUT>(in_ptr[k]);
98+
}
99+
}
100+
out_ptr += inner;
101+
});
102+
}
103+
}
104+
});
105+
}
106+
return out;
107+
}
108+
} // namespace torch::executor::native
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/runtime/kernel/kernel_includes.h>
12+
13+
namespace torch::executor::native {
14+
15+
bool check_cat_args(
16+
executorch::aten::ArrayRef<Tensor> tensors,
17+
int64_t dim,
18+
Tensor& out);
19+
20+
void get_cat_out_target_size(
21+
executorch::aten::ArrayRef<Tensor> tensors,
22+
int64_t dim,
23+
executorch::aten::SizesType* out_sizes,
24+
size_t* out_ndim);
25+
26+
Tensor& cat_out_impl(
27+
KernelRuntimeContext& ctx,
28+
executorch::aten::ArrayRef<Tensor> tensors,
29+
int64_t dim,
30+
Tensor& out);
31+
32+
inline Tensor& cat_out_impl(
33+
executorch::aten::ArrayRef<Tensor> tensors,
34+
int64_t dim,
35+
Tensor& out) {
36+
KernelRuntimeContext ctx;
37+
return cat_out_impl(ctx, tensors, dim, out);
38+
}
39+
40+
} // namespace torch::executor::native

kernels/portable/cpu/util/targets.bzl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def define_common_targets():
1919
"//executorch/kernels/portable/cpu/util:kernel_ops_util",
2020
"//executorch/kernels/portable/cpu:vec_ops",
2121
"//executorch/kernels/portable/cpu/util:matmul_ops_util",
22+
"//executorch/kernels/portable/cpu/util:cat_util",
2223
"//executorch/kernels/portable/cpu/util:copy_ops_util",
2324
"//executorch/kernels/portable/cpu/util:transpose_util",
2425
"//executorch/kernels/portable/cpu/util:index_util",
@@ -302,6 +303,20 @@ def define_common_targets():
302303
visibility = ["//executorch/kernels/portable/cpu/..."],
303304
)
304305

306+
runtime.cxx_library(
307+
name = "cat_util",
308+
srcs = ["cat_util.cpp"],
309+
exported_headers = ["cat_util.h"],
310+
deps = [
311+
"//executorch/runtime/kernel:kernel_includes",
312+
"//executorch/kernels/portable/cpu/util:copy_ops_util",
313+
],
314+
visibility = [
315+
"//executorch/kernels/portable/cpu/...",
316+
"@EXECUTORCH_CLIENTS",
317+
],
318+
)
319+
305320
runtime.cxx_library(
306321
name = "broadcast_indexes_range",
307322
exported_headers = ["broadcast_indexes_range.h"],

shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,7 @@ ATEN_OPS = (
378378
op_target(
379379
name = "op_cat",
380380
deps = [
381+
"//executorch/kernels/portable/cpu/util:cat_util",
381382
"//executorch/kernels/portable/cpu/util:copy_ops_util",
382383
],
383384
),

0 commit comments

Comments
 (0)