Skip to content

Commit 829ba3e

Browse files
swolchokfacebook-github-bot
authored andcommitted
generalize tests for unary_ufunc_realhb_to_floath ops (1/2) (#5674)
Summary: Pull Request resolved: pytorch/executorch#5674 We have 20-odd ops that just map a unary function over each element. We can share the test infrastructure rather than create a bunch of similar-looking tests. Note that I attempted to minimize the amount of code inside `IMPLEMENT_UNARY_UFUNC_REALHB_TO_FLOATH_TEST`; it is just a bunch of `TEST_F` calls that call into plain old C++ functions, so we haven't merely copied the tests into a macro. This diff only makes one op adopt the test to prove it works; following diff(s) will roll out fully. ghstack-source-id: 245555539 exported-using-ghexport Reviewed By: manuelcandales Differential Revision: D63431290 fbshipit-source-id: 4ea1db3f3a708059731cbb5d328d2474ebfb41a2
1 parent 04669a1 commit 829ba3e

File tree

5 files changed

+332
-152
lines changed

5 files changed

+332
-152
lines changed

kernels/test/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,8 @@ set(all_test_sources
211211
"op_view_copy_test.cpp"
212212
"op_where_test.cpp"
213213
"op_zeros_test.cpp"
214-
)
214+
"UnaryUfuncRealHBToFloatHTest.cpp"
215+
)
215216

216217
set(_portable_kernels_test_sources
217218
${all_test_sources}
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/kernels/test/UnaryUfuncRealHBToFloatHTest.h>
10+
11+
namespace torch::executor::testing {
12+
void UnaryUfuncRealHBToFloatHTest::test_bool_input() {
13+
TensorFactory<exec_aten::ScalarType::Bool> tf_bool;
14+
TensorFactory<exec_aten::ScalarType::Float> tf_float;
15+
16+
const std::vector<int32_t> sizes = {1, 2};
17+
18+
exec_aten::Tensor a = tf_bool.make(sizes, /*data=*/{false, true});
19+
exec_aten::Tensor out = tf_float.zeros(sizes);
20+
exec_aten::Tensor res = tf_float.make(
21+
sizes,
22+
/*data=*/{(float)op_reference(false), (float)op_reference(true)});
23+
24+
EXPECT_TENSOR_CLOSE(op_out(a, out), res);
25+
}
26+
27+
void UnaryUfuncRealHBToFloatHTest::test_mismatched_input_shapes_dies() {
28+
if (get_supported_features()->is_aten) {
29+
GTEST_SKIP() << "ATen kernel can handle mismatched input shapes";
30+
}
31+
TensorFactory<exec_aten::ScalarType::Float> tf;
32+
33+
exec_aten::Tensor a = tf.ones(/*sizes=*/{4});
34+
exec_aten::Tensor out = tf.ones(/*sizes=*/{2, 2});
35+
36+
ET_EXPECT_KERNEL_FAILURE(context_, op_out(a, out));
37+
}
38+
39+
void UnaryUfuncRealHBToFloatHTest::
40+
test_all_real_input_half_output_static_dynamism_support() {
41+
if (get_supported_features()->is_aten) {
42+
GTEST_SKIP() << "Test Half support only for ExecuTorch mode";
43+
}
44+
#define TEST_ENTRY(ctype, dtype) \
45+
test_floating_point_op_out< \
46+
exec_aten::ScalarType::dtype, \
47+
exec_aten::ScalarType::Half>();
48+
ET_FORALL_REALH_TYPES(TEST_ENTRY);
49+
#undef TEST_ENTRY
50+
}
51+
52+
void UnaryUfuncRealHBToFloatHTest::
53+
test_all_real_input_float_output_static_dynamism_support() {
54+
#define TEST_ENTRY(ctype, dtype) \
55+
test_floating_point_op_out< \
56+
exec_aten::ScalarType::dtype, \
57+
exec_aten::ScalarType::Float>();
58+
ET_FORALL_REAL_TYPES(TEST_ENTRY);
59+
#undef TEST_ENTRY
60+
}
61+
62+
void UnaryUfuncRealHBToFloatHTest::
63+
test_all_real_input_double_output_static_dynamism_support() {
64+
#define TEST_ENTRY(ctype, dtype) \
65+
test_floating_point_op_out< \
66+
exec_aten::ScalarType::dtype, \
67+
exec_aten::ScalarType::Double>();
68+
ET_FORALL_REAL_TYPES(TEST_ENTRY);
69+
#undef TEST_ENTRY
70+
}
71+
72+
void UnaryUfuncRealHBToFloatHTest::
73+
test_all_real_input_half_output_bound_dynamism_support() {
74+
if (get_supported_features()->is_aten) {
75+
GTEST_SKIP() << "Test Half support only for ExecuTorch mode";
76+
}
77+
#define TEST_ENTRY(ctype, dtype) \
78+
test_floating_point_op_out< \
79+
exec_aten::ScalarType::dtype, \
80+
exec_aten::ScalarType::Half>( \
81+
{10, 10}, exec_aten::TensorShapeDynamism::DYNAMIC_BOUND);
82+
ET_FORALL_REALH_TYPES(TEST_ENTRY);
83+
#undef TEST_ENTRY
84+
}
85+
86+
void UnaryUfuncRealHBToFloatHTest::
87+
test_all_real_input_float_output_bound_dynamism_support() {
88+
#define TEST_ENTRY(ctype, dtype) \
89+
test_floating_point_op_out< \
90+
exec_aten::ScalarType::dtype, \
91+
exec_aten::ScalarType::Float>( \
92+
{10, 10}, exec_aten::TensorShapeDynamism::DYNAMIC_BOUND);
93+
ET_FORALL_REAL_TYPES(TEST_ENTRY);
94+
#undef TEST_ENTRY
95+
}
96+
97+
void UnaryUfuncRealHBToFloatHTest::
98+
test_all_real_input_double_output_bound_dynamism_support() {
99+
#define TEST_ENTRY(ctype, dtype) \
100+
test_floating_point_op_out< \
101+
exec_aten::ScalarType::dtype, \
102+
exec_aten::ScalarType::Double>( \
103+
{10, 10}, exec_aten::TensorShapeDynamism::DYNAMIC_BOUND);
104+
ET_FORALL_REAL_TYPES(TEST_ENTRY);
105+
#undef TEST_ENTRY
106+
}
107+
108+
void UnaryUfuncRealHBToFloatHTest::
109+
test_all_real_input_float_output_unbound_dynamism_support() {
110+
if (!get_supported_features()->is_aten) {
111+
GTEST_SKIP() << "Dynamic shape unbound not supported";
112+
}
113+
#define TEST_ENTRY(ctype, dtype) \
114+
test_floating_point_op_out< \
115+
exec_aten::ScalarType::dtype, \
116+
exec_aten::ScalarType::Float>( \
117+
{1, 1}, exec_aten::TensorShapeDynamism::DYNAMIC_UNBOUND);
118+
ET_FORALL_REAL_TYPES(TEST_ENTRY);
119+
#undef TEST_ENTRY
120+
}
121+
122+
void UnaryUfuncRealHBToFloatHTest::
123+
test_all_real_input_double_output_unbound_dynamism_support() {
124+
if (!get_supported_features()->is_aten) {
125+
GTEST_SKIP() << "Dynamic shape unbound not supported";
126+
}
127+
#define TEST_ENTRY(ctype, dtype) \
128+
test_floating_point_op_out< \
129+
exec_aten::ScalarType::dtype, \
130+
exec_aten::ScalarType::Double>( \
131+
{1, 1}, exec_aten::TensorShapeDynamism::DYNAMIC_UNBOUND);
132+
ET_FORALL_REAL_TYPES(TEST_ENTRY);
133+
#undef TEST_ENTRY
134+
}
135+
136+
void UnaryUfuncRealHBToFloatHTest::test_non_float_output_dtype_dies() {
137+
#define TEST_ENTRY(ctype, dtype) \
138+
test_op_invalid_output_dtype_dies< \
139+
exec_aten::ScalarType::Float, \
140+
exec_aten::ScalarType::dtype>();
141+
ET_FORALL_INT_TYPES(TEST_ENTRY);
142+
#undef TEST_ENTRY
143+
}
144+
145+
} // namespace torch::executor::testing
Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/kernels/test/TestUtil.h>
12+
#include <executorch/kernels/test/supported_features.h>
13+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
14+
#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
15+
#include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
16+
17+
#include <gtest/gtest.h>
18+
19+
namespace torch::executor::testing {
20+
// Generic test harness for ops that use unary_ufunc_realhb_to_floath
21+
// -- in other words, ops that just apply an elementwise function
22+
// mapping to a float or half.
23+
class UnaryUfuncRealHBToFloatHTest : public OperatorTest {
24+
protected:
25+
// Implement this to call the torch::executor::aten::op_outf function for the
26+
// op.
27+
virtual exec_aten::Tensor& op_out(
28+
const exec_aten::Tensor& self,
29+
exec_aten::Tensor& out) = 0;
30+
31+
// Scalar reference implementation of the function in question for testing.
32+
virtual double op_reference(double x) const = 0;
33+
34+
// The SupportedFeatures system assumes that it can build each test
35+
// target with a separate SupportedFeatures (really just one
36+
// portable, one optimzed but between one and the infinite, two is
37+
// ridiculous and can't exist). We work around that by calling
38+
// SupportedFeatures::get() in the concrete test translation
39+
// unit. You need to declare an override, but we implement it for you
40+
// in IMPLEMENT_UNARY_UFUNC_REALHB_TO_FLOATH_TEST.
41+
virtual SupportedFeatures* get_supported_features() const = 0;
42+
43+
template <exec_aten::ScalarType IN_DTYPE, exec_aten::ScalarType OUT_DTYPE>
44+
void test_floating_point_op_out(
45+
const std::vector<int32_t>& out_shape = {1, 6},
46+
exec_aten::TensorShapeDynamism dynamism =
47+
exec_aten::TensorShapeDynamism::STATIC) {
48+
TensorFactory<IN_DTYPE> tf_in;
49+
TensorFactory<OUT_DTYPE> tf_out;
50+
51+
exec_aten::Tensor out = tf_out.zeros(out_shape, dynamism);
52+
53+
std::vector<typename decltype(tf_in)::ctype> test_vector = {
54+
0, 1, 3, 5, 10, 100};
55+
std::vector<typename decltype(tf_out)::ctype> expected_vector;
56+
std::transform(
57+
test_vector.begin(),
58+
test_vector.end(),
59+
std::back_inserter(expected_vector),
60+
[this](auto x) { return this->op_reference(x); });
61+
62+
// clang-format off
63+
op_out(tf_in.make({1, 6}, test_vector), out);
64+
65+
EXPECT_TENSOR_CLOSE(
66+
out,
67+
tf_out.make({1, 6}, expected_vector));
68+
// clang-format on
69+
}
70+
71+
// Unhandled output dtypes.
72+
template <
73+
exec_aten::ScalarType INPUT_DTYPE,
74+
exec_aten::ScalarType OUTPUT_DTYPE>
75+
void test_op_invalid_output_dtype_dies() {
76+
TensorFactory<INPUT_DTYPE> tf;
77+
TensorFactory<OUTPUT_DTYPE> tf_out;
78+
79+
const std::vector<int32_t> sizes = {2, 5};
80+
81+
exec_aten::Tensor in = tf.ones(sizes);
82+
exec_aten::Tensor out = tf_out.zeros(sizes);
83+
84+
ET_EXPECT_KERNEL_FAILURE(context_, op_out(in, out));
85+
}
86+
87+
void test_bool_input();
88+
89+
void test_mismatched_input_shapes_dies();
90+
91+
void test_all_real_input_half_output_static_dynamism_support();
92+
93+
void test_all_real_input_float_output_static_dynamism_support();
94+
95+
void test_all_real_input_double_output_static_dynamism_support();
96+
97+
void test_all_real_input_half_output_bound_dynamism_support();
98+
99+
void test_all_real_input_float_output_bound_dynamism_support();
100+
101+
void test_all_real_input_double_output_bound_dynamism_support();
102+
103+
void test_all_real_input_float_output_unbound_dynamism_support();
104+
105+
void test_all_real_input_double_output_unbound_dynamism_support();
106+
107+
void test_non_float_output_dtype_dies();
108+
};
109+
110+
#define IMPLEMENT_UNARY_UFUNC_REALHB_TO_FLOATH_TEST(TestName) \
111+
torch::executor::testing::SupportedFeatures* \
112+
TestName::get_supported_features() const { \
113+
return torch::executor::testing::SupportedFeatures::get(); \
114+
} \
115+
TEST_F(TestName, HandleBoolInput) { \
116+
test_bool_input(); \
117+
} \
118+
TEST_F(TestName, AllRealInputHalfOutputStaticDynamismSupport) { \
119+
test_all_real_input_half_output_static_dynamism_support(); \
120+
} \
121+
\
122+
TEST_F(TestName, AllRealInputFloatOutputStaticDynamismSupport) { \
123+
test_all_real_input_float_output_static_dynamism_support(); \
124+
} \
125+
\
126+
TEST_F(TestName, AllRealInputDoubleOutputStaticDynamismSupport) { \
127+
test_all_real_input_double_output_static_dynamism_support(); \
128+
} \
129+
\
130+
TEST_F(TestName, AllRealInputHalfOutputBoundDynamismSupport) { \
131+
test_all_real_input_half_output_bound_dynamism_support(); \
132+
} \
133+
\
134+
TEST_F(TestName, AllRealInputFloatOutputBoundDynamismSupport) { \
135+
test_all_real_input_float_output_bound_dynamism_support(); \
136+
} \
137+
\
138+
TEST_F(TestName, AllRealInputDoubleOutputBoundDynamismSupport) { \
139+
test_all_real_input_double_output_bound_dynamism_support(); \
140+
} \
141+
\
142+
TEST_F(TestName, AllRealInputFloatOutputUnboundDynamismSupport) { \
143+
test_all_real_input_float_output_unbound_dynamism_support(); \
144+
} \
145+
\
146+
TEST_F(TestName, AllRealInputDoubleOutputUnboundDynamismSupport) { \
147+
test_all_real_input_double_output_unbound_dynamism_support(); \
148+
} \
149+
\
150+
TEST_F(TestName, AllNonFloatOutputDTypeDies) { \
151+
test_non_float_output_dtype_dies(); \
152+
} \
153+
\
154+
TEST_F(TestName, MismatchedInputShapesDies) { \
155+
test_mismatched_input_shapes_dies(); \
156+
}
157+
158+
} // namespace torch::executor::testing

0 commit comments

Comments
 (0)