Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions kernels/portable/cpu/op_nonzero.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,9 @@ Tensor& nonzero_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) {

ET_KERNEL_CHECK(ctx, check_nonzero_args(in, out), InvalidArgument, out);

ET_SWITCH_REAL_TYPES_AND(
Bool, in.scalar_type(), ctx, "nonzero.out", CTYPE, [&] {
nonzero<CTYPE>(ctx, in, out);
});
ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, "nonzero.out", CTYPE, [&] {
nonzero<CTYPE>(ctx, in, out);
});

return out;
}
Expand Down
8 changes: 4 additions & 4 deletions kernels/test/op_nonzero_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ class OpNonzeroTest : public OperatorTest {
void test_dtype() {
TensorFactory<DTYPE> tf_input;
TensorFactory<ScalarType::Long> tf_long;
// clang-format off
Tensor a = tf_input.make(/*sizes=*/{2, 2}, /*data=*/{2, 0,
2, 4});
// clang-format offs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

offs --> off?

Tensor a = tf_input.make(
/*sizes=*/{2, 2}, /*data=*/{CTYPE(2), CTYPE(0), CTYPE(2), CTYPE(4)});
// clang-format on
Tensor out = tf_long.zeros({3, 2});

Expand All @@ -45,7 +45,7 @@ class OpNonzeroTest : public OperatorTest {

TEST_F(OpNonzeroTest, AllDtypesSupported) {
#define TEST_ENTRY(ctype, dtype) test_dtype<ctype, ScalarType::dtype>();
ET_FORALL_REAL_TYPES(TEST_ENTRY);
ET_FORALL_REALHBBF16_TYPES(TEST_ENTRY);
#undef TEST_ENTRY
}

Expand Down
Loading