@@ -74,6 +74,15 @@ class OpFillTest : public OperatorTest {
7474 // Check `out` matches expected output.
7575 EXPECT_TENSOR_EQ (out, exp_out);
7676 }
77+
78+ template <ScalarType DTYPE>
79+ void expect_bad_scalar_value_dies (const Scalar& bad_value) {
80+ TensorFactory<DTYPE> tf;
81+ Tensor a = tf.ones ({2 , 2 });
82+ Tensor out = tf.zeros ({2 , 2 });
83+
84+ ET_EXPECT_KERNEL_FAILURE (context_, op_fill_scalar_out (a, bad_value, out));
85+ }
7786};
7887
7988// A macro for defining tests for both scalar and tensor variants of
@@ -157,3 +166,28 @@ TEST_F(OpFillTest, MismatchedOutputDtypeDies) {
157166 // Assert `out` can't be filled due to incompatible dtype.
158167 ET_EXPECT_KERNEL_FAILURE (context_, op_fill_scalar_out (self, 0.0 , out));
159168}
169+
170+ TEST_F (OpFillTest, ByteTensorTooLargeScalarDies) {
171+ // Cannot be represented by a uint8_t.
172+ expect_bad_scalar_value_dies<ScalarType::Byte>(256 );
173+ }
174+
175+ TEST_F (OpFillTest, CharTensorTooSmallScalarDies) {
176+ // Cannot be represented by a int8_t.
177+ expect_bad_scalar_value_dies<ScalarType::Char>(-129 );
178+ }
179+
180+ TEST_F (OpFillTest, ShortTensorTooLargeScalarDies) {
181+ // Cannot be represented by a int16_t.
182+ expect_bad_scalar_value_dies<ScalarType::Short>(32768 );
183+ }
184+
185+ TEST_F (OpFillTest, FloatTensorTooSmallScalarDies) {
186+ // Cannot be represented by a float.
187+ expect_bad_scalar_value_dies<ScalarType::Float>(-3.41e+38 );
188+ }
189+
190+ TEST_F (OpFillTest, FloatTensorTooLargeScalarDies) {
191+ // Cannot be represented by a float.
192+ expect_bad_scalar_value_dies<ScalarType::Float>(3.41e+38 );
193+ }
0 commit comments