From 00787097dbf912ec854c8ab15b8e2ca9680d7f89 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Wed, 22 Jan 2025 14:24:05 -0800 Subject: [PATCH 1/2] Update [ghstack-poisoned] --- kernels/portable/cpu/op_nonzero.cpp | 7 +++---- kernels/test/op_nonzero_test.cpp | 8 ++++---- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/kernels/portable/cpu/op_nonzero.cpp b/kernels/portable/cpu/op_nonzero.cpp index 6c149ec4de5..77f80126d9f 100644 --- a/kernels/portable/cpu/op_nonzero.cpp +++ b/kernels/portable/cpu/op_nonzero.cpp @@ -88,10 +88,9 @@ Tensor& nonzero_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { ET_KERNEL_CHECK(ctx, check_nonzero_args(in, out), InvalidArgument, out); - ET_SWITCH_REAL_TYPES_AND( - Bool, in.scalar_type(), ctx, "nonzero.out", CTYPE, [&] { - nonzero(ctx, in, out); - }); + ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, "nonzero.out", CTYPE, [&] { + nonzero(ctx, in, out); + }); return out; } diff --git a/kernels/test/op_nonzero_test.cpp b/kernels/test/op_nonzero_test.cpp index 2eb828413ef..a0b32942050 100644 --- a/kernels/test/op_nonzero_test.cpp +++ b/kernels/test/op_nonzero_test.cpp @@ -28,9 +28,9 @@ class OpNonzeroTest : public OperatorTest { void test_dtype() { TensorFactory tf_input; TensorFactory tf_long; - // clang-format off - Tensor a = tf_input.make(/*sizes=*/{2, 2}, /*data=*/{2, 0, - 2, 4}); + // clang-format offs + Tensor a = tf_input.make( + /*sizes=*/{2, 2}, /*data=*/{CTYPE(2), CTYPE(0), CTYPE(2), CTYPE(4)}); // clang-format on Tensor out = tf_long.zeros({3, 2}); @@ -45,7 +45,7 @@ class OpNonzeroTest : public OperatorTest { TEST_F(OpNonzeroTest, AllDtypesSupported) { #define TEST_ENTRY(ctype, dtype) test_dtype(); - ET_FORALL_REAL_TYPES(TEST_ENTRY); + ET_FORALL_REALHBBF16_TYPES(TEST_ENTRY); #undef TEST_ENTRY } From 719038eff28feef859040f10def30e4bab54547c Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Wed, 22 Jan 2025 15:48:47 -0800 Subject: [PATCH 2/2] Update [ghstack-poisoned] --- kernels/test/op_nonzero_test.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/kernels/test/op_nonzero_test.cpp b/kernels/test/op_nonzero_test.cpp index a0b32942050..c1948a439b9 100644 --- a/kernels/test/op_nonzero_test.cpp +++ b/kernels/test/op_nonzero_test.cpp @@ -28,10 +28,8 @@ class OpNonzeroTest : public OperatorTest { void test_dtype() { TensorFactory tf_input; TensorFactory tf_long; - // clang-format offs Tensor a = tf_input.make( /*sizes=*/{2, 2}, /*data=*/{CTYPE(2), CTYPE(0), CTYPE(2), CTYPE(4)}); - // clang-format on Tensor out = tf_long.zeros({3, 2}); op_nonzero_out(a, out);