Skip to content

Commit da9466d

Browse files
committed
[ExecuTorch] generalize tests for unary_ufunc_realhb_to_floath ops (1/2)
Pull Request resolved: #5674 We have 20-odd ops that just map a unary function over each element. We can share the test infrastructure rather than create a bunch of similar-looking tests. Note that I attempted to minimize the amount of code inside `IMPLEMENT_UNARY_UFUNC_REALHB_TO_FLOATH_TEST`; it is just a bunch of `TEST_F` calls that call into plain old C++ functions, so we haven't merely copied the tests into a macro. This diff only makes one op adopt the test to prove it works; following diff(s) will roll out fully. ghstack-source-id: 244771365 @exported-using-ghexport Differential Revision: [D63431290](https://our.internmc.facebook.com/intern/diff/D63431290/)
1 parent d2ba238 commit da9466d

File tree

4 files changed

+311
-149
lines changed

4 files changed

+311
-149
lines changed
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/kernels/test/UnaryUfuncRealHBToFloatHTest.h>
10+
11+
namespace torch::executor::testing {
12+
void UnaryUfuncRealHBToFloatHTest::test_bool_input() {
13+
TensorFactory<exec_aten::ScalarType::Bool> tf_bool;
14+
TensorFactory<exec_aten::ScalarType::Float> tf_float;
15+
16+
const std::vector<int32_t> sizes = {1, 2};
17+
18+
exec_aten::Tensor a = tf_bool.make(sizes, /*data=*/{false, true});
19+
exec_aten::Tensor out = tf_float.zeros(sizes);
20+
exec_aten::Tensor res = tf_float.make(
21+
sizes,
22+
/*data=*/{(float)op_reference(false), (float)op_reference(true)});
23+
24+
EXPECT_TENSOR_CLOSE(op_out(a, out), res);
25+
}
26+
27+
void UnaryUfuncRealHBToFloatHTest::test_mismatched_input_shapes_dies() {
28+
if (torch::executor::testing::SupportedFeatures::get()->is_aten) {
29+
GTEST_SKIP() << "ATen kernel can handle mismatched input shapes";
30+
}
31+
TensorFactory<exec_aten::ScalarType::Float> tf;
32+
33+
exec_aten::Tensor a = tf.ones(/*sizes=*/{4});
34+
exec_aten::Tensor out = tf.ones(/*sizes=*/{2, 2});
35+
36+
ET_EXPECT_KERNEL_FAILURE(context_, op_out(a, out));
37+
}
38+
39+
void UnaryUfuncRealHBToFloatHTest::
40+
test_all_real_input_half_output_static_dynamism_support() {
41+
if (torch::executor::testing::SupportedFeatures::get()->is_aten) {
42+
GTEST_SKIP() << "Test Half support only for ExecuTorch mode";
43+
}
44+
#define TEST_ENTRY(ctype, dtype) \
45+
test_floating_point_op_out< \
46+
exec_aten::ScalarType::dtype, \
47+
exec_aten::ScalarType::Half>();
48+
ET_FORALL_REALH_TYPES(TEST_ENTRY);
49+
#undef TEST_ENTRY
50+
}
51+
52+
void UnaryUfuncRealHBToFloatHTest::
53+
test_all_real_input_float_output_static_dynamism_support() {
54+
#define TEST_ENTRY(ctype, dtype) \
55+
test_floating_point_op_out< \
56+
exec_aten::ScalarType::dtype, \
57+
exec_aten::ScalarType::Float>();
58+
ET_FORALL_REAL_TYPES(TEST_ENTRY);
59+
#undef TEST_ENTRY
60+
}
61+
62+
void UnaryUfuncRealHBToFloatHTest::
63+
test_all_real_input_double_output_static_dynamism_support() {
64+
#define TEST_ENTRY(ctype, dtype) \
65+
test_floating_point_op_out< \
66+
exec_aten::ScalarType::dtype, \
67+
exec_aten::ScalarType::Double>();
68+
ET_FORALL_REAL_TYPES(TEST_ENTRY);
69+
#undef TEST_ENTRY
70+
}
71+
72+
void UnaryUfuncRealHBToFloatHTest::
73+
test_all_real_input_half_output_bound_dynamism_support() {
74+
if (torch::executor::testing::SupportedFeatures::get()->is_aten) {
75+
GTEST_SKIP() << "Test Half support only for ExecuTorch mode";
76+
}
77+
#define TEST_ENTRY(ctype, dtype) \
78+
test_floating_point_op_out< \
79+
exec_aten::ScalarType::dtype, \
80+
exec_aten::ScalarType::Half>( \
81+
{10, 10}, exec_aten::TensorShapeDynamism::DYNAMIC_BOUND);
82+
ET_FORALL_REALH_TYPES(TEST_ENTRY);
83+
#undef TEST_ENTRY
84+
}
85+
86+
void UnaryUfuncRealHBToFloatHTest::
87+
test_all_real_input_float_output_bound_dynamism_support() {
88+
#define TEST_ENTRY(ctype, dtype) \
89+
test_floating_point_op_out< \
90+
exec_aten::ScalarType::dtype, \
91+
exec_aten::ScalarType::Float>( \
92+
{10, 10}, exec_aten::TensorShapeDynamism::DYNAMIC_BOUND);
93+
ET_FORALL_REAL_TYPES(TEST_ENTRY);
94+
#undef TEST_ENTRY
95+
}
96+
97+
void UnaryUfuncRealHBToFloatHTest::
98+
test_all_real_input_double_output_bound_dynamism_support() {
99+
#define TEST_ENTRY(ctype, dtype) \
100+
test_floating_point_op_out< \
101+
exec_aten::ScalarType::dtype, \
102+
exec_aten::ScalarType::Double>( \
103+
{10, 10}, exec_aten::TensorShapeDynamism::DYNAMIC_BOUND);
104+
ET_FORALL_REAL_TYPES(TEST_ENTRY);
105+
#undef TEST_ENTRY
106+
}
107+
108+
void UnaryUfuncRealHBToFloatHTest::
109+
test_all_real_input_float_output_unbound_dynamism_support() {
110+
if (!torch::executor::testing::SupportedFeatures::get()->is_aten) {
111+
GTEST_SKIP() << "Dynamic shape unbound not supported";
112+
}
113+
#define TEST_ENTRY(ctype, dtype) \
114+
test_floating_point_op_out< \
115+
exec_aten::ScalarType::dtype, \
116+
exec_aten::ScalarType::Float>( \
117+
{1, 1}, exec_aten::TensorShapeDynamism::DYNAMIC_UNBOUND);
118+
ET_FORALL_REAL_TYPES(TEST_ENTRY);
119+
#undef TEST_ENTRY
120+
}
121+
122+
void UnaryUfuncRealHBToFloatHTest::
123+
test_all_real_input_double_output_unbound_dynamism_support() {
124+
if (!torch::executor::testing::SupportedFeatures::get()->is_aten) {
125+
GTEST_SKIP() << "Dynamic shape unbound not supported";
126+
}
127+
#define TEST_ENTRY(ctype, dtype) \
128+
test_floating_point_op_out< \
129+
exec_aten::ScalarType::dtype, \
130+
exec_aten::ScalarType::Double>( \
131+
{1, 1}, exec_aten::TensorShapeDynamism::DYNAMIC_UNBOUND);
132+
ET_FORALL_REAL_TYPES(TEST_ENTRY);
133+
#undef TEST_ENTRY
134+
}
135+
136+
void UnaryUfuncRealHBToFloatHTest::test_non_float_output_dtype_dies() {
137+
#define TEST_ENTRY(ctype, dtype) \
138+
test_op_invalid_output_dtype_dies< \
139+
exec_aten::ScalarType::Float, \
140+
exec_aten::ScalarType::dtype>();
141+
ET_FORALL_INT_TYPES(TEST_ENTRY);
142+
#undef TEST_ENTRY
143+
}
144+
145+
} // namespace torch::executor::testing
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/kernels/test/TestUtil.h>
12+
#include <executorch/kernels/test/supported_features.h>
13+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
14+
#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
15+
#include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
16+
17+
#include <gtest/gtest.h>
18+
19+
namespace torch::executor::testing {
20+
// Generic test harness for ops that use unary_ufunc_realhb_to_floath
21+
// -- in other words, ops that just apply an elementwise function
22+
// mapping to a float or half.
23+
class UnaryUfuncRealHBToFloatHTest : public OperatorTest {
24+
protected:
25+
// Implement this to call the torch::executor::aten::op_outf function for the
26+
// top.
27+
virtual exec_aten::Tensor& op_out(
28+
const exec_aten::Tensor& self,
29+
exec_aten::Tensor& out) = 0;
30+
31+
// Scalar reference implementation of the function in question for testing.
32+
virtual double op_reference(double x) const = 0;
33+
34+
template <exec_aten::ScalarType IN_DTYPE, exec_aten::ScalarType OUT_DTYPE>
35+
void test_floating_point_op_out(
36+
const std::vector<int32_t>& out_shape = {1, 6},
37+
exec_aten::TensorShapeDynamism dynamism =
38+
exec_aten::TensorShapeDynamism::STATIC) {
39+
TensorFactory<IN_DTYPE> tf_in;
40+
TensorFactory<OUT_DTYPE> tf_out;
41+
42+
exec_aten::Tensor out = tf_out.zeros(out_shape, dynamism);
43+
44+
std::vector<typename decltype(tf_in)::ctype> test_vector = {
45+
0, 1, 3, 5, 10, 100};
46+
std::vector<typename decltype(tf_out)::ctype> expected_vector;
47+
std::transform(
48+
test_vector.begin(),
49+
test_vector.end(),
50+
std::back_inserter(expected_vector),
51+
[this](auto x) { return this->op_reference(x); });
52+
53+
// clang-format off
54+
op_out(tf_in.make({1, 6}, test_vector), out);
55+
56+
EXPECT_TENSOR_CLOSE(
57+
out,
58+
tf_out.make({1, 6}, expected_vector));
59+
// clang-format on
60+
}
61+
62+
// Unhandled output dtypes.
63+
template <
64+
exec_aten::ScalarType INPUT_DTYPE,
65+
exec_aten::ScalarType OUTPUT_DTYPE>
66+
void test_op_invalid_output_dtype_dies() {
67+
TensorFactory<INPUT_DTYPE> tf;
68+
TensorFactory<OUTPUT_DTYPE> tf_out;
69+
70+
const std::vector<int32_t> sizes = {2, 5};
71+
72+
exec_aten::Tensor in = tf.ones(sizes);
73+
exec_aten::Tensor out = tf_out.zeros(sizes);
74+
75+
ET_EXPECT_KERNEL_FAILURE(context_, op_out(in, out));
76+
}
77+
78+
void test_bool_input();
79+
80+
void test_mismatched_input_shapes_dies();
81+
82+
void test_all_real_input_half_output_static_dynamism_support();
83+
84+
void test_all_real_input_float_output_static_dynamism_support();
85+
86+
void test_all_real_input_double_output_static_dynamism_support();
87+
88+
void test_all_real_input_half_output_bound_dynamism_support();
89+
90+
void test_all_real_input_float_output_bound_dynamism_support();
91+
92+
void test_all_real_input_double_output_bound_dynamism_support();
93+
94+
void test_all_real_input_float_output_unbound_dynamism_support();
95+
96+
void test_all_real_input_double_output_unbound_dynamism_support();
97+
98+
void test_non_float_output_dtype_dies();
99+
};
100+
101+
#define IMPLEMENT_UNARY_UFUNC_REALHB_TO_FLOATH_TEST(TestName) \
102+
TEST_F(TestName, HandleBoolInput) { \
103+
test_bool_input(); \
104+
} \
105+
TEST_F(TestName, AllRealInputHalfOutputStaticDynamismSupport) { \
106+
test_all_real_input_half_output_static_dynamism_support(); \
107+
} \
108+
\
109+
TEST_F(TestName, AllRealInputFloatOutputStaticDynamismSupport) { \
110+
test_all_real_input_float_output_static_dynamism_support(); \
111+
} \
112+
\
113+
TEST_F(TestName, AllRealInputDoubleOutputStaticDynamismSupport) { \
114+
test_all_real_input_double_output_static_dynamism_support(); \
115+
} \
116+
\
117+
TEST_F(TestName, AllRealInputHalfOutputBoundDynamismSupport) { \
118+
test_all_real_input_half_output_bound_dynamism_support(); \
119+
} \
120+
\
121+
TEST_F(TestName, AllRealInputFloatOutputBoundDynamismSupport) { \
122+
test_all_real_input_float_output_bound_dynamism_support(); \
123+
} \
124+
\
125+
TEST_F(TestName, AllRealInputDoubleOutputBoundDynamismSupport) { \
126+
test_all_real_input_double_output_bound_dynamism_support(); \
127+
} \
128+
\
129+
TEST_F(TestName, AllRealInputFloatOutputUnboundDynamismSupport) { \
130+
test_all_real_input_float_output_unbound_dynamism_support(); \
131+
} \
132+
\
133+
TEST_F(TestName, AllRealInputDoubleOutputUnboundDynamismSupport) { \
134+
test_all_real_input_double_output_unbound_dynamism_support(); \
135+
} \
136+
\
137+
TEST_F(TestName, AllNonFloatOutputDTypeDies) { \
138+
test_non_float_output_dtype_dies(); \
139+
} \
140+
\
141+
TEST_F(TestName, MismatchedInputShapesDies) { \
142+
test_mismatched_input_shapes_dies(); \
143+
}
144+
145+
} // namespace torch::executor::testing

0 commit comments

Comments
 (0)