19
19
namespace paddle {
20
20
namespace operators {
21
21
22
- using Tensor = framework::Tensor ;
22
+ using LoDTensor = framework::LoDTensor ;
23
23
using SelectedRows = framework::SelectedRows;
24
24
25
25
template <typename T>
26
26
class LookupTableKernel : public framework ::OpKernel<T> {
27
27
public:
28
28
void Compute (const framework::ExecutionContext& context) const override {
29
- auto table_t = context.Input <Tensor >(" W" ); // float tensor
30
- auto ids_t = context.Input <Tensor >(" Ids" ); // int tensor
31
- auto output_t = context.Output <Tensor >(" Out" ); // float tensor
29
+ auto * table_t = context.Input <LoDTensor >(" W" ); // float tensor
30
+ auto * ids_t = context.Input <LoDTensor >(" Ids" ); // int tensor
31
+ auto * output_t = context.Output <LoDTensor >(" Out" ); // float tensor
32
32
33
33
int N = table_t ->dims ()[0 ];
34
34
int D = table_t ->dims ()[1 ];
35
- auto ids = ids_t ->data <int64_t >();
36
- auto table = table_t ->data <T>();
37
- auto output = output_t ->mutable_data <T>(context.GetPlace ());
35
+ auto * ids = ids_t ->data <int64_t >();
36
+ auto * table = table_t ->data <T>();
37
+ auto * output = output_t ->mutable_data <T>(context.GetPlace ());
38
38
for (int64_t i = 0 ; i < ids_t ->numel (); ++i) {
39
39
PADDLE_ENFORCE_LT (ids[i], N);
40
40
PADDLE_ENFORCE_GE (ids[i], 0 );
@@ -49,9 +49,9 @@ class LookupTableGradKernel : public framework::OpKernel<T> {
49
49
void Compute (const framework::ExecutionContext& context) const override {
50
50
bool is_sparse = context.Attr <bool >(" is_sparse" );
51
51
if (is_sparse) {
52
- auto * ids = context.Input <Tensor >(" Ids" );
53
- auto * table = context.Input <Tensor >(" W" );
54
- auto * d_output = context.Input <Tensor >(framework::GradVarName (" Out" ));
52
+ auto * ids = context.Input <LoDTensor >(" Ids" );
53
+ auto * table = context.Input <LoDTensor >(" W" );
54
+ auto * d_output = context.Input <LoDTensor >(framework::GradVarName (" Out" ));
55
55
auto * d_table = context.Output <SelectedRows>(framework::GradVarName (" W" ));
56
56
57
57
auto * ids_data = ids->data <int64_t >();
@@ -76,10 +76,10 @@ class LookupTableGradKernel : public framework::OpKernel<T> {
76
76
PADDLE_ENFORCE_EQ (d_table_value->dims (), d_output->dims ());
77
77
memcpy (d_table_data, d_output_data, sizeof (T) * d_output->numel ());
78
78
} else {
79
- auto * ids = context.Input <Tensor >(" Ids" );
80
- auto * d_output = context.Input <Tensor >(framework::GradVarName (" Out" ));
81
- auto * d_table = context.Output <Tensor >(framework::GradVarName (" W" ));
82
- auto * table = context.Input <Tensor >(" W" );
79
+ auto * ids = context.Input <LoDTensor >(" Ids" );
80
+ auto * d_output = context.Input <LoDTensor >(framework::GradVarName (" Out" ));
81
+ auto * d_table = context.Output <LoDTensor >(framework::GradVarName (" W" ));
82
+ auto * table = context.Input <LoDTensor >(" W" );
83
83
84
84
auto * ids_data = ids->data <int64_t >();
85
85
auto ids_dim = ids->dims ();
0 commit comments