Skip to content

Commit 1e4b8c1

Browse files
[ET][Portable] Check scalar overflow: op_fill (#12042)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #12012 by @manuelcandales ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/manuelcandales/119/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/manuelcandales/119/head Merge bot PR base: https://github.com/pytorch/executorch/tree/gh/manuelcandales/118/orig Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/manuelcandales/119/orig @diff-train-skip-merge --------- Co-authored-by: Manuel Candales <[email protected]> Co-authored-by: Manuel Candales <[email protected]>
1 parent 084b064 commit 1e4b8c1

File tree

2 files changed

+37
-1
lines changed

2 files changed

+37
-1
lines changed

kernels/portable/cpu/op_fill.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,9 @@ Tensor& fill_scalar_out(
4242
"Failed to resize output tensor.");
4343

4444
ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, "fill.Scalar_out", CTYPE_A, [&] {
45-
const CTYPE_A b_casted = utils::scalar_to<CTYPE_A>(b);
45+
auto opt_b_casted = utils::internal::check_overflow_scalar_cast<CTYPE_A>(b);
46+
ET_KERNEL_CHECK(ctx, opt_b_casted.has_value(), InvalidArgument, );
47+
auto b_casted = opt_b_casted.value();
4648

4749
apply_unary_map_fn(
4850
[b_casted](const CTYPE_A val_a) { return b_casted; },

kernels/test/op_fill_test.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,15 @@ class OpFillTest : public OperatorTest {
7474
// Check `out` matches expected output.
7575
EXPECT_TENSOR_EQ(out, exp_out);
7676
}
77+
78+
template <ScalarType DTYPE>
79+
void expect_bad_scalar_value_dies(const Scalar& bad_value) {
80+
TensorFactory<DTYPE> tf;
81+
Tensor a = tf.ones({2, 2});
82+
Tensor out = tf.zeros({2, 2});
83+
84+
ET_EXPECT_KERNEL_FAILURE(context_, op_fill_scalar_out(a, bad_value, out));
85+
}
7786
};
7887

7988
// A macro for defining tests for both scalar and tensor variants of
@@ -157,3 +166,28 @@ TEST_F(OpFillTest, MismatchedOutputDtypeDies) {
157166
// Assert `out` can't be filled due to incompatible dtype.
158167
ET_EXPECT_KERNEL_FAILURE(context_, op_fill_scalar_out(self, 0.0, out));
159168
}
169+
170+
TEST_F(OpFillTest, ByteTensorTooLargeScalarDies) {
171+
// Cannot be represented by a uint8_t.
172+
expect_bad_scalar_value_dies<ScalarType::Byte>(256);
173+
}
174+
175+
TEST_F(OpFillTest, CharTensorTooSmallScalarDies) {
176+
// Cannot be represented by a int8_t.
177+
expect_bad_scalar_value_dies<ScalarType::Char>(-129);
178+
}
179+
180+
TEST_F(OpFillTest, ShortTensorTooLargeScalarDies) {
181+
// Cannot be represented by a int16_t.
182+
expect_bad_scalar_value_dies<ScalarType::Short>(32768);
183+
}
184+
185+
TEST_F(OpFillTest, FloatTensorTooSmallScalarDies) {
186+
// Cannot be represented by a float.
187+
expect_bad_scalar_value_dies<ScalarType::Float>(-3.41e+38);
188+
}
189+
190+
TEST_F(OpFillTest, FloatTensorTooLargeScalarDies) {
191+
// Cannot be represented by a float.
192+
expect_bad_scalar_value_dies<ScalarType::Float>(3.41e+38);
193+
}

0 commit comments

Comments
 (0)