@@ -37,28 +37,27 @@ Tensor& eq_tensor_out(
3737 ET_KERNEL_CHECK (
3838 ctx, tensors_have_same_dim_order (a, b, out), InvalidArgument, out);
3939
40- ET_SWITCH_REAL_TYPES_AND (Bool, a_type, ctx, " eq.Scalar_out" , CTYPE_A, [&]() {
41- ET_SWITCH_REAL_TYPES_AND (
42- Bool, b_type, ctx, " eq.Scalar_out" , CTYPE_B, [&]() {
43- using CTYPE_IN =
44- typename torch::executor::promote_types<CTYPE_A, CTYPE_B>::type;
45- ET_DCHECK (
46- CppTypeToScalarType<CTYPE_IN>::value ==
47- promoteTypes (a_type, b_type));
48- ET_SWITCH_REAL_TYPES_AND (
49- Bool, out_type, ctx, " eq.Scalar_out" , CTYPE_OUT, [&]() {
50- apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
51- [](const CTYPE_A val_a, const CTYPE_B val_b) {
52- CTYPE_IN a_casted = static_cast <CTYPE_IN>(val_a);
53- CTYPE_IN b_casted = static_cast <CTYPE_IN>(val_b);
54- bool value = a_casted == b_casted;
55- return static_cast <CTYPE_OUT>(value);
56- },
57- a,
58- b,
59- out);
60- });
61- });
40+ constexpr auto name = " eq.Tensor_out" ;
41+
42+ ET_SWITCH_REALHBBF16_TYPES (a_type, ctx, name, CTYPE_A, [&]() {
43+ ET_SWITCH_REALHBBF16_TYPES (b_type, ctx, name, CTYPE_B, [&]() {
44+ using CTYPE_IN =
45+ typename torch::executor::promote_types<CTYPE_A, CTYPE_B>::type;
46+ ET_DCHECK (
47+ CppTypeToScalarType<CTYPE_IN>::value == promoteTypes (a_type, b_type));
48+ ET_SWITCH_REALHBBF16_TYPES (out_type, ctx, name, CTYPE_OUT, [&]() {
49+ apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
50+ [](const CTYPE_A val_a, const CTYPE_B val_b) {
51+ CTYPE_IN a_casted = static_cast <CTYPE_IN>(val_a);
52+ CTYPE_IN b_casted = static_cast <CTYPE_IN>(val_b);
53+ bool value = a_casted == b_casted;
54+ return static_cast <CTYPE_OUT>(value);
55+ },
56+ a,
57+ b,
58+ out);
59+ });
60+ });
6261 });
6362
6463 return out;
@@ -86,27 +85,28 @@ Tensor& eq_scalar_out(
8685 ET_KERNEL_CHECK (
8786 ctx, tensors_have_same_dim_order (a, out), InvalidArgument, out);
8887
89- ET_SWITCH_REAL_TYPES_AND (Bool, a_type, ctx, " eq.Scalar_out" , CTYPE_A, [&]() {
90- ET_SWITCH_SCALAR_OBJ_TYPES (b_type, ctx, " eq.Scalar_out" , CTYPE_B, [&]() {
88+ constexpr auto name = " eq.Scalar_out" ;
89+
90+ ET_SWITCH_REALHBBF16_TYPES (a_type, ctx, name, CTYPE_A, [&]() {
91+ ET_SWITCH_SCALAR_OBJ_TYPES (b_type, ctx, name, CTYPE_B, [&]() {
9192 using CTYPE_IN =
9293 typename torch::executor::promote_types<CTYPE_A, CTYPE_B>::type;
9394 ET_DCHECK (
9495 CppTypeToScalarType<CTYPE_IN>::value == promoteTypes (a_type, b_type));
95- ET_SWITCH_REAL_TYPES_AND (
96- Bool, out_type, ctx, " eq.Scalar_out" , CTYPE_OUT, [&]() {
97- CTYPE_B val_b = 0 ;
98- utils::extract_scalar (b, &val_b);
99- apply_unary_map_fn (
100- [val_b](const CTYPE_A val_a) {
101- CTYPE_IN a_casted = static_cast <CTYPE_IN>(val_a);
102- CTYPE_IN b_casted = static_cast <CTYPE_IN>(val_b);
103- bool value = a_casted == b_casted;
104- return static_cast <CTYPE_OUT>(value);
105- },
106- a.const_data_ptr <CTYPE_A>(),
107- out.mutable_data_ptr <CTYPE_OUT>(),
108- out.numel ());
109- });
96+ ET_SWITCH_REALHBBF16_TYPES (out_type, ctx, name, CTYPE_OUT, [&]() {
97+ CTYPE_B val_b = 0 ;
98+ utils::extract_scalar (b, &val_b);
99+ apply_unary_map_fn (
100+ [val_b](const CTYPE_A val_a) {
101+ CTYPE_IN a_casted = static_cast <CTYPE_IN>(val_a);
102+ CTYPE_IN b_casted = static_cast <CTYPE_IN>(val_b);
103+ bool value = a_casted == b_casted;
104+ return static_cast <CTYPE_OUT>(value);
105+ },
106+ a.const_data_ptr <CTYPE_A>(),
107+ out.mutable_data_ptr <CTYPE_OUT>(),
108+ out.numel ());
109+ });
110110 });
111111 });
112112
0 commit comments