Skip to content

Commit 0078709

Browse files
committed
Update
[ghstack-poisoned]
1 parent 6fe6870 commit 0078709

File tree

2 files changed

+7
-8
lines changed

2 files changed

+7
-8
lines changed

kernels/portable/cpu/op_nonzero.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,10 +88,9 @@ Tensor& nonzero_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) {
8888

8989
ET_KERNEL_CHECK(ctx, check_nonzero_args(in, out), InvalidArgument, out);
9090

91-
ET_SWITCH_REAL_TYPES_AND(
92-
Bool, in.scalar_type(), ctx, "nonzero.out", CTYPE, [&] {
93-
nonzero<CTYPE>(ctx, in, out);
94-
});
91+
ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, "nonzero.out", CTYPE, [&] {
92+
nonzero<CTYPE>(ctx, in, out);
93+
});
9594

9695
return out;
9796
}

kernels/test/op_nonzero_test.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@ class OpNonzeroTest : public OperatorTest {
2828
void test_dtype() {
2929
TensorFactory<DTYPE> tf_input;
3030
TensorFactory<ScalarType::Long> tf_long;
31-
// clang-format off
32-
Tensor a = tf_input.make(/*sizes=*/{2, 2}, /*data=*/{2, 0,
33-
2, 4});
31+
// clang-format offs
32+
Tensor a = tf_input.make(
33+
/*sizes=*/{2, 2}, /*data=*/{CTYPE(2), CTYPE(0), CTYPE(2), CTYPE(4)});
3434
// clang-format on
3535
Tensor out = tf_long.zeros({3, 2});
3636

@@ -45,7 +45,7 @@ class OpNonzeroTest : public OperatorTest {
4545

4646
TEST_F(OpNonzeroTest, AllDtypesSupported) {
4747
#define TEST_ENTRY(ctype, dtype) test_dtype<ctype, ScalarType::dtype>();
48-
ET_FORALL_REAL_TYPES(TEST_ENTRY);
48+
ET_FORALL_REALHBBF16_TYPES(TEST_ENTRY);
4949
#undef TEST_ENTRY
5050
}
5151

0 commit comments

Comments
 (0)