Skip to content

Commit 6a72709

Browse files
pytorchbotlarryliu0820
authored andcommitted
[llm] Add arange() tensor maker API (pytorch#11965)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: pytorch#11861 by @larryliu0820 ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/larryliu0820/66/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/larryliu0820/66/head Merge bot PR base: https://github.com/pytorch/executorch/tree/main Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/larryliu0820/66/orig @diff-train-skip-merge Co-authored-by: Mengwei Liu <[email protected]>
1 parent d77f9f3 commit 6a72709

File tree

5 files changed

+113
-20
lines changed

5 files changed

+113
-20
lines changed

kernels/portable/cpu/op_arange.cpp

Lines changed: 6 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
*/
88

99
#include <executorch/kernels/portable/cpu/scalar_utils.h>
10+
#include <executorch/kernels/portable/cpu/util/arange_util.h>
1011
#include <executorch/kernels/portable/cpu/util/kernel_ops_util.h>
1112
#include <executorch/runtime/kernel/kernel_includes.h>
1213
#include <executorch/runtime/platform/assert.h>
@@ -29,22 +30,15 @@ Tensor& arange_out(KernelRuntimeContext& ctx, const Scalar& end, Tensor& out) {
2930

3031
ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(out), InvalidArgument, out);
3132

32-
size_t size = static_cast<size_t>(std::ceil(end_val));
33-
34-
Tensor::SizesType out_length = static_cast<Tensor::SizesType>(size);
33+
Tensor::SizesType out_length = compute_arange_out_size(0.0, end_val, 1.0);
3534

3635
ET_KERNEL_CHECK(
3736
ctx,
3837
resize_tensor(out, {&out_length, 1}) == Error::Ok,
3938
InvalidArgument,
4039
out);
4140

42-
ET_SWITCH_REALHBF16_TYPES(out.scalar_type(), ctx, "arange.out", CTYPE, [&]() {
43-
auto out_data = out.mutable_data_ptr<CTYPE>();
44-
for (size_t i = 0; i < size; i++) {
45-
out_data[i] = static_cast<CTYPE>(i);
46-
}
47-
});
41+
arange_out_impl(ctx, end_val, out);
4842

4943
return out;
5044
}
@@ -77,24 +71,16 @@ Tensor& arange_start_out(
7771

7872
ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(out), InvalidArgument, out);
7973

80-
double size_d = (d_end - d_start) / d_step;
81-
size_t size = static_cast<size_t>(std::ceil(size_d));
82-
83-
Tensor::SizesType out_length = static_cast<Tensor::SizesType>(size);
74+
Tensor::SizesType out_length =
75+
compute_arange_out_size(d_start, d_end, d_step);
8476

8577
ET_KERNEL_CHECK(
8678
ctx,
8779
resize_tensor(out, {&out_length, 1}) == Error::Ok,
8880
InvalidArgument,
8981
out);
9082

91-
ET_SWITCH_REALHBF16_TYPES(
92-
out.scalar_type(), ctx, "arange.start_out", CTYPE, [&]() {
93-
auto out_data = out.mutable_data_ptr<CTYPE>();
94-
for (size_t i = 0; i < size; i++) {
95-
out_data[i] = convert<CTYPE, double>(d_start + i * d_step);
96-
}
97-
});
83+
arange_out_impl(ctx, d_start, d_end, d_step, out);
9884

9985
return out;
10086
}
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/kernels/portable/cpu/util/arange_util.h>
10+
11+
namespace torch::executor::native {
12+
#define ET_ARANGE_IMPL(ctx, start, numel, step, out, op_name) \
13+
ET_SWITCH_REALHBF16_TYPES(out.scalar_type(), ctx, op_name, CTYPE, [&]() { \
14+
auto out_data = out.mutable_data_ptr<CTYPE>(); \
15+
for (Tensor::SizesType i = 0; i < numel; ++i) { \
16+
out_data[i] = static_cast<CTYPE>(start + i * step); \
17+
} \
18+
})
19+
20+
Tensor::SizesType
21+
compute_arange_out_size(double start, double end, double step) {
22+
Tensor::SizesType numel =
23+
static_cast<Tensor::SizesType>(std::ceil((end - start) / step));
24+
25+
ET_CHECK_MSG(
26+
numel >= 0,
27+
"numel should be non-negative, but got (%d). start (%f), end (%f), step (%f)",
28+
numel,
29+
start,
30+
end,
31+
step);
32+
return numel;
33+
}
34+
35+
void arange_out_impl(
36+
KernelRuntimeContext& ctx,
37+
double start,
38+
double end,
39+
double step,
40+
Tensor& out) {
41+
(void)ctx;
42+
Tensor::SizesType numel = compute_arange_out_size(start, end, step);
43+
ET_ARANGE_IMPL(ctx, start, numel, step, out, "arange.start_out");
44+
}
45+
46+
void arange_out_impl(KernelRuntimeContext& ctx, double end, Tensor& out) {
47+
(void)ctx;
48+
ET_ARANGE_IMPL(ctx, 0.0, end, 1.0, out, "arange.out");
49+
}
50+
51+
} // namespace torch::executor::native
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/runtime/kernel/kernel_includes.h>
12+
13+
namespace torch::executor::native {
14+
15+
Tensor::SizesType
16+
compute_arange_out_size(double start, double end, double step);
17+
18+
inline Tensor::SizesType compute_arange_out_size(double end) {
19+
return compute_arange_out_size(0.0, end, 1.0);
20+
}
21+
22+
void arange_out_impl(
23+
KernelRuntimeContext& ctx,
24+
double start,
25+
double end,
26+
double step,
27+
Tensor& out);
28+
29+
void arange_out_impl(KernelRuntimeContext& ctx, double end, Tensor& out);
30+
31+
inline void
32+
arange_out_impl(double start, double end, double step, Tensor& out) {
33+
KernelRuntimeContext ctx;
34+
arange_out_impl(ctx, start, end, step, out);
35+
}
36+
37+
inline void arange_out_impl(double end, Tensor& out) {
38+
KernelRuntimeContext ctx;
39+
arange_out_impl(ctx, 0.0, end, 1.0, out);
40+
}
41+
} // namespace torch::executor::native

kernels/portable/cpu/util/targets.bzl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ def define_common_targets():
1313
name = "all_deps",
1414
exported_deps = [
1515
"//executorch/extension/threadpool:threadpool",
16+
"//executorch/kernels/portable/cpu/util:arange_util",
1617
"//executorch/kernels/portable/cpu/util:functional_util",
1718
"//executorch/kernels/portable/cpu/util:broadcast_util",
1819
"//executorch/kernels/portable/cpu/util:kernel_ops_util",
@@ -300,6 +301,19 @@ def define_common_targets():
300301
visibility = ["//executorch/kernels/portable/cpu/..."],
301302
)
302303

304+
runtime.cxx_library(
305+
name = "arange_util",
306+
srcs = ["arange_util.cpp"],
307+
exported_headers = ["arange_util.h"],
308+
deps = [
309+
"//executorch/runtime/kernel:kernel_includes",
310+
],
311+
visibility = [
312+
"//executorch/kernels/portable/cpu/...",
313+
"//executorch/extension/llm/...",
314+
],
315+
)
316+
303317
runtime.cxx_library(
304318
name = "broadcast_indexes_range",
305319
exported_headers = ["broadcast_indexes_range.h"],

shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,7 @@ ATEN_OPS = (
268268
op_target(
269269
name = "op_arange",
270270
deps = [
271+
"//executorch/kernels/portable/cpu/util:arange_util",
271272
"//executorch/kernels/portable/cpu/util:kernel_ops_util",
272273
"//executorch/runtime/core/exec_aten/util:scalar_type_util",
273274
"//executorch/runtime/core/exec_aten/util:tensor_util",

0 commit comments

Comments
 (0)