Skip to content

Commit 49e6d98

Browse files
committed
embedding
1 parent 13fdca7 commit 49e6d98

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

kernels/portable/cpu/op_embedding.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,13 @@ void embedding_kernel(
3535
int64_t nbytes_per_entry = weight.size(1) * weight.element_size();
3636
const char* w_data = weight.const_data_ptr<char>();
3737
char* out_data = out.mutable_data_ptr<char>();
38-
const CTYPE* indices_ptr = indices.const_data_ptr<CTYPE>();
38+
CTYPE* indices_ptr = indices.mutable_data_ptr<CTYPE>();
3939
ssize_t weight_height = weight.size(0);
4040
const auto indices_numel = indices.numel();
4141
for (int i = 0; i < indices_numel; i++) {
4242
// Ensure index is larger than 0 and smaller than weight.size(0)
43+
indices_ptr[i] = indices_ptr[i] < 0 ? 0 : indices_ptr[i];
44+
indices_ptr[i] = indices_ptr[i] >= weight_height ? weight_height-1 : indices_ptr[i];
4345
ET_KERNEL_CHECK_MSG(
4446
ctx,
4547
indices_ptr[i] < weight_height,
@@ -118,7 +120,7 @@ Tensor& embedding_out(
118120

119121
ET_SWITCH_TWO_TYPES(
120122
Long, Int, ix_type, ctx, "op_embedding.out", CTYPE, [&]() {
121-
embedding_kernel<CTYPE>(ctx, weight, indices, out);
123+
embedding_kernel<long>(ctx, weight, indices, out);
122124
});
123125

124126
return out;

0 commit comments

Comments
 (0)