Skip to content

Commit 3a76a8b

Browse files
authored
Add bitwise left/right (pytorch#15893)
Fixes pytorch#15673 (comment) Did not add as core aten op. But used "func" and enabled exporting via: ``` to_edge_transform_and_lower( exported_program, partitioner=None, compile_config=EdgeCompileConfig( _core_aten_ops_exception_list=[ torch.ops.aten.bitwise_left_shift.Tensor, torch.ops.aten.bitwise_left_shift.Tensor_Scalar, torch.ops.aten.bitwise_right_shift.Tensor, torch.ops.aten.bitwise_right_shift.Tensor_Scalar, ] ), ) ``` In the future, we may promote to core aten op.
1 parent e793135 commit 3a76a8b

File tree

11 files changed

+600
-0
lines changed

11 files changed

+600
-0
lines changed
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/kernels/portable/cpu/pattern/bitwise_op.h>
10+
11+
namespace torch {
12+
namespace executor {
13+
namespace native {
14+
15+
Tensor& bitwise_left_shift_Tensor_out(
16+
KernelRuntimeContext& ctx,
17+
const Tensor& a,
18+
const Tensor& b,
19+
Tensor& out) {
20+
// @lint-ignore CLANGTIDY facebook-hte-CArray
21+
static constexpr const char op_name[] = "bitwise_left_shift.Tensor_out";
22+
return internal::bitwise_tensor_out<internal::bit_lshift, op_name>(
23+
ctx, a, b, out);
24+
}
25+
26+
Tensor& bitwise_left_shift_Tensor_Scalar_out(
27+
KernelRuntimeContext& ctx,
28+
const Tensor& a,
29+
const Scalar& b,
30+
Tensor& out) {
31+
// @lint-ignore CLANGTIDY facebook-hte-CArray
32+
static constexpr const char op_name[] =
33+
"bitwise_left_shift.Tensor_Scalar_out";
34+
return internal::bitwise_scalar_out<internal::bit_lshift, op_name>(
35+
ctx, a, b, out);
36+
}
37+
38+
} // namespace native
39+
} // namespace executor
40+
} // namespace torch
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/kernels/portable/cpu/pattern/bitwise_op.h>
10+
11+
namespace torch {
12+
namespace executor {
13+
namespace native {
14+
15+
Tensor& bitwise_right_shift_Tensor_out(
16+
KernelRuntimeContext& ctx,
17+
const Tensor& a,
18+
const Tensor& b,
19+
Tensor& out) {
20+
// @lint-ignore CLANGTIDY facebook-hte-CArray
21+
static constexpr const char op_name[] = "bitwise_right_shift.Tensor_out";
22+
return internal::bitwise_tensor_out<internal::bit_rshift, op_name>(
23+
ctx, a, b, out);
24+
}
25+
26+
Tensor& bitwise_right_shift_Tensor_Scalar_out(
27+
KernelRuntimeContext& ctx,
28+
const Tensor& a,
29+
const Scalar& b,
30+
Tensor& out) {
31+
// @lint-ignore CLANGTIDY facebook-hte-CArray
32+
static constexpr const char op_name[] =
33+
"bitwise_right_shift.Tensor_Scalar_out";
34+
return internal::bitwise_scalar_out<internal::bit_rshift, op_name>(
35+
ctx, a, b, out);
36+
}
37+
38+
} // namespace native
39+
} // namespace executor
40+
} // namespace torch

kernels/portable/cpu/pattern/bitwise_op.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,21 @@ DEFINE_BINARY_OPERATOR_TEMPLATE(bitwise_and, &)
2727
DEFINE_BINARY_OPERATOR_TEMPLATE(bitwise_or, |)
2828
DEFINE_BINARY_OPERATOR_TEMPLATE(bitwise_xor, ^)
2929

30+
// Functor wrappers for shift operations (similar to std::bit_and, etc.)
31+
template <typename T = void>
32+
struct bit_lshift {
33+
constexpr T operator()(const T& lhs, const T& rhs) const {
34+
return static_cast<T>(lhs << rhs);
35+
}
36+
};
37+
38+
template <typename T = void>
39+
struct bit_rshift {
40+
constexpr T operator()(const T& lhs, const T& rhs) const {
41+
return static_cast<T>(lhs >> rhs);
42+
}
43+
};
44+
3045
template <typename T>
3146
using bitwise_fn = T (*)(const T, const T);
3247

kernels/portable/functions.yaml

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,26 @@
207207
- arg_meta: null
208208
kernel_name: torch::executor::bitwise_xor_Tensor_out
209209

210+
- func: bitwise_left_shift.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
211+
kernels:
212+
- arg_meta: null
213+
kernel_name: torch::executor::bitwise_left_shift_Tensor_out
214+
215+
- func: bitwise_left_shift.Tensor_Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
216+
kernels:
217+
- arg_meta: null
218+
kernel_name: torch::executor::bitwise_left_shift_Tensor_Scalar_out
219+
220+
- func: bitwise_right_shift.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
221+
kernels:
222+
- arg_meta: null
223+
kernel_name: torch::executor::bitwise_right_shift_Tensor_out
224+
225+
- func: bitwise_right_shift.Tensor_Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
226+
kernels:
227+
- arg_meta: null
228+
kernel_name: torch::executor::bitwise_right_shift_Tensor_Scalar_out
229+
210230
- op: bmm.out
211231
kernels:
212232
- arg_meta: null

kernels/test/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,10 @@ set(all_test_sources
147147
"op_atanh_test.cpp"
148148
"op_avg_pool2d_test.cpp"
149149
"op_bitwise_and_test.cpp"
150+
"op_bitwise_left_shift_test.cpp"
150151
"op_bitwise_not_test.cpp"
151152
"op_bitwise_or_test.cpp"
153+
"op_bitwise_right_shift_test.cpp"
152154
"op_bitwise_xor_test.cpp"
153155
"op_bmm_test.cpp"
154156
"op_cat_test.cpp"

kernels/test/TARGETS

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,23 @@ oncall("executorch")
99

1010
define_common_targets()
1111

12+
python_unittest(
13+
name = "test_bitwise_shift",
14+
srcs = ["test_bitwise_shift.py"],
15+
preload_deps = [
16+
"//executorch/kernels/portable:custom_ops_generated_lib",
17+
],
18+
deps = [
19+
"//caffe2:torch",
20+
"//executorch/exir:lib",
21+
"//executorch/extension/export_util:export_util",
22+
"//executorch/runtime:runtime",
23+
],
24+
env = {
25+
"PYTORCH_DISABLE_JUSTKNOBS": "1",
26+
},
27+
)
28+
1229
python_unittest(
1330
name = "gen_supported_features_test",
1431
srcs = ["gen_supported_features_test.py"],
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
10+
#include <executorch/kernels/test/TestUtil.h>
11+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
12+
#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
13+
#include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
14+
15+
#include <gtest/gtest.h>
16+
17+
using namespace ::testing;
18+
using executorch::aten::Scalar;
19+
using executorch::aten::ScalarType;
20+
using executorch::aten::Tensor;
21+
using torch::executor::testing::TensorFactory;
22+
23+
class OpBitwiseLeftShiftTensorOutTest : public OperatorTest {
24+
protected:
25+
Tensor& op_bitwise_left_shift_tensor_out(
26+
const Tensor& self,
27+
const Tensor& other,
28+
Tensor& out) {
29+
return torch::executor::aten::bitwise_left_shift_outf(
30+
context_, self, other, out);
31+
}
32+
};
33+
34+
class OpBitwiseLeftShiftScalarOutTest : public OperatorTest {
35+
protected:
36+
Tensor& op_bitwise_left_shift_scalar_out(
37+
const Tensor& self,
38+
const Scalar& other,
39+
Tensor& out) {
40+
return torch::executor::aten::bitwise_left_shift_outf(
41+
context_, self, other, out);
42+
}
43+
};
44+
45+
TEST_F(OpBitwiseLeftShiftTensorOutTest, SmokeTestInt) {
46+
TensorFactory<ScalarType::Int> tf;
47+
48+
// Test basic left shift: [1, 2, 4, 8] << [1, 2, 1, 2] = [2, 8, 8, 32]
49+
Tensor a = tf.make({2, 2}, {1, 2, 4, 8});
50+
Tensor b = tf.make({2, 2}, {1, 2, 1, 2});
51+
52+
Tensor out = tf.zeros({2, 2});
53+
54+
op_bitwise_left_shift_tensor_out(a, b, out);
55+
EXPECT_TENSOR_EQ(out, tf.make({2, 2}, {2, 8, 8, 32}));
56+
}
57+
58+
TEST_F(OpBitwiseLeftShiftTensorOutTest, SmokeTestByte) {
59+
TensorFactory<ScalarType::Byte> tf;
60+
61+
// Test with byte values: [1, 5, 10, 15] << [0, 1, 2, 3] = [1, 10, 40, 120]
62+
Tensor a = tf.make({2, 2}, {1, 5, 10, 15});
63+
Tensor b = tf.make({2, 2}, {0, 1, 2, 3});
64+
65+
Tensor out = tf.zeros({2, 2});
66+
67+
op_bitwise_left_shift_tensor_out(a, b, out);
68+
EXPECT_TENSOR_EQ(out, tf.make({2, 2}, {1, 10, 40, 120}));
69+
}
70+
71+
TEST_F(OpBitwiseLeftShiftTensorOutTest, ZeroShift) {
72+
TensorFactory<ScalarType::Int> tf;
73+
74+
// Shifting by 0 should return the original value
75+
Tensor a = tf.make({2, 2}, {5, 10, 15, 20});
76+
Tensor b = tf.zeros({2, 2});
77+
78+
Tensor out = tf.zeros({2, 2});
79+
80+
op_bitwise_left_shift_tensor_out(a, b, out);
81+
EXPECT_TENSOR_EQ(out, tf.make({2, 2}, {5, 10, 15, 20}));
82+
}
83+
84+
TEST_F(OpBitwiseLeftShiftScalarOutTest, SmokeTestInt) {
85+
TensorFactory<ScalarType::Int> tf;
86+
87+
// Test shifting by scalar: [1, 2, 3, 4] << 2 = [4, 8, 12, 16]
88+
Tensor a = tf.make({2, 2}, {1, 2, 3, 4});
89+
Scalar b = 2;
90+
91+
Tensor out = tf.zeros({2, 2});
92+
93+
op_bitwise_left_shift_scalar_out(a, b, out);
94+
EXPECT_TENSOR_EQ(out, tf.make({2, 2}, {4, 8, 12, 16}));
95+
}
96+
97+
TEST_F(OpBitwiseLeftShiftScalarOutTest, ShiftByOne) {
98+
TensorFactory<ScalarType::Int> tf;
99+
100+
// Shifting by 1 should double the value
101+
Tensor a = tf.make({2, 2}, {1, 5, 10, 100});
102+
Scalar b = 1;
103+
104+
Tensor out = tf.zeros({2, 2});
105+
106+
op_bitwise_left_shift_scalar_out(a, b, out);
107+
EXPECT_TENSOR_EQ(out, tf.make({2, 2}, {2, 10, 20, 200}));
108+
}
109+
110+
TEST_F(OpBitwiseLeftShiftScalarOutTest, ShiftByZero) {
111+
TensorFactory<ScalarType::Int> tf;
112+
113+
// Shifting by 0 should return the original value
114+
Tensor a = tf.make({2, 2}, {7, 14, 21, 28});
115+
Scalar b = 0;
116+
117+
Tensor out = tf.zeros({2, 2});
118+
119+
op_bitwise_left_shift_scalar_out(a, b, out);
120+
EXPECT_TENSOR_EQ(out, tf.make({2, 2}, {7, 14, 21, 28}));
121+
}

0 commit comments

Comments
 (0)