Skip to content

Commit 2737333

Browse files
committed
Update
[ghstack-poisoned]
1 parent 0d910f0 commit 2737333

File tree

2 files changed

+44
-24
lines changed

2 files changed

+44
-24
lines changed

kernels/portable/cpu/op_prod.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ Tensor& prod_out(
3333
ScalarType out_type = out.scalar_type();
3434
constexpr auto name = "prod.int_out";
3535

36-
ET_SWITCH_REALHB_TYPES(in_type, ctx, name, CTYPE_IN, [&] {
37-
ET_SWITCH_REALHB_TYPES(out_type, ctx, name, CTYPE_OUT, [&] {
36+
ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, name, CTYPE_IN, [&] {
37+
ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, name, CTYPE_OUT, [&] {
3838
const auto data_in = in.const_data_ptr<CTYPE_IN>();
3939
auto data_out = out.mutable_data_ptr<CTYPE_OUT>();
4040
data_out[0] = static_cast<CTYPE_OUT>(1);
@@ -73,8 +73,8 @@ Tensor& prod_int_out(
7373
ScalarType out_type = out.scalar_type();
7474
constexpr auto name = "prod.int_out";
7575

76-
ET_SWITCH_REALHB_TYPES(in_type, ctx, name, CTYPE_IN, [&] {
77-
ET_SWITCH_REALHB_TYPES(out_type, ctx, name, CTYPE_OUT, [&] {
76+
ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, name, CTYPE_IN, [&] {
77+
ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, name, CTYPE_OUT, [&] {
7878
CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
7979
for (size_t out_ix = 0; out_ix < out.numel(); ++out_ix) {
8080
CTYPE_OUT prod = 1;

kernels/test/op_prod_test.cpp

Lines changed: 40 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,24 @@ class OpProdOutTest : public ::testing::Test {
4545
// first.
4646
torch::executor::runtime_init();
4747
}
48+
49+
template <ScalarType DTYPE>
50+
void test_dtype() {
51+
TensorFactory<DTYPE> tf;
52+
TensorFactory<
53+
executorch::runtime::isIntegralType(DTYPE, /*includeBool*/ true)
54+
? ScalarType::Long
55+
: DTYPE>
56+
tf_out;
57+
58+
Tensor self = tf.make({2, 3}, {1, 2, 3, 4, 5, 6});
59+
optional<ScalarType> dtype{};
60+
Tensor out = tf_out.zeros({});
61+
Tensor out_expected =
62+
tf_out.make({}, {DTYPE == ScalarType::Bool ? 1 : 720});
63+
op_prod_out(self, dtype, out);
64+
EXPECT_TENSOR_CLOSE(out, out_expected);
65+
}
4866
};
4967

5068
class OpProdIntOutTest : public ::testing::Test {
@@ -54,30 +72,32 @@ class OpProdIntOutTest : public ::testing::Test {
5472
// first.
5573
torch::executor::runtime_init();
5674
}
57-
};
5875

59-
TEST_F(OpProdOutTest, SmokeTest) {
60-
TensorFactory<ScalarType::Float> tfFloat;
76+
template <ScalarType DTYPE>
77+
void test_dtype() {
78+
TensorFactory<DTYPE> tf;
6179

62-
Tensor self = tfFloat.make({2, 3}, {1, 2, 3, 4, 5, 6});
63-
optional<ScalarType> dtype{};
64-
Tensor out = tfFloat.zeros({});
65-
Tensor out_expected = tfFloat.make({}, {720});
66-
op_prod_out(self, dtype, out);
67-
EXPECT_TENSOR_CLOSE(out, out_expected);
68-
}
80+
Tensor self = tf.make({2, 3}, {1, 2, 3, 4, 5, 6});
81+
int64_t dim = 0;
82+
bool keepdim = false;
83+
optional<ScalarType> dtype{};
84+
Tensor out = tf.zeros({3});
85+
Tensor out_expected = tf.make({3}, {4, 10, 18});
86+
op_prod_int_out(self, dim, keepdim, dtype, out);
87+
EXPECT_TENSOR_CLOSE(out, out_expected);
88+
}
89+
};
6990

70-
TEST_F(OpProdIntOutTest, SmokeTest) {
71-
TensorFactory<ScalarType::Float> tfFloat;
91+
TEST_F(OpProdOutTest, SmokeTest){
92+
#define TEST_ENTRY(ctype, dtype) test_dtype<ScalarType::dtype>();
93+
ET_FORALL_REALHBBF16_TYPES(TEST_ENTRY)
94+
#undef TEST_ENTRY
95+
}
7296

73-
Tensor self = tfFloat.make({2, 3}, {1, 2, 3, 4, 5, 6});
74-
int64_t dim = 0;
75-
bool keepdim = false;
76-
optional<ScalarType> dtype{};
77-
Tensor out = tfFloat.zeros({3});
78-
Tensor out_expected = tfFloat.make({3}, {4, 10, 18});
79-
op_prod_int_out(self, dim, keepdim, dtype, out);
80-
EXPECT_TENSOR_CLOSE(out, out_expected);
97+
TEST_F(OpProdIntOutTest, SmokeTest){
98+
#define TEST_ENTRY(ctype, dtype) test_dtype<ScalarType::dtype>();
99+
ET_FORALL_REALHBBF16_TYPES(TEST_ENTRY)
100+
#undef TEST_ENTRY
81101
}
82102

83103
TEST_F(OpProdIntOutTest, SmokeTestKeepdim) {

0 commit comments

Comments
 (0)