Skip to content

Commit 40e7caf

Browse files
QiJunereyoung
authored andcommitted
ensure ids in lookup table op must be a column vector (#4987)
* ensure ids in lookup table op must be a column vector * follow comments
1 parent 7d653c4 commit 40e7caf

File tree

2 files changed

+8
-2
lines changed

2 files changed

+8
-2
lines changed

paddle/operators/lookup_table_op.cc

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ class LookupTableOp : public framework::OperatorWithKernel {
3232
auto table_dims = ctx->GetInputDim("W");
3333
auto ids_dims = ctx->GetInputDim("Ids");
3434

35+
PADDLE_ENFORCE_EQ(ids_dims.size(), 2);
36+
PADDLE_ENFORCE_EQ(ids_dims[1], 1);
37+
3538
ctx->SetOutputDim("Out", {ids_dims[0], table_dims[1]});
3639
ctx->ShareLoD("Ids", /*->*/ "Out");
3740
}
@@ -53,7 +56,9 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
5356
" which is a learnable parameter.");
5457
AddInput("Ids",
5558
"An input with type int32 or int64"
56-
"contains the ids to be looked up in W.");
59+
"contains the ids to be looked up in W."
60+
"Ids must be a column vector with rank = 2."
61+
"The 2nd dimension size must be 1");
5762
AddOutput("Out", "The lookup results, which have the same type with W.");
5863
AddComment(R"DOC(
5964
This operator is used to perform lookups on the parameter W,

python/paddle/v2/framework/tests/test_lookup_table_op.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@ def setUp(self):
88
self.op_type = "lookup_table"
99
table = np.random.random((17, 31)).astype("float32")
1010
ids = np.random.randint(0, 17, 4).astype("int32")
11-
self.inputs = {'W': table, 'Ids': ids}
11+
ids_expand = np.expand_dims(ids, axis=1)
12+
self.inputs = {'W': table, 'Ids': ids_expand}
1213
self.outputs = {'Out': table[ids]}
1314

1415
def test_check_output(self):

0 commit comments

Comments
 (0)