Skip to content

Commit 5cbd790

Browse files
[ET][Portable] Check scalar overflow: op_full_like (#12081)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #12015 by @manuelcandales ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/manuelcandales/122/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/manuelcandales/122/head Merge bot PR base: https://github.com/pytorch/executorch/tree/gh/manuelcandales/121/orig Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/manuelcandales/122/orig @diff-train-skip-merge --------- Co-authored-by: Manuel Candales <[email protected]>
1 parent f25265b commit 5cbd790

File tree

2 files changed

+24
-13
lines changed

2 files changed

+24
-13
lines changed

kernels/portable/cpu/op_full_like.cpp

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -48,23 +48,19 @@ Tensor& full_like_out(
4848
out,
4949
"Failed to resize output tensor.");
5050

51-
ScalarType val_type = utils::get_scalar_dtype(fill_value);
5251
ScalarType out_type = out.scalar_type();
5352

5453
constexpr auto name = "scalar_tensor.out";
5554

56-
ET_SWITCH_REALB_TYPES(val_type, ctx, name, CTYPE_VAL, [&] {
57-
CTYPE_VAL val;
58-
ET_KERNEL_CHECK(
59-
ctx, utils::extract_scalar(fill_value, &val), InvalidArgument, );
60-
61-
ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, name, CTYPE_OUT, [&] {
62-
CTYPE_OUT val_casted = static_cast<CTYPE_OUT>(val);
63-
auto data_out = out.mutable_data_ptr<CTYPE_OUT>();
64-
for (const auto i : c10::irange(out.numel())) {
65-
data_out[i] = val_casted;
66-
}
67-
});
55+
ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, name, CTYPE_OUT, [&] {
56+
auto opt_val_casted =
57+
utils::internal::check_overflow_scalar_cast<CTYPE_OUT>(fill_value);
58+
ET_KERNEL_CHECK(ctx, opt_val_casted.has_value(), InvalidArgument, );
59+
auto val_casted = opt_val_casted.value();
60+
auto data_out = out.mutable_data_ptr<CTYPE_OUT>();
61+
for (const auto i : c10::irange(out.numel())) {
62+
data_out[i] = val_casted;
63+
}
6864
});
6965

7066
return out;

kernels/test/op_full_like_test.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
*/
88

99
#include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
10+
#include <executorch/kernels/test/ScalarOverflowTestMacros.h>
1011
#include <executorch/kernels/test/TestUtil.h>
1112
#include <executorch/kernels/test/supported_features.h>
1213
#include <executorch/runtime/core/exec_aten/exec_aten.h>
@@ -65,6 +66,18 @@ class OpFullLikeTest : public OperatorTest {
6566
ET_EXPECT_KERNEL_FAILURE(
6667
context_, op_full_like_out(in, value, memory_format, out));
6768
}
69+
70+
template <ScalarType DTYPE>
71+
void expect_bad_scalar_value_dies(const Scalar& bad_value) {
72+
TensorFactory<DTYPE> tf;
73+
const std::vector<int32_t> sizes = {2, 2};
74+
Tensor in = tf.zeros(sizes);
75+
Tensor out = tf.zeros(sizes);
76+
optional<MemoryFormat> memory_format;
77+
78+
ET_EXPECT_KERNEL_FAILURE(
79+
context_, op_full_like_out(in, bad_value, memory_format, out));
80+
}
6881
};
6982

7083
template <>
@@ -209,3 +222,5 @@ TEST_F(OpFullLikeTest, HalfSupport) {
209222
op_full_like_out(in, INFINITY, memory_format, out);
210223
EXPECT_TENSOR_CLOSE(out, tf.full({2, 3}, INFINITY));
211224
}
225+
226+
GENERATE_SCALAR_OVERFLOW_TESTS(OpFullLikeTest)

0 commit comments

Comments
 (0)