From c803e04ecd85effdea96a222324337a118001d64 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Fri, 17 Jan 2025 16:09:39 -0800 Subject: [PATCH 1/2] Update [ghstack-poisoned] --- kernels/portable/cpu/op_split_with_sizes_copy.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/kernels/portable/cpu/op_split_with_sizes_copy.cpp b/kernels/portable/cpu/op_split_with_sizes_copy.cpp index ab7dce1d1af..31233adfbf3 100644 --- a/kernels/portable/cpu/op_split_with_sizes_copy.cpp +++ b/kernels/portable/cpu/op_split_with_sizes_copy.cpp @@ -71,8 +71,8 @@ void split_with_sizes_copy_out( ScalarType in_type = in.scalar_type(); ScalarType out_type = out[0].scalar_type(); - ET_SWITCH_REAL_TYPES_AND(Bool, in_type, ctx, __func__, CTYPE_IN, [&]() { - ET_SWITCH_REAL_TYPES_AND(Bool, out_type, ctx, __func__, CTYPE_OUT, [&]() { + ET_SWITCH_REALHBF16_TYPES(in_type, ctx, __func__, CTYPE_IN, [&]() { + ET_SWITCH_REALHBF16_TYPES(out_type, ctx, __func__, CTYPE_OUT, [&]() { const CTYPE_IN* in_data = in.const_data_ptr(); // Iterate through list of out tensors From d8ff739e4cc0472b3054a8b086f79b2b85ad5758 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Wed, 22 Jan 2025 09:42:35 -0800 Subject: [PATCH 2/2] Update [ghstack-poisoned] --- kernels/portable/cpu/op_split_with_sizes_copy.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/kernels/portable/cpu/op_split_with_sizes_copy.cpp b/kernels/portable/cpu/op_split_with_sizes_copy.cpp index 31233adfbf3..f6bfffdbf04 100644 --- a/kernels/portable/cpu/op_split_with_sizes_copy.cpp +++ b/kernels/portable/cpu/op_split_with_sizes_copy.cpp @@ -71,8 +71,8 @@ void split_with_sizes_copy_out( ScalarType in_type = in.scalar_type(); ScalarType out_type = out[0].scalar_type(); - ET_SWITCH_REALHBF16_TYPES(in_type, ctx, __func__, CTYPE_IN, [&]() { - ET_SWITCH_REALHBF16_TYPES(out_type, ctx, __func__, CTYPE_OUT, [&]() { + ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, __func__, CTYPE_IN, [&]() { + ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, __func__, CTYPE_OUT, [&]() { const CTYPE_IN* in_data = in.const_data_ptr(); // Iterate through list of out tensors