@@ -73,22 +73,20 @@ Tensor& select_scatter_out(
7373 ScalarType in_type = in.scalar_type ();
7474 ScalarType src_type = src.scalar_type ();
7575
76- ET_SWITCH_REAL_TYPES_AND (
77- Bool, in_type, ctx, " select_scatter.out" , CTYPE, [&]() {
78- ET_SWITCH_REAL_TYPES_AND (
79- Bool, src_type, ctx, " select_scatter.out" , CTYPE_SRC, [&]() {
80- CTYPE* const out_data = out.mutable_data_ptr <CTYPE>();
81- const CTYPE_SRC* const src_data = src.const_data_ptr <CTYPE_SRC>();
82-
83- for (size_t i = 0 ; i < leading_dims; ++i) {
84- for (size_t j = 0 ; j < trailing_stride; ++j) {
85- out_data[start_offset + i * out_step + j] =
86- convert<CTYPE, CTYPE_SRC>(
87- src_data[i * trailing_stride + j]);
88- }
89- }
90- });
91- });
76+ ET_SWITCH_REALHBBF16_TYPES (in_type, ctx, " select_scatter.out" , CTYPE, [&]() {
77+ ET_SWITCH_REALHBBF16_TYPES (
78+ src_type, ctx, " select_scatter.out" , CTYPE_SRC, [&]() {
79+ CTYPE* const out_data = out.mutable_data_ptr <CTYPE>();
80+ const CTYPE_SRC* const src_data = src.const_data_ptr <CTYPE_SRC>();
81+
82+ for (size_t i = 0 ; i < leading_dims; ++i) {
83+ for (size_t j = 0 ; j < trailing_stride; ++j) {
84+ out_data[start_offset + i * out_step + j] =
85+ convert<CTYPE, CTYPE_SRC>(src_data[i * trailing_stride + j]);
86+ }
87+ }
88+ });
89+ });
9290
9391 return out;
9492}
0 commit comments