Skip to content

Commit 0acc09a

Browse files
authored
Revert "Support Half/BFloat16 in prod operator (#7857)"
This reverts commit 0d365a6.
1 parent 0d365a6 commit 0acc09a

File tree

2 files changed

+24
-44
lines changed

2 files changed

+24
-44
lines changed

kernels/portable/cpu/op_prod.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ Tensor& prod_out(
3333
ScalarType out_type = out.scalar_type();
3434
constexpr auto name = "prod.int_out";
3535

36-
ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, name, CTYPE_IN, [&] {
37-
ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, name, CTYPE_OUT, [&] {
36+
ET_SWITCH_REALHB_TYPES(in_type, ctx, name, CTYPE_IN, [&] {
37+
ET_SWITCH_REALHB_TYPES(out_type, ctx, name, CTYPE_OUT, [&] {
3838
const auto data_in = in.const_data_ptr<CTYPE_IN>();
3939
auto data_out = out.mutable_data_ptr<CTYPE_OUT>();
4040
data_out[0] = static_cast<CTYPE_OUT>(1);
@@ -73,8 +73,8 @@ Tensor& prod_int_out(
7373
ScalarType out_type = out.scalar_type();
7474
constexpr auto name = "prod.int_out";
7575

76-
ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, name, CTYPE_IN, [&] {
77-
ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, name, CTYPE_OUT, [&] {
76+
ET_SWITCH_REALHB_TYPES(in_type, ctx, name, CTYPE_IN, [&] {
77+
ET_SWITCH_REALHB_TYPES(out_type, ctx, name, CTYPE_OUT, [&] {
7878
CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
7979
for (size_t out_ix = 0; out_ix < out.numel(); ++out_ix) {
8080
CTYPE_OUT prod = 1;

kernels/test/op_prod_test.cpp

Lines changed: 20 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -45,24 +45,6 @@ class OpProdOutTest : public ::testing::Test {
4545
// first.
4646
torch::executor::runtime_init();
4747
}
48-
49-
template <ScalarType DTYPE>
50-
void test_dtype() {
51-
TensorFactory<DTYPE> tf;
52-
TensorFactory<
53-
executorch::runtime::isIntegralType(DTYPE, /*includeBool*/ true)
54-
? ScalarType::Long
55-
: DTYPE>
56-
tf_out;
57-
58-
Tensor self = tf.make({2, 3}, {1, 2, 3, 4, 5, 6});
59-
optional<ScalarType> dtype{};
60-
Tensor out = tf_out.zeros({});
61-
Tensor out_expected =
62-
tf_out.make({}, {DTYPE == ScalarType::Bool ? 1 : 720});
63-
op_prod_out(self, dtype, out);
64-
EXPECT_TENSOR_CLOSE(out, out_expected);
65-
}
6648
};
6749

6850
class OpProdIntOutTest : public ::testing::Test {
@@ -72,32 +54,30 @@ class OpProdIntOutTest : public ::testing::Test {
7254
// first.
7355
torch::executor::runtime_init();
7456
}
75-
76-
template <ScalarType DTYPE>
77-
void test_dtype() {
78-
TensorFactory<DTYPE> tf;
79-
80-
Tensor self = tf.make({2, 3}, {1, 2, 3, 4, 5, 6});
81-
int64_t dim = 0;
82-
bool keepdim = false;
83-
optional<ScalarType> dtype{};
84-
Tensor out = tf.zeros({3});
85-
Tensor out_expected = tf.make({3}, {4, 10, 18});
86-
op_prod_int_out(self, dim, keepdim, dtype, out);
87-
EXPECT_TENSOR_CLOSE(out, out_expected);
88-
}
8957
};
9058

91-
TEST_F(OpProdOutTest, SmokeTest){
92-
#define TEST_ENTRY(ctype, dtype) test_dtype<ScalarType::dtype>();
93-
ET_FORALL_REALHBBF16_TYPES(TEST_ENTRY)
94-
#undef TEST_ENTRY
59+
TEST_F(OpProdOutTest, SmokeTest) {
60+
TensorFactory<ScalarType::Float> tfFloat;
61+
62+
Tensor self = tfFloat.make({2, 3}, {1, 2, 3, 4, 5, 6});
63+
optional<ScalarType> dtype{};
64+
Tensor out = tfFloat.zeros({});
65+
Tensor out_expected = tfFloat.make({}, {720});
66+
op_prod_out(self, dtype, out);
67+
EXPECT_TENSOR_CLOSE(out, out_expected);
9568
}
9669

97-
TEST_F(OpProdIntOutTest, SmokeTest){
98-
#define TEST_ENTRY(ctype, dtype) test_dtype<ScalarType::dtype>();
99-
ET_FORALL_REALHBBF16_TYPES(TEST_ENTRY)
100-
#undef TEST_ENTRY
70+
TEST_F(OpProdIntOutTest, SmokeTest) {
71+
TensorFactory<ScalarType::Float> tfFloat;
72+
73+
Tensor self = tfFloat.make({2, 3}, {1, 2, 3, 4, 5, 6});
74+
int64_t dim = 0;
75+
bool keepdim = false;
76+
optional<ScalarType> dtype{};
77+
Tensor out = tfFloat.zeros({3});
78+
Tensor out_expected = tfFloat.make({3}, {4, 10, 18});
79+
op_prod_int_out(self, dim, keepdim, dtype, out);
80+
EXPECT_TENSOR_CLOSE(out, out_expected);
10181
}
10282

10383
TEST_F(OpProdIntOutTest, SmokeTestKeepdim) {

0 commit comments

Comments
 (0)