@@ -55,21 +55,20 @@ Tensor& stack_out(
5555 const size_t ninputs = tensors.size ();
5656
5757 const auto out_type = out.scalar_type ();
58- ET_SWITCH_REAL_TYPES_AND (Bool, out_type, ctx, " stack.out" , CTYPE_OUT, [&] {
58+ ET_SWITCH_REALHBBF16_TYPES ( out_type, ctx, " stack.out" , CTYPE_OUT, [&] {
5959 CTYPE_OUT* out_ptr = out.mutable_data_ptr <CTYPE_OUT>();
6060 for (size_t i = 0 ; i < outer; ++i) {
6161 for (size_t j = 0 ; j < ninputs; ++j) {
6262 const auto in_type = tensors[j].scalar_type ();
63- ET_SWITCH_REAL_TYPES_AND (
64- Bool, in_type, ctx, " stack.out" , CTYPE_IN, [&] {
65- const CTYPE_IN* const in_ptr =
66- tensors[j].const_data_ptr <CTYPE_IN>() + i * inner;
63+ ET_SWITCH_REALHBBF16_TYPES (in_type, ctx, " stack.out" , CTYPE_IN, [&] {
64+ const CTYPE_IN* const in_ptr =
65+ tensors[j].const_data_ptr <CTYPE_IN>() + i * inner;
6766
68- for (size_t k = 0 ; k < inner; ++k) {
69- out_ptr[k] = static_cast <CTYPE_OUT>(in_ptr[k]);
70- }
71- out_ptr += inner;
72- });
67+ for (size_t k = 0 ; k < inner; ++k) {
68+ out_ptr[k] = static_cast <CTYPE_OUT>(in_ptr[k]);
69+ }
70+ out_ptr += inner;
71+ });
7372 }
7473 }
7574 });
0 commit comments