Skip to content

Commit dd44442

Browse files
committed
[ExecuTorch] generalize tests for unary_ufunc_realhb_to_floath ops (1/2)
We have 20-odd ops that just map a unary function over each element. We can share the test infrastructure rather than create a bunch of similar-looking tests. Note that I attempted to minimize the amount of code inside `IMPLEMENT_UNARY_UFUNC_REALHB_TO_FLOATH_TEST`; it is just a bunch of `TEST_F` calls that call into plain old C++ functions, so we haven't merely copied the tests into a macro. This diff only makes one op adopt the test to prove it works; following diff(s) will roll out fully. Differential Revision: [D63431290](https://our.internmc.facebook.com/intern/diff/D63431290/) ghstack-source-id: 244762688 Pull Request resolved: #5674
1 parent d2ba238 commit dd44442

File tree

3 files changed

+261
-149
lines changed

3 files changed

+261
-149
lines changed
Lines changed: 247 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,247 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/kernels/test/TestUtil.h>
12+
#include <executorch/kernels/test/supported_features.h>
13+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
14+
#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
15+
#include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
16+
17+
#include <gtest/gtest.h>
18+
19+
namespace torch::executor::testing {
20+
// Generic test harness for ops that use unary_ufunc_realhb_to_floath
21+
// -- in other words, ops that just apply an elementwise function
22+
// mapping to a float or half.
23+
class UnaryUfuncRealHBToFloatHTest : public OperatorTest {
24+
protected:
25+
// Implement this to call the torch::executor::aten::op_outf function for the
26+
// top.
27+
virtual exec_aten::Tensor& op_out(
28+
const exec_aten::Tensor& self,
29+
exec_aten::Tensor& out) = 0;
30+
31+
// Scalar reference implementation of the function in question for testing.
32+
virtual double op_reference(double x) const = 0;
33+
34+
template <exec_aten::ScalarType IN_DTYPE, exec_aten::ScalarType OUT_DTYPE>
35+
void test_floating_point_op_out(
36+
const std::vector<int32_t>& out_shape = {1, 6},
37+
exec_aten::TensorShapeDynamism dynamism =
38+
exec_aten::TensorShapeDynamism::STATIC) {
39+
TensorFactory<IN_DTYPE> tf_in;
40+
TensorFactory<OUT_DTYPE> tf_out;
41+
42+
exec_aten::Tensor out = tf_out.zeros(out_shape, dynamism);
43+
44+
std::vector<typename decltype(tf_in)::ctype> test_vector = {
45+
0, 1, 3, 5, 10, 100};
46+
std::vector<typename decltype(tf_out)::ctype> expected_vector;
47+
std::transform(
48+
test_vector.begin(),
49+
test_vector.end(),
50+
std::back_inserter(expected_vector),
51+
[this](auto x) { return this->op_reference(x); });
52+
53+
// clang-format off
54+
op_out(tf_in.make({1, 6}, test_vector), out);
55+
56+
EXPECT_TENSOR_CLOSE(
57+
out,
58+
tf_out.make({1, 6}, expected_vector));
59+
// clang-format on
60+
}
61+
// Unhandled output dtypes.
62+
template <
63+
exec_aten::ScalarType INPUT_DTYPE,
64+
exec_aten::ScalarType OUTPUT_DTYPE>
65+
void test_op_invalid_output_dtype_dies() {
66+
TensorFactory<INPUT_DTYPE> tf;
67+
TensorFactory<OUTPUT_DTYPE> tf_out;
68+
69+
const std::vector<int32_t> sizes = {2, 5};
70+
71+
exec_aten::Tensor in = tf.ones(sizes);
72+
exec_aten::Tensor out = tf_out.zeros(sizes);
73+
74+
ET_EXPECT_KERNEL_FAILURE(context_, op_out(in, out));
75+
}
76+
77+
void test_bool_input() {
78+
TensorFactory<exec_aten::ScalarType::Bool> tf_bool;
79+
TensorFactory<exec_aten::ScalarType::Float> tf_float;
80+
81+
const std::vector<int32_t> sizes = {1, 2};
82+
83+
exec_aten::Tensor a = tf_bool.make(sizes, /*data=*/{false, true});
84+
exec_aten::Tensor out = tf_float.zeros(sizes);
85+
exec_aten::Tensor res = tf_float.make(
86+
sizes,
87+
/*data=*/{(float)op_reference(false), (float)op_reference(true)});
88+
89+
EXPECT_TENSOR_CLOSE(op_out(a, out), res);
90+
}
91+
92+
void test_mismatched_input_shapes_dies() {
93+
if (torch::executor::testing::SupportedFeatures::get()->is_aten) {
94+
GTEST_SKIP() << "ATen kernel can handle mismatched input shapes";
95+
}
96+
TensorFactory<exec_aten::ScalarType::Float> tf;
97+
98+
exec_aten::Tensor a = tf.ones(/*sizes=*/{4});
99+
exec_aten::Tensor out = tf.ones(/*sizes=*/{2, 2});
100+
101+
ET_EXPECT_KERNEL_FAILURE(context_, op_out(a, out));
102+
}
103+
104+
void test_all_real_input_half_output_static_dynamism_support() {
105+
if (torch::executor::testing::SupportedFeatures::get()->is_aten) {
106+
GTEST_SKIP() << "Test Half support only for ExecuTorch mode";
107+
}
108+
#define TEST_ENTRY(ctype, dtype) \
109+
test_floating_point_op_out< \
110+
exec_aten::ScalarType::dtype, \
111+
exec_aten::ScalarType::Half>();
112+
ET_FORALL_REALH_TYPES(TEST_ENTRY);
113+
#undef TEST_ENTRY
114+
}
115+
116+
void test_all_real_input_float_output_static_dynamism_support() {
117+
#define TEST_ENTRY(ctype, dtype) \
118+
test_floating_point_op_out< \
119+
exec_aten::ScalarType::dtype, \
120+
exec_aten::ScalarType::Float>();
121+
ET_FORALL_REAL_TYPES(TEST_ENTRY);
122+
#undef TEST_ENTRY
123+
}
124+
125+
void test_all_real_input_double_output_static_dynamism_support() {
126+
#define TEST_ENTRY(ctype, dtype) \
127+
test_floating_point_op_out< \
128+
exec_aten::ScalarType::dtype, \
129+
exec_aten::ScalarType::Double>();
130+
ET_FORALL_REAL_TYPES(TEST_ENTRY);
131+
#undef TEST_ENTRY
132+
}
133+
134+
void test_all_real_input_half_output_bound_dynamism_support() {
135+
if (torch::executor::testing::SupportedFeatures::get()->is_aten) {
136+
GTEST_SKIP() << "Test Half support only for ExecuTorch mode";
137+
}
138+
#define TEST_ENTRY(ctype, dtype) \
139+
test_floating_point_op_out< \
140+
exec_aten::ScalarType::dtype, \
141+
exec_aten::ScalarType::Half>( \
142+
{10, 10}, exec_aten::TensorShapeDynamism::DYNAMIC_BOUND);
143+
ET_FORALL_REALH_TYPES(TEST_ENTRY);
144+
#undef TEST_ENTRY
145+
}
146+
147+
void test_all_real_input_float_output_bound_dynamism_support() {
148+
#define TEST_ENTRY(ctype, dtype) \
149+
test_floating_point_op_out< \
150+
exec_aten::ScalarType::dtype, \
151+
exec_aten::ScalarType::Float>( \
152+
{10, 10}, exec_aten::TensorShapeDynamism::DYNAMIC_BOUND);
153+
ET_FORALL_REAL_TYPES(TEST_ENTRY);
154+
#undef TEST_ENTRY
155+
}
156+
157+
void test_all_real_input_double_output_bound_dynamism_support() {
158+
#define TEST_ENTRY(ctype, dtype) \
159+
test_floating_point_op_out< \
160+
exec_aten::ScalarType::dtype, \
161+
exec_aten::ScalarType::Double>( \
162+
{10, 10}, exec_aten::TensorShapeDynamism::DYNAMIC_BOUND);
163+
ET_FORALL_REAL_TYPES(TEST_ENTRY);
164+
#undef TEST_ENTRY
165+
}
166+
167+
void test_all_real_input_float_output_unbound_dynamism_support() {
168+
if (!torch::executor::testing::SupportedFeatures::get()->is_aten) {
169+
GTEST_SKIP() << "Dynamic shape unbound not supported";
170+
}
171+
#define TEST_ENTRY(ctype, dtype) \
172+
test_floating_point_op_out< \
173+
exec_aten::ScalarType::dtype, \
174+
exec_aten::ScalarType::Float>( \
175+
{1, 1}, exec_aten::TensorShapeDynamism::DYNAMIC_UNBOUND);
176+
ET_FORALL_REAL_TYPES(TEST_ENTRY);
177+
#undef TEST_ENTRY
178+
}
179+
180+
void test_all_real_input_double_output_unbound_dynamism_support() {
181+
if (!torch::executor::testing::SupportedFeatures::get()->is_aten) {
182+
GTEST_SKIP() << "Dynamic shape unbound not supported";
183+
}
184+
#define TEST_ENTRY(ctype, dtype) \
185+
test_floating_point_op_out< \
186+
exec_aten::ScalarType::dtype, \
187+
exec_aten::ScalarType::Double>( \
188+
{1, 1}, exec_aten::TensorShapeDynamism::DYNAMIC_UNBOUND);
189+
ET_FORALL_REAL_TYPES(TEST_ENTRY);
190+
#undef TEST_ENTRY
191+
}
192+
193+
void test_non_float_output_dtype_dies() {
194+
#define TEST_ENTRY(ctype, dtype) \
195+
test_op_invalid_output_dtype_dies< \
196+
exec_aten::ScalarType::Float, \
197+
exec_aten::ScalarType::dtype>();
198+
ET_FORALL_INT_TYPES(TEST_ENTRY);
199+
#undef TEST_ENTRY
200+
}
201+
};
202+
203+
#define IMPLEMENT_UNARY_UFUNC_REALHB_TO_FLOATH_TEST(TestName) \
204+
TEST_F(TestName, HandleBoolInput) { \
205+
test_bool_input(); \
206+
} \
207+
TEST_F(TestName, AllRealInputHalfOutputStaticDynamismSupport) { \
208+
test_all_real_input_half_output_static_dynamism_support(); \
209+
} \
210+
\
211+
TEST_F(TestName, AllRealInputFloatOutputStaticDynamismSupport) { \
212+
test_all_real_input_float_output_static_dynamism_support(); \
213+
} \
214+
\
215+
TEST_F(TestName, AllRealInputDoubleOutputStaticDynamismSupport) { \
216+
test_all_real_input_double_output_static_dynamism_support(); \
217+
} \
218+
\
219+
TEST_F(TestName, AllRealInputHalfOutputBoundDynamismSupport) { \
220+
test_all_real_input_half_output_bound_dynamism_support(); \
221+
} \
222+
\
223+
TEST_F(TestName, AllRealInputFloatOutputBoundDynamismSupport) { \
224+
test_all_real_input_float_output_bound_dynamism_support(); \
225+
} \
226+
\
227+
TEST_F(TestName, AllRealInputDoubleOutputBoundDynamismSupport) { \
228+
test_all_real_input_double_output_bound_dynamism_support(); \
229+
} \
230+
\
231+
TEST_F(TestName, AllRealInputFloatOutputUnboundDynamismSupport) { \
232+
test_all_real_input_float_output_unbound_dynamism_support(); \
233+
} \
234+
\
235+
TEST_F(TestName, AllRealInputDoubleOutputUnboundDynamismSupport) { \
236+
test_all_real_input_double_output_unbound_dynamism_support(); \
237+
} \
238+
\
239+
TEST_F(TestName, AllNonFloatOutputDTypeDies) { \
240+
test_non_float_output_dtype_dies(); \
241+
} \
242+
\
243+
TEST_F(TestName, MismatchedInputShapesDies) { \
244+
test_mismatched_input_shapes_dies(); \
245+
}
246+
247+
} // namespace torch::executor::testing

0 commit comments

Comments
 (0)