@@ -26,10 +26,9 @@ using Tensor = executorch::aten::Tensor;
2626
2727namespace {
2828
29- template <typename CTYPE_VAL, typename CTYPE_OUT, typename CTYPE_CAST>
29+ template <typename CTYPE_OUT, typename CTYPE_CAST>
3030/* * Check if val, when cast to CTYPE_CAST, is not in the range of CTYPE_OUT */
31- bool is_out_of_bounds (CTYPE_VAL val) {
32- const CTYPE_CAST val_cast = static_cast <CTYPE_CAST>(val);
31+ bool is_out_of_bounds (CTYPE_CAST val_cast) {
3332 return val_cast < std::numeric_limits<CTYPE_OUT>::lowest () ||
3433 val_cast > std::numeric_limits<CTYPE_OUT>::max ();
3534}
@@ -41,26 +40,24 @@ ET_NODISCARD bool check_bounds(
4140 const char * val_name) {
4241 auto is_valid = true ;
4342
44- ET_SWITCH_SCALAR_OBJ_TYPES (val_type, ctx, " clamp.out" , CTYPE_VAL, [&]() {
45- CTYPE_VAL val = 0 ;
46- utils::extract_scalar (val_scalar, &val);
47- if (isIntegralType (out_type, /* includeBool=*/ false )) {
48- ET_SWITCH_INT_TYPES (out_type, ctx, " clamp.out" , CTYPE_OUT, [&]() {
49- if (is_out_of_bounds<CTYPE_VAL, CTYPE_OUT, long >(val)) {
50- ET_LOG (Error, " %s value out of bounds" , val_name);
51- is_valid = false ;
52- }
53- });
54- } else if (isFloatingType (out_type)) {
55- ET_SWITCH_FLOATH_TYPES (out_type, ctx, " clamp" , CTYPE_OUT, [&]() {
56- if (std::isfinite (val) &&
57- is_out_of_bounds<CTYPE_VAL, CTYPE_OUT, double >(val)) {
58- ET_LOG (Error, " %s value out of bounds" , val_name);
59- is_valid = false ;
60- }
61- });
62- }
63- });
43+ if (isIntegralType (out_type, /* includeBool=*/ false )) {
44+ const long val_long = utils::scalar_to<long >(val_scalar);
45+ ET_SWITCH_INT_TYPES (out_type, ctx, " clamp.out" , CTYPE_OUT, [&]() {
46+ if (is_out_of_bounds<CTYPE_OUT, long >(val_long)) {
47+ ET_LOG (Error, " %s value out of bounds" , val_name);
48+ is_valid = false ;
49+ }
50+ });
51+ } else if (isFloatingType (out_type)) {
52+ ET_SWITCH_FLOATHBF16_TYPES (out_type, ctx, " clamp.out" , CTYPE_OUT, [&]() {
53+ const double val_double = utils::scalar_to<double >(val_scalar);
54+ if (std::isfinite (val_double) &&
55+ is_out_of_bounds<CTYPE_OUT, double >(val_double)) {
56+ ET_LOG (Error, " %s value out of bounds" , val_name);
57+ is_valid = false ;
58+ }
59+ });
60+ }
6461
6562 return is_valid;
6663}
0 commit comments