diff --git a/kernels/test/ScalarOverflowTestMacros.h b/kernels/test/ScalarOverflowTestMacros.h new file mode 100644 index 00000000000..46a2425b0fa --- /dev/null +++ b/kernels/test/ScalarOverflowTestMacros.h @@ -0,0 +1,38 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +// Macro to generate scalar overflow test cases for a given test suite. +// The test suite must have a method called expect_bad_scalar_value_dies +// that takes a template parameter for ScalarType and a Scalar value. +#define GENERATE_SCALAR_OVERFLOW_TESTS(TEST_SUITE_NAME) \ + TEST_F(TEST_SUITE_NAME, ByteTensorTooLargeScalarDies) { \ + /* Cannot be represented by a uint8_t. */ \ + expect_bad_scalar_value_dies(256); \ + } \ + \ + TEST_F(TEST_SUITE_NAME, CharTensorTooSmallScalarDies) { \ + /* Cannot be represented by a int8_t. */ \ + expect_bad_scalar_value_dies(-129); \ + } \ + \ + TEST_F(TEST_SUITE_NAME, ShortTensorTooLargeScalarDies) { \ + /* Cannot be represented by a int16_t. */ \ + expect_bad_scalar_value_dies(32768); \ + } \ + \ + TEST_F(TEST_SUITE_NAME, FloatTensorTooSmallScalarDies) { \ + /* Cannot be represented by a float. */ \ + expect_bad_scalar_value_dies(-3.41e+38); \ + } \ + \ + TEST_F(TEST_SUITE_NAME, FloatTensorTooLargeScalarDies) { \ + /* Cannot be represented by a float. */ \ + expect_bad_scalar_value_dies(3.41e+38); \ + } diff --git a/kernels/test/op_fill_test.cpp b/kernels/test/op_fill_test.cpp index 0de49374477..c1c50206152 100644 --- a/kernels/test/op_fill_test.cpp +++ b/kernels/test/op_fill_test.cpp @@ -7,6 +7,7 @@ */ #include // Declares the operator +#include #include #include #include @@ -167,27 +168,4 @@ TEST_F(OpFillTest, MismatchedOutputDtypeDies) { ET_EXPECT_KERNEL_FAILURE(context_, op_fill_scalar_out(self, 0.0, out)); } -TEST_F(OpFillTest, ByteTensorTooLargeScalarDies) { - // Cannot be represented by a uint8_t. - expect_bad_scalar_value_dies(256); -} - -TEST_F(OpFillTest, CharTensorTooSmallScalarDies) { - // Cannot be represented by a int8_t. - expect_bad_scalar_value_dies(-129); -} - -TEST_F(OpFillTest, ShortTensorTooLargeScalarDies) { - // Cannot be represented by a int16_t. - expect_bad_scalar_value_dies(32768); -} - -TEST_F(OpFillTest, FloatTensorTooSmallScalarDies) { - // Cannot be represented by a float. - expect_bad_scalar_value_dies(-3.41e+38); -} - -TEST_F(OpFillTest, FloatTensorTooLargeScalarDies) { - // Cannot be represented by a float. - expect_bad_scalar_value_dies(3.41e+38); -} +GENERATE_SCALAR_OVERFLOW_TESTS(OpFillTest) diff --git a/kernels/test/targets.bzl b/kernels/test/targets.bzl index de3b2317f2b..60dabac1844 100644 --- a/kernels/test/targets.bzl +++ b/kernels/test/targets.bzl @@ -50,6 +50,7 @@ def define_common_targets(): ], exported_headers = [ "BinaryLogicalOpTest.h", + "ScalarOverflowTestMacros.h", "UnaryUfuncRealHBBF16ToFloatHBF16Test.h", ], visibility = [