File tree Expand file tree Collapse file tree 2 files changed +8
-2
lines changed
python/paddle/v2/framework/tests Expand file tree Collapse file tree 2 files changed +8
-2
lines changed Original file line number Diff line number Diff line change @@ -32,6 +32,9 @@ class LookupTableOp : public framework::OperatorWithKernel {
32
32
auto table_dims = ctx->GetInputDim (" W" );
33
33
auto ids_dims = ctx->GetInputDim (" Ids" );
34
34
35
+ PADDLE_ENFORCE_EQ (ids_dims.size (), 2 );
36
+ PADDLE_ENFORCE_EQ (ids_dims[1 ], 1 );
37
+
35
38
ctx->SetOutputDim (" Out" , {ids_dims[0 ], table_dims[1 ]});
36
39
ctx->ShareLoD (" Ids" , /* ->*/ " Out" );
37
40
}
@@ -53,7 +56,9 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
53
56
" which is a learnable parameter." );
54
57
AddInput (" Ids" ,
55
58
" An input with type int32 or int64"
56
- " contains the ids to be looked up in W." );
59
+ " contains the ids to be looked up in W."
60
+ " Ids must be a column vector with rank = 2."
61
+ " The 2nd dimension size must be 1" );
57
62
AddOutput (" Out" , " The lookup results, which have the same type with W." );
58
63
AddComment (R"DOC(
59
64
This operator is used to perform lookups on the parameter W,
Original file line number Diff line number Diff line change @@ -8,7 +8,8 @@ def setUp(self):
8
8
self .op_type = "lookup_table"
9
9
table = np .random .random ((17 , 31 )).astype ("float32" )
10
10
ids = np .random .randint (0 , 17 , 4 ).astype ("int32" )
11
- self .inputs = {'W' : table , 'Ids' : ids }
11
+ ids_expand = np .expand_dims (ids , axis = 1 )
12
+ self .inputs = {'W' : table , 'Ids' : ids_expand }
12
13
self .outputs = {'Out' : table [ids ]}
13
14
14
15
def test_check_output (self ):
You can’t perform that action at this time.
0 commit comments