@@ -45,6 +45,24 @@ class OpProdOutTest : public ::testing::Test {
4545 // first.
4646 torch::executor::runtime_init ();
4747 }
48+
49+ template <ScalarType DTYPE>
50+ void test_dtype () {
51+ TensorFactory<DTYPE> tf;
52+ TensorFactory<
53+ executorch::runtime::isIntegralType (DTYPE, /* includeBool*/ true )
54+ ? ScalarType::Long
55+ : DTYPE>
56+ tf_out;
57+
58+ Tensor self = tf.make ({2 , 3 }, {1 , 2 , 3 , 4 , 5 , 6 });
59+ optional<ScalarType> dtype{};
60+ Tensor out = tf_out.zeros ({});
61+ Tensor out_expected =
62+ tf_out.make ({}, {DTYPE == ScalarType::Bool ? 1 : 720 });
63+ op_prod_out (self, dtype, out);
64+ EXPECT_TENSOR_CLOSE (out, out_expected);
65+ }
4866};
4967
5068class OpProdIntOutTest : public ::testing::Test {
@@ -54,30 +72,32 @@ class OpProdIntOutTest : public ::testing::Test {
5472 // first.
5573 torch::executor::runtime_init ();
5674 }
57- };
5875
59- TEST_F (OpProdOutTest, SmokeTest) {
60- TensorFactory<ScalarType::Float> tfFloat;
76+ template <ScalarType DTYPE>
77+ void test_dtype () {
78+ TensorFactory<DTYPE> tf;
6179
62- Tensor self = tfFloat.make ({2 , 3 }, {1 , 2 , 3 , 4 , 5 , 6 });
63- optional<ScalarType> dtype{};
64- Tensor out = tfFloat.zeros ({});
65- Tensor out_expected = tfFloat.make ({}, {720 });
66- op_prod_out (self, dtype, out);
67- EXPECT_TENSOR_CLOSE (out, out_expected);
68- }
80+ Tensor self = tf.make ({2 , 3 }, {1 , 2 , 3 , 4 , 5 , 6 });
81+ int64_t dim = 0 ;
82+ bool keepdim = false ;
83+ optional<ScalarType> dtype{};
84+ Tensor out = tf.zeros ({3 });
85+ Tensor out_expected = tf.make ({3 }, {4 , 10 , 18 });
86+ op_prod_int_out (self, dim, keepdim, dtype, out);
87+ EXPECT_TENSOR_CLOSE (out, out_expected);
88+ }
89+ };
6990
70- TEST_F (OpProdIntOutTest, SmokeTest) {
71- TensorFactory<ScalarType::Float> tfFloat;
91+ TEST_F (OpProdOutTest, SmokeTest){
92+ #define TEST_ENTRY (ctype, dtype ) test_dtype<ScalarType::dtype>();
93+ ET_FORALL_REALHBBF16_TYPES (TEST_ENTRY)
94+ #undef TEST_ENTRY
95+ }
7296
73- Tensor self = tfFloat.make ({2 , 3 }, {1 , 2 , 3 , 4 , 5 , 6 });
74- int64_t dim = 0 ;
75- bool keepdim = false ;
76- optional<ScalarType> dtype{};
77- Tensor out = tfFloat.zeros ({3 });
78- Tensor out_expected = tfFloat.make ({3 }, {4 , 10 , 18 });
79- op_prod_int_out (self, dim, keepdim, dtype, out);
80- EXPECT_TENSOR_CLOSE (out, out_expected);
97+ TEST_F (OpProdIntOutTest, SmokeTest){
98+ #define TEST_ENTRY (ctype, dtype ) test_dtype<ScalarType::dtype>();
99+ ET_FORALL_REALHBBF16_TYPES (TEST_ENTRY)
100+ #undef TEST_ENTRY
81101}
82102
83103TEST_F (OpProdIntOutTest, SmokeTestKeepdim) {
0 commit comments