diff --git a/kernels/portable/cpu/op_clamp.cpp b/kernels/portable/cpu/op_clamp.cpp index af082a18e78..c2b9c73f2ea 100644 --- a/kernels/portable/cpu/op_clamp.cpp +++ b/kernels/portable/cpu/op_clamp.cpp @@ -26,10 +26,9 @@ using Tensor = executorch::aten::Tensor; namespace { -template +template /** Check if val, when cast to CTYPE_CAST, is not in the range of CTYPE_OUT */ -bool is_out_of_bounds(CTYPE_VAL val) { - const CTYPE_CAST val_cast = static_cast(val); +bool is_out_of_bounds(CTYPE_CAST val_cast) { return val_cast < std::numeric_limits::lowest() || val_cast > std::numeric_limits::max(); } @@ -41,26 +40,24 @@ ET_NODISCARD bool check_bounds( const char* val_name) { auto is_valid = true; - ET_SWITCH_SCALAR_OBJ_TYPES(val_type, ctx, "clamp.out", CTYPE_VAL, [&]() { - CTYPE_VAL val = 0; - utils::extract_scalar(val_scalar, &val); - if (isIntegralType(out_type, /*includeBool=*/false)) { - ET_SWITCH_INT_TYPES(out_type, ctx, "clamp.out", CTYPE_OUT, [&]() { - if (is_out_of_bounds(val)) { - ET_LOG(Error, "%s value out of bounds", val_name); - is_valid = false; - } - }); - } else if (isFloatingType(out_type)) { - ET_SWITCH_FLOATH_TYPES(out_type, ctx, "clamp", CTYPE_OUT, [&]() { - if (std::isfinite(val) && - is_out_of_bounds(val)) { - ET_LOG(Error, "%s value out of bounds", val_name); - is_valid = false; - } - }); - } - }); + if (isIntegralType(out_type, /*includeBool=*/false)) { + const long val_long = utils::scalar_to(val_scalar); + ET_SWITCH_INT_TYPES(out_type, ctx, "clamp.out", CTYPE_OUT, [&]() { + if (is_out_of_bounds(val_long)) { + ET_LOG(Error, "%s value out of bounds", val_name); + is_valid = false; + } + }); + } else if (isFloatingType(out_type)) { + ET_SWITCH_FLOATHBF16_TYPES(out_type, ctx, "clamp.out", CTYPE_OUT, [&]() { + const double val_double = utils::scalar_to(val_scalar); + if (std::isfinite(val_double) && + is_out_of_bounds(val_double)) { + ET_LOG(Error, "%s value out of bounds", val_name); + is_valid = false; + } + }); + } return is_valid; } diff --git a/kernels/portable/cpu/op_constant_pad_nd.cpp b/kernels/portable/cpu/op_constant_pad_nd.cpp index e1e37e5b4d2..6e643e1b945 100644 --- a/kernels/portable/cpu/op_constant_pad_nd.cpp +++ b/kernels/portable/cpu/op_constant_pad_nd.cpp @@ -183,17 +183,10 @@ Tensor& constant_pad_nd_out( "Failed to resize output tensor."); ScalarType in_type = in.scalar_type(); - ScalarType value_type = utils::get_scalar_dtype(value); ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, "constant_pad_nd.out", CTYPE, [&]() { - CTYPE value_v; - ET_SWITCH_SCALAR_OBJ_TYPES( - value_type, ctx, "constant_pad_nd.out", CTYPE_VALUE, [&]() { - CTYPE_VALUE val = 0; - utils::extract_scalar(value, &val); - value_v = static_cast(val); - }); - constant_pad_nd_out_impl(in, pad, value_v, out); + const CTYPE value_casted = utils::scalar_to(value); + constant_pad_nd_out_impl(in, pad, value_casted, out); }); return out; diff --git a/kernels/portable/cpu/op_fill.cpp b/kernels/portable/cpu/op_fill.cpp index 3ed8557c29e..b985e2f4f07 100644 --- a/kernels/portable/cpu/op_fill.cpp +++ b/kernels/portable/cpu/op_fill.cpp @@ -26,7 +26,6 @@ Tensor& fill_scalar_out( (void)ctx; ScalarType a_type = a.scalar_type(); - ScalarType b_type = utils::get_scalar_dtype(b); ScalarType out_type = out.scalar_type(); ET_KERNEL_CHECK(ctx, a_type == out_type, InvalidArgument, out); @@ -43,12 +42,7 @@ Tensor& fill_scalar_out( "Failed to resize output tensor."); ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, "fill.Scalar_out", CTYPE_A, [&] { - CTYPE_A b_casted; - ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, "fill.Scalar_out", CTYPE_B, [&] { - CTYPE_B b_val = 0; - utils::extract_scalar(b, &b_val); - b_casted = static_cast(b_val); - }); + const CTYPE_A b_casted = utils::scalar_to(b); apply_unary_map_fn( [b_casted](const CTYPE_A val_a) { return b_casted; }, diff --git a/kernels/portable/cpu/op_full.cpp b/kernels/portable/cpu/op_full.cpp index 69b4c8fd150..83ffcad45a6 100644 --- a/kernels/portable/cpu/op_full.cpp +++ b/kernels/portable/cpu/op_full.cpp @@ -24,7 +24,6 @@ Tensor& full_out( Tensor& out) { (void)ctx; - ScalarType val_type = utils::get_scalar_dtype(fill_value); ScalarType out_type = out.scalar_type(); // Resize for dynamic shape @@ -37,18 +36,12 @@ Tensor& full_out( constexpr auto name = "full.out"; - ET_SWITCH_SCALAR_OBJ_TYPES(val_type, ctx, name, CTYPE_VAL, [&] { - CTYPE_VAL val; - ET_KERNEL_CHECK( - ctx, utils::extract_scalar(fill_value, &val), InvalidArgument, ); - - ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, name, CTYPE_OUT, [&] { - CTYPE_OUT val_casted = static_cast(val); - auto data_out = out.mutable_data_ptr(); - for (const auto i : c10::irange(out.numel())) { - data_out[i] = val_casted; - } - }); + ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, name, CTYPE_OUT, [&] { + CTYPE_OUT val_casted = utils::scalar_to(fill_value); + auto data_out = out.mutable_data_ptr(); + for (const auto i : c10::irange(out.numel())) { + data_out[i] = val_casted; + } }); return out; diff --git a/kernels/portable/cpu/op_hardtanh.cpp b/kernels/portable/cpu/op_hardtanh.cpp index 97355b0f17c..8ec73b07856 100644 --- a/kernels/portable/cpu/op_hardtanh.cpp +++ b/kernels/portable/cpu/op_hardtanh.cpp @@ -40,26 +40,13 @@ Tensor& hardtanh_out( ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out); ScalarType in_type = in.scalar_type(); - ScalarType min_type = utils::get_scalar_dtype(min); - ScalarType max_type = utils::get_scalar_dtype(max); ScalarType out_type = out.scalar_type(); ET_KERNEL_CHECK(ctx, in_type == out_type, InvalidArgument, out); ET_SWITCH_REALHBF16_TYPES(in_type, ctx, "hardtanh.out", CTYPE, [&]() { - CTYPE min_casted; - ET_SWITCH_SCALAR_OBJ_TYPES(min_type, ctx, "hardtanh.out", CTYPE_MIN, [&]() { - CTYPE_MIN min_val = 0; - utils::extract_scalar(min, &min_val); - min_casted = static_cast(min_val); - }); - - CTYPE max_casted; - ET_SWITCH_SCALAR_OBJ_TYPES(max_type, ctx, "hardtanh.out", CTYPE_MAX, [&]() { - CTYPE_MAX max_val = 0; - utils::extract_scalar(max, &max_val); - max_casted = static_cast(max_val); - }); + const CTYPE min_casted = utils::scalar_to(min); + const CTYPE max_casted = utils::scalar_to(max); apply_unary_map_fn( [min_casted, max_casted](const CTYPE val_in) { diff --git a/kernels/portable/cpu/op_leaky_relu.cpp b/kernels/portable/cpu/op_leaky_relu.cpp index a04365814c7..11860c8d129 100644 --- a/kernels/portable/cpu/op_leaky_relu.cpp +++ b/kernels/portable/cpu/op_leaky_relu.cpp @@ -39,19 +39,12 @@ Tensor& leaky_relu_out( ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out); ScalarType in_type = in.scalar_type(); - ScalarType sc_type = utils::get_scalar_dtype(negative_slope); ScalarType out_type = out.scalar_type(); ET_KERNEL_CHECK(ctx, in_type == out_type, InvalidArgument, out); ET_SWITCH_FLOATHBF16_TYPES(in_type, ctx, "leaky_relu.out", CTYPE, [&]() { - CTYPE negative_slope_casted = 0; - ET_SWITCH_SCALAR_OBJ_TYPES( - sc_type, ctx, "leaky_relu.out", CTYPE_MIN, [&]() { - CTYPE_MIN negative_slope_val = 0; - utils::extract_scalar(negative_slope, &negative_slope_val); - negative_slope_casted = static_cast(negative_slope_val); - }); + const CTYPE negative_slope_casted = utils::scalar_to(negative_slope); apply_unary_map_fn( [negative_slope_casted](const CTYPE val_in) { diff --git a/kernels/portable/cpu/op_scalar_tensor.cpp b/kernels/portable/cpu/op_scalar_tensor.cpp index 5be65a2e060..e111a9ac869 100644 --- a/kernels/portable/cpu/op_scalar_tensor.cpp +++ b/kernels/portable/cpu/op_scalar_tensor.cpp @@ -20,19 +20,21 @@ scalar_tensor_out(KernelRuntimeContext& ctx, const Scalar& s, Tensor& out) { ET_KERNEL_CHECK( ctx, resize_tensor(out, {}) == Error::Ok, InvalidArgument, out); - ScalarType s_type = utils::get_scalar_dtype(s); ScalarType out_type = out.scalar_type(); constexpr auto name = "scalar_tensor.out"; - ET_SWITCH_REAL_TYPES_AND3( - Half, Bool, BFloat16, out_type, ctx, name, CTYPE, [&]() { - ET_SWITCH_SCALAR_OBJ_TYPES(s_type, ctx, name, CTYPE_S, [&]() { - CTYPE_S val_s = 0; - utils::extract_scalar(s, &val_s); - out.mutable_data_ptr()[0] = convert(val_s); - }); - }); + if (s.isFloatingPoint() && + executorch::runtime::isIntegralType(out_type, false)) { + ET_SWITCH_INT_TYPES(out_type, ctx, name, CTYPE, [&]() { + out.mutable_data_ptr()[0] = + static_cast(utils::scalar_to(s)); + }); + } else { + ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, name, CTYPE, [&]() { + out.mutable_data_ptr()[0] = utils::scalar_to(s); + }); + } return out; } diff --git a/kernels/portable/cpu/op_scatter.cpp b/kernels/portable/cpu/op_scatter.cpp index f8f4b21264e..7de0ec4d5f9 100644 --- a/kernels/portable/cpu/op_scatter.cpp +++ b/kernels/portable/cpu/op_scatter.cpp @@ -151,17 +151,11 @@ Tensor& scatter_value_out( ET_KERNEL_CHECK( ctx, resize_tensor(out, in.sizes()) == Error::Ok, InvalidArgument, out); - ScalarType val_type = utils::get_scalar_dtype(value); - constexpr auto name = "scatter.value_out"; - ET_SWITCH_SCALAR_OBJ_TYPES(val_type, ctx, name, CTYPE_VAL, [&] { - CTYPE_VAL val; - ET_KERNEL_CHECK(ctx, utils::extract_scalar(value, &val), InvalidArgument, ); - - ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, name, CTYPE, [&]() { - scatter_value_helper(in, dim, index, val, out); - }); + ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, name, CTYPE, [&]() { + const CTYPE val = utils::scalar_to(value); + scatter_value_helper(in, dim, index, val, out); }); return out; diff --git a/kernels/portable/cpu/op_var.cpp b/kernels/portable/cpu/op_var.cpp index f09f1d92bc9..a95b3a9f167 100644 --- a/kernels/portable/cpu/op_var.cpp +++ b/kernels/portable/cpu/op_var.cpp @@ -127,12 +127,7 @@ Tensor& var_correction_out( double correction_val = 1; if (correction.has_value()) { - ScalarType corr_type = utils::get_scalar_dtype(correction.value()); - ET_SWITCH_SCALAR_OBJ_TYPES(corr_type, ctx, name, CTYPE_CORR, [&]() { - CTYPE_CORR corr_val = 0; - utils::extract_scalar(correction.value(), &corr_val); - correction_val = static_cast(corr_val); - }); + correction_val = utils::scalar_to(correction.value()); } const size_t num = get_reduced_dim_product(in, dim_list);