@@ -22,32 +22,54 @@ limitations under the License. */
22
22
namespace paddle {
23
23
namespace operators {
24
24
25
+ using Tensor = framework::Tensor;
25
26
using LoDTensor = framework::LoDTensor;
26
27
using SelectedRows = framework::SelectedRows;
27
28
28
29
template <typename T>
29
30
class LookupTableKernel : public framework ::OpKernel<T> {
30
31
public:
31
32
void Compute (const framework::ExecutionContext& context) const override {
32
- auto * table_t = context.Input <LoDTensor>(" W" ); // float tensor
33
- auto * ids_t = context.Input <LoDTensor>(" Ids" ); // int tensor
34
- auto * output_t = context.Output <LoDTensor>(" Out" ); // float tensor
33
+ auto * table_t = context.Input <LoDTensor>(" W" ); // float tensor
34
+ auto * ids_var = context.InputVar (" Ids" ); // int tensor
35
+
36
+ int64_t * ids;
37
+ int64_t ids_numel;
38
+ Tensor* output_t ;
39
+
40
+ // ids_var_types also can be LOD_TENSOR_ARRAY, it's used as concat_rows.
41
+ // Maybe near future we will add concat_rows op.
42
+ if (ids_var->IsType <LoDTensor>()) {
43
+ auto * ids_t = context.Input <LoDTensor>(" Ids" );
44
+ output_t = context.Output <LoDTensor>(" Out" );
45
+ ids = const_cast <int64_t *>(ids_t ->data <int64_t >());
46
+ ids_numel = ids_t ->numel ();
47
+ } else if (ids_var->IsType <SelectedRows>()) {
48
+ auto * ids_t = context.Input <SelectedRows>(" Ids" );
49
+ output_t =
50
+ const_cast <Tensor*>(&(context.Output <SelectedRows>(" Out" )->value ()));
51
+ ids = const_cast <int64_t *>(ids_t ->rows ().data ());
52
+ ids_numel = ids_t ->rows ().size ();
53
+ output_t ->Resize ({ids_numel, table_t ->dims ()[1 ]});
54
+ } else {
55
+ PADDLE_THROW (" Unsupported Variable Type of Ids" );
56
+ }
57
+
35
58
int64_t padding_idx = context.Attr <int64_t >(" padding_idx" );
36
59
37
60
int N = table_t ->dims ()[0 ];
38
61
int D = table_t ->dims ()[1 ];
39
- auto * ids = ids_t ->data <int64_t >();
40
62
auto * table = table_t ->data <T>();
41
63
auto * output = output_t ->mutable_data <T>(context.GetPlace ());
42
64
43
65
if (padding_idx == -1 ) {
44
- for (int64_t i = 0 ; i < ids_t -> numel () ; ++i) {
66
+ for (int64_t i = 0 ; i < ids_numel ; ++i) {
45
67
PADDLE_ENFORCE_LT (ids[i], N);
46
68
PADDLE_ENFORCE_GE (ids[i], 0 );
47
69
memcpy (output + i * D, table + ids[i] * D, D * sizeof (T));
48
70
}
49
71
} else {
50
- for (int64_t i = 0 ; i < ids_t -> numel () ; ++i) {
72
+ for (int64_t i = 0 ; i < ids_numel ; ++i) {
51
73
if (ids[i] == padding_idx) {
52
74
memset (output + i * D, 0 , D * sizeof (T));
53
75
} else {
0 commit comments