Skip to content

Commit f49cf00

Browse files
[ET][Test] Scalar overflow test macro (#12043)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #12013 by @manuelcandales ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/manuelcandales/120/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/manuelcandales/120/head Merge bot PR base: https://github.com/pytorch/executorch/tree/gh/manuelcandales/119/orig Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/manuelcandales/120/orig @diff-train-skip-merge --------- Co-authored-by: Manuel Candales <[email protected]> Co-authored-by: Manuel Candales <[email protected]>
1 parent 1e4b8c1 commit f49cf00

File tree

3 files changed

+40
-0
lines changed

3 files changed

+40
-0
lines changed
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
// Macro to generate scalar overflow test cases for a given test suite.
12+
// The test suite must have a method called expect_bad_scalar_value_dies
13+
// that takes a template parameter for ScalarType and a Scalar value.
14+
#define GENERATE_SCALAR_OVERFLOW_TESTS(TEST_SUITE_NAME) \
15+
TEST_F(TEST_SUITE_NAME, ByteTensorTooLargeScalarDies) { \
16+
/* Cannot be represented by a uint8_t. */ \
17+
expect_bad_scalar_value_dies<ScalarType::Byte>(256); \
18+
} \
19+
\
20+
TEST_F(TEST_SUITE_NAME, CharTensorTooSmallScalarDies) { \
21+
/* Cannot be represented by a int8_t. */ \
22+
expect_bad_scalar_value_dies<ScalarType::Char>(-129); \
23+
} \
24+
\
25+
TEST_F(TEST_SUITE_NAME, ShortTensorTooLargeScalarDies) { \
26+
/* Cannot be represented by a int16_t. */ \
27+
expect_bad_scalar_value_dies<ScalarType::Short>(32768); \
28+
} \
29+
\
30+
TEST_F(TEST_SUITE_NAME, FloatTensorTooSmallScalarDies) { \
31+
/* Cannot be represented by a float. */ \
32+
expect_bad_scalar_value_dies<ScalarType::Float>(-3.41e+38); \
33+
} \
34+
\
35+
TEST_F(TEST_SUITE_NAME, FloatTensorTooLargeScalarDies) { \
36+
/* Cannot be represented by a float. */ \
37+
expect_bad_scalar_value_dies<ScalarType::Float>(3.41e+38); \
38+
}

kernels/test/op_fill_test.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
*/
88

99
#include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
10+
#include <executorch/kernels/test/ScalarOverflowTestMacros.h>
1011
#include <executorch/kernels/test/TestUtil.h>
1112
#include <executorch/kernels/test/supported_features.h>
1213
#include <executorch/runtime/core/exec_aten/exec_aten.h>

kernels/test/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def define_common_targets():
5050
],
5151
exported_headers = [
5252
"BinaryLogicalOpTest.h",
53+
"ScalarOverflowTestMacros.h",
5354
"UnaryUfuncRealHBBF16ToFloatHBF16Test.h",
5455
],
5556
visibility = [

0 commit comments

Comments
 (0)