From 924968cc84fedb7ae109d463dc28219d39ce2992 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Thu, 23 Jan 2025 13:50:49 -0800 Subject: [PATCH 1/2] Update [ghstack-poisoned] --- kernels/portable/cpu/op_upsample_bilinear2d.cpp | 2 +- kernels/test/CMakeLists.txt | 1 + kernels/test/op_upsample_bilinear2d_test.cpp | 6 +++--- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/kernels/portable/cpu/op_upsample_bilinear2d.cpp b/kernels/portable/cpu/op_upsample_bilinear2d.cpp index ea2ff86b31f..931a1705885 100644 --- a/kernels/portable/cpu/op_upsample_bilinear2d.cpp +++ b/kernels/portable/cpu/op_upsample_bilinear2d.cpp @@ -123,7 +123,7 @@ Tensor& upsample_bilinear2d_vec_out( const auto kernel_scale_w = area_pixel_compute_scale( in.sizes()[3], out.sizes()[3], align_corners, scale_w); - ET_SWITCH_REAL_TYPES( + ET_SWITCH_REALHBF16_TYPES( in.scalar_type(), ctx, "upsample_bilinear2d.out", CTYPE, [&]() { upsample_bilinear2d_kernel_impl( in, align_corners, kernel_scale_h, kernel_scale_w, out); diff --git a/kernels/test/CMakeLists.txt b/kernels/test/CMakeLists.txt index 65ec529ecdf..67bced07771 100644 --- a/kernels/test/CMakeLists.txt +++ b/kernels/test/CMakeLists.txt @@ -229,6 +229,7 @@ set(all_test_sources "op_trunc_test.cpp" "op_unbind_copy_test.cpp" "op_unsqueeze_copy_test.cpp" + "op_upsample_bilinear2d_test.cpp" "op_var_test.cpp" "op_view_copy_test.cpp" "op_where_test.cpp" diff --git a/kernels/test/op_upsample_bilinear2d_test.cpp b/kernels/test/op_upsample_bilinear2d_test.cpp index c7b5332275b..4a97068560f 100644 --- a/kernels/test/op_upsample_bilinear2d_test.cpp +++ b/kernels/test/op_upsample_bilinear2d_test.cpp @@ -302,9 +302,9 @@ TEST_F(OpUpsampleBilinear2dTest, SmokeTestAlignCornersScales) { } TEST_F(OpUpsampleBilinear2dTest, DType) { -#define TEST_ENTRY(ctype, dtype) \ - test_upsample_bilinear2d_dtype(); \ - ET_FORALL_REAL_TYPES(TEST_ENTRY); +#define TEST_ENTRY(ctype, dtype) \ + test_upsample_bilinear2d_dtype(); + ET_FORALL_REALHBF16_TYPES(TEST_ENTRY); #undef TEST_ENTRY } From cdd702ad81bd2d66abaf0543e0ff74a24af5650a Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Thu, 23 Jan 2025 13:58:45 -0800 Subject: [PATCH 2/2] Update [ghstack-poisoned] --- kernels/portable/cpu/op_upsample_nearest2d.cpp | 2 +- kernels/test/CMakeLists.txt | 1 + kernels/test/op_upsample_nearest2d_test.cpp | 7 +++---- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/kernels/portable/cpu/op_upsample_nearest2d.cpp b/kernels/portable/cpu/op_upsample_nearest2d.cpp index 93a88588d83..43ab32707dd 100644 --- a/kernels/portable/cpu/op_upsample_nearest2d.cpp +++ b/kernels/portable/cpu/op_upsample_nearest2d.cpp @@ -79,7 +79,7 @@ Tensor& upsample_nearest2d_vec_out( const auto kernel_scale_w = area_pixel_compute_scale( in.sizes()[3], out.sizes()[3], false, scale_w); - ET_SWITCH_REAL_TYPES( + ET_SWITCH_REALHBF16_TYPES( in.scalar_type(), ctx, "upsample_nearest2d.out", CTYPE, [&]() { upsample_nearest2d_kernel_impl( in, kernel_scale_h, kernel_scale_w, out); diff --git a/kernels/test/CMakeLists.txt b/kernels/test/CMakeLists.txt index 67bced07771..1bd63a2a5fe 100644 --- a/kernels/test/CMakeLists.txt +++ b/kernels/test/CMakeLists.txt @@ -230,6 +230,7 @@ set(all_test_sources "op_unbind_copy_test.cpp" "op_unsqueeze_copy_test.cpp" "op_upsample_bilinear2d_test.cpp" + "op_upsample_nearest2d_test.cpp" "op_var_test.cpp" "op_view_copy_test.cpp" "op_where_test.cpp" diff --git a/kernels/test/op_upsample_nearest2d_test.cpp b/kernels/test/op_upsample_nearest2d_test.cpp index 12301688002..93737b436ad 100644 --- a/kernels/test/op_upsample_nearest2d_test.cpp +++ b/kernels/test/op_upsample_nearest2d_test.cpp @@ -52,7 +52,6 @@ class OpUpsampleNearest2dTest : public OperatorTest { op_upsample_nearest2d_out( input, OptionalArrayRef({output_size.data(), output_size.size()}), - true, {}, out); @@ -254,9 +253,9 @@ TEST_F(OpUpsampleNearest2dTest, MultiBatchAndChannel) { } TEST_F(OpUpsampleNearest2dTest, DType) { -#define TEST_ENTRY(ctype, dtype) \ - test_upsample_nearest2d_dtype(); \ - ET_FORALL_REAL_TYPES(TEST_ENTRY); +#define TEST_ENTRY(ctype, dtype) \ + test_upsample_nearest2d_dtype(); + ET_FORALL_REALHBF16_TYPES(TEST_ENTRY); #undef TEST_ENTRY }