@@ -74,6 +74,15 @@ class OpFillTest : public OperatorTest {
74
74
// Check `out` matches expected output.
75
75
EXPECT_TENSOR_EQ (out, exp_out);
76
76
}
77
+
78
+ template <ScalarType DTYPE>
79
+ void expect_bad_scalar_value_dies (const Scalar& bad_value) {
80
+ TensorFactory<DTYPE> tf;
81
+ Tensor a = tf.ones ({2 , 2 });
82
+ Tensor out = tf.zeros ({2 , 2 });
83
+
84
+ ET_EXPECT_KERNEL_FAILURE (context_, op_fill_scalar_out (a, bad_value, out));
85
+ }
77
86
};
78
87
79
88
// A macro for defining tests for both scalar and tensor variants of
@@ -157,3 +166,28 @@ TEST_F(OpFillTest, MismatchedOutputDtypeDies) {
157
166
// Assert `out` can't be filled due to incompatible dtype.
158
167
ET_EXPECT_KERNEL_FAILURE (context_, op_fill_scalar_out (self, 0.0 , out));
159
168
}
169
+
170
+ TEST_F (OpFillTest, ByteTensorTooLargeScalarDies) {
171
+ // Cannot be represented by a uint8_t.
172
+ expect_bad_scalar_value_dies<ScalarType::Byte>(256 );
173
+ }
174
+
175
+ TEST_F (OpFillTest, CharTensorTooSmallScalarDies) {
176
+ // Cannot be represented by a int8_t.
177
+ expect_bad_scalar_value_dies<ScalarType::Char>(-129 );
178
+ }
179
+
180
+ TEST_F (OpFillTest, ShortTensorTooLargeScalarDies) {
181
+ // Cannot be represented by a int16_t.
182
+ expect_bad_scalar_value_dies<ScalarType::Short>(32768 );
183
+ }
184
+
185
+ TEST_F (OpFillTest, FloatTensorTooSmallScalarDies) {
186
+ // Cannot be represented by a float.
187
+ expect_bad_scalar_value_dies<ScalarType::Float>(-3.41e+38 );
188
+ }
189
+
190
+ TEST_F (OpFillTest, FloatTensorTooLargeScalarDies) {
191
+ // Cannot be represented by a float.
192
+ expect_bad_scalar_value_dies<ScalarType::Float>(3.41e+38 );
193
+ }
0 commit comments