@@ -26,10 +26,9 @@ using Tensor = executorch::aten::Tensor;
26
26
27
27
namespace {
28
28
29
- template <typename CTYPE_VAL, typename CTYPE_OUT, typename CTYPE_CAST>
29
+ template <typename CTYPE_OUT, typename CTYPE_CAST>
30
30
/* * Check if val, when cast to CTYPE_CAST, is not in the range of CTYPE_OUT */
31
- bool is_out_of_bounds (CTYPE_VAL val) {
32
- const CTYPE_CAST val_cast = static_cast <CTYPE_CAST>(val);
31
+ bool is_out_of_bounds (CTYPE_CAST val_cast) {
33
32
return val_cast < std::numeric_limits<CTYPE_OUT>::lowest () ||
34
33
val_cast > std::numeric_limits<CTYPE_OUT>::max ();
35
34
}
@@ -41,26 +40,24 @@ ET_NODISCARD bool check_bounds(
41
40
const char * val_name) {
42
41
auto is_valid = true ;
43
42
44
- ET_SWITCH_SCALAR_OBJ_TYPES (val_type, ctx, " clamp.out" , CTYPE_VAL, [&]() {
45
- CTYPE_VAL val = 0 ;
46
- utils::extract_scalar (val_scalar, &val);
47
- if (isIntegralType (out_type, /* includeBool=*/ false )) {
48
- ET_SWITCH_INT_TYPES (out_type, ctx, " clamp.out" , CTYPE_OUT, [&]() {
49
- if (is_out_of_bounds<CTYPE_VAL, CTYPE_OUT, long >(val)) {
50
- ET_LOG (Error, " %s value out of bounds" , val_name);
51
- is_valid = false ;
52
- }
53
- });
54
- } else if (isFloatingType (out_type)) {
55
- ET_SWITCH_FLOATH_TYPES (out_type, ctx, " clamp" , CTYPE_OUT, [&]() {
56
- if (std::isfinite (val) &&
57
- is_out_of_bounds<CTYPE_VAL, CTYPE_OUT, double >(val)) {
58
- ET_LOG (Error, " %s value out of bounds" , val_name);
59
- is_valid = false ;
60
- }
61
- });
62
- }
63
- });
43
+ if (isIntegralType (out_type, /* includeBool=*/ false )) {
44
+ const long val_long = utils::scalar_to<long >(val_scalar);
45
+ ET_SWITCH_INT_TYPES (out_type, ctx, " clamp.out" , CTYPE_OUT, [&]() {
46
+ if (is_out_of_bounds<CTYPE_OUT, long >(val_long)) {
47
+ ET_LOG (Error, " %s value out of bounds" , val_name);
48
+ is_valid = false ;
49
+ }
50
+ });
51
+ } else if (isFloatingType (out_type)) {
52
+ ET_SWITCH_FLOATHBF16_TYPES (out_type, ctx, " clamp.out" , CTYPE_OUT, [&]() {
53
+ const double val_double = utils::scalar_to<double >(val_scalar);
54
+ if (std::isfinite (val_double) &&
55
+ is_out_of_bounds<CTYPE_OUT, double >(val_double)) {
56
+ ET_LOG (Error, " %s value out of bounds" , val_name);
57
+ is_valid = false ;
58
+ }
59
+ });
60
+ }
64
61
65
62
return is_valid;
66
63
}
0 commit comments