Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion kernels/portable/cpu/op_topk.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ std::tuple<Tensor&, Tensor&> topk_values(

bool temp_mem_allocated = false;

ET_SWITCH_REALH_TYPES(in.scalar_type(), ctx, name, CTYPE, [&]() {
ET_SWITCH_REALHBF16_TYPES(in.scalar_type(), ctx, name, CTYPE, [&]() {
using elem_t = std::pair<CTYPE, int64_t>;
size_t temp_mem_size = nonempty_size(in, dim) * sizeof(elem_t);

Expand Down
1 change: 1 addition & 0 deletions kernels/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ set(all_test_sources
"op_tan_test.cpp"
"op_tanh_test.cpp"
"op_to_copy_test.cpp"
"op_topk_test.cpp"
"op_transpose_copy_test.cpp"
"op_tril_test.cpp"
"op_trunc_test.cpp"
Expand Down
53 changes: 30 additions & 23 deletions kernels/test/op_topk_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,32 +118,39 @@ class OpTopkValuesTest : public ::testing::Test {
// first.
torch::executor::runtime_init();
}

template <ScalarType DTYPE>
void run_smoke_test() {
TensorFactory<DTYPE> tfDtype;
TensorFactory<ScalarType::Long> tfLong;

Tensor input =
tfDtype.make({3, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
int64_t k = 2;
int64_t dim = 0;
bool largest = true;
bool sorted = true;
Tensor values = tfDtype.zeros({2, 2, 2});
Tensor indices = tfLong.zeros({2, 2, 2});
Tensor values_expected = tfDtype.make({2, 2, 2}, {9, 10, 11, 12, 5, 6, 7, 8});
Tensor indices_expected = tfLong.make({2, 2, 2}, {2, 2, 2, 2, 1, 1, 1, 1});
op_topk_values(input, k, dim, largest, sorted, values, indices);
EXPECT_TENSOR_CLOSE(values, values_expected);
EXPECT_TENSOR_EQ(indices, indices_expected);

largest = false;
values_expected = tfDtype.make({2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8});
indices_expected = tfLong.make({2, 2, 2}, {0, 0, 0, 0, 1, 1, 1, 1});
op_topk_values(input, k, dim, largest, sorted, values, indices);
EXPECT_TENSOR_CLOSE(values, values_expected);
EXPECT_TENSOR_EQ(indices, indices_expected);

}
};

TEST_F(OpTopkValuesTest, SmokeTest) {
TensorFactory<ScalarType::Float> tfFloat;
TensorFactory<ScalarType::Long> tfLong;

Tensor input =
tfFloat.make({3, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
int64_t k = 2;
int64_t dim = 0;
bool largest = true;
bool sorted = true;
Tensor values = tfFloat.zeros({2, 2, 2});
Tensor indices = tfLong.zeros({2, 2, 2});
Tensor values_expected = tfFloat.make({2, 2, 2}, {9, 10, 11, 12, 5, 6, 7, 8});
Tensor indices_expected = tfLong.make({2, 2, 2}, {2, 2, 2, 2, 1, 1, 1, 1});
op_topk_values(input, k, dim, largest, sorted, values, indices);
EXPECT_TENSOR_CLOSE(values, values_expected);
EXPECT_TENSOR_EQ(indices, indices_expected);

largest = false;
values_expected = tfFloat.make({2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8});
indices_expected = tfLong.make({2, 2, 2}, {0, 0, 0, 0, 1, 1, 1, 1});
op_topk_values(input, k, dim, largest, sorted, values, indices);
EXPECT_TENSOR_CLOSE(values, values_expected);
EXPECT_TENSOR_EQ(indices, indices_expected);
#define RUN_SMOKE_TEST(ctype, dtype) run_smoke_test<ScalarType::dtype>();
ET_FORALL_REALHBF16_TYPES(RUN_SMOKE_TEST);
}

TEST_F(OpTopkValuesTest, NonPartialSort) {
Expand Down
Loading