@@ -31,93 +31,28 @@ class SeqExpandKernel : public framework::OpKernel<T> {
31
31
auto * out = context.Output <LoDTensor>(" Out" );
32
32
const T* x_data = x->data <T>();
33
33
auto x_dims = x->dims ();
34
- auto x_lod = x->lod ();
35
-
36
- framework::Vector<size_t > level;
37
- size_t num = (x_lod.size () == 0 ) ? (x->dims ()[0 ] + 1 ) : x_lod[0 ].size ();
38
- for (int i = 0 ; i < num; ++i) {
39
- level.push_back (i);
40
- }
41
- x_lod.push_back (level);
42
-
43
- size_t repeat = static_cast <size_t >(context.Attr <int >(" repeat" ));
44
- framework::Vector<size_t > scales;
45
- if (repeat != 0 ) {
46
- for (int i = 0 ; i < x_lod[0 ].size () - 1 ; ++i) {
47
- scales.push_back (repeat);
48
- }
49
- std::vector<int64_t > dims = framework::vectorize (x->dims ());
50
- dims[0 ] = dims[0 ] * repeat;
51
- auto out_dims = framework::make_ddim (dims);
52
- out->Resize (out_dims);
53
- } else {
54
- auto * y = context.Input <LoDTensor>(" Y" );
55
- auto y_lod = y->lod ();
56
- auto y_abs_lod = y_lod.ToAbsOffset ();
57
- auto x_abs_lod = x_lod.ToAbsOffset ();
58
- for (int i = 0 ; i < y_abs_lod[0 ].size () - 1 ; ++i) {
59
- scales.push_back ((y_abs_lod[0 ][i + 1 ] - y_abs_lod[0 ][i]) /
60
- (x_abs_lod[0 ][i + 1 ] - x_abs_lod[0 ][i]));
61
- }
62
- out->Resize (y->dims ());
63
- }
64
-
65
- framework::Vector<size_t > indexes;
66
- for (int size_t i = 0 ; i < x_lod[0 ]; ++i) {
67
- indexes[i] = x_lod[0 ];
68
- }
69
- framework::LoD out_lod;
70
- auto level0 = framework::expand_lod (indexes, x_lod[0 ], scales, false );
71
- out_lod.push_back (level0);
72
- for (int i = 1 ; i < x_lod.size (); ++i) {
73
- for (int j = 0 ; j < indexes.size (); ++j) {
74
- indexes[j] = x_lod[i - 1 ][indexes[j]];
75
- }
76
- out_lod.push_back (framework::expand_lod (x_lod[i], indexes, scales, true ));
77
- }
78
-
34
+ auto * y = context.Input <LoDTensor>(" Y" );
35
+ PADDLE_ENFORCE_EQ (x_dims[0 ], y->lod ().back ().size () - 1 ,
36
+ " The size of last lod level in Input(Y)"
37
+ " must be equal to dims[0] of Input(X)." );
38
+ out->set_lod (y->lod ());
39
+ out->Resize (y->dims ());
40
+ auto place = context.GetEigenDevice <Place>();
79
41
size_t element_len = framework::product (x_dims) / x_dims[0 ];
80
42
T* out_data = out->mutable_data <T>(context.GetPlace ());
81
-
82
- // copy data
83
- auto place = context.GetPlace ();
84
- size_t count = 0 ;
85
- if (platform::is_cpu_place (place)) {
86
- auto & cpu_place = boost::get<platform::CPUPlace>(place);
87
- for (size_t i = 0 ; i < scales.size (); ++i) {
88
- count = element_len * (x_abs_lod[0 ][i + 1 ] - x_abs_lod[0 ][i]);
89
- for (size_t j = 0 ; j < scales[i]; ++j) {
90
- memory::Copy (cpu_place, out_data, cpu_place, x_data,
91
- sizeof (T) * count);
92
- out_data += count;
93
- }
94
- x_data += count;
95
- }
96
- } else {
97
- #ifdef PADDLE_WITH_CUDA
98
- auto & gpu_place = boost::get<platform::GPUPlace>(place);
99
- auto stream = reinterpret_cast <const platform::CUDADeviceContext&>(
100
- context.device_context ())
101
- .stream ();
102
- for (size_t i = 0 ; i < scales.size (); ++i) {
103
- count = element_len * (x_abs_lod[0 ][i + 1 ] - x_abs_lod[0 ][i]);
104
- for (size_t j = 0 ; j < scales[i]; ++j) {
105
- memory::Copy (gpu_place, out_data, gpu_place, x_data,
106
- sizeof (T) * count, stream);
107
- out_data += count;
108
- }
109
- x_data += count;
110
- }
111
- #else
112
- PADDLE_THROW (" Paddle is not compiled with GPU" );
113
- #endif
114
- }
115
-
116
- out->set_lod (out_lod);
117
- for (size_t i = 0 ; i < lod.size ; i++) {
118
- for (size_t j = 0 ; j < lod[i].size (); j++) {
119
- LOG (INFO) << " lod[" << i << " ][" << j " ] = " << lod[i][j];
120
- }
43
+ auto out_starts = out->lod ().back ();
44
+
45
+ for (size_t i = 0 ; i < out_starts.size () - 1 ; i++) {
46
+ int scale = out_starts[i + 1 ] - out_starts[i];
47
+ Eigen::TensorMap<
48
+ Eigen::Tensor<const T, 2 , Eigen::RowMajor, Eigen::DenseIndex>>
49
+ x_t (x_data, 1 , element_len);
50
+ Eigen::TensorMap<Eigen::Tensor<T, 2 , Eigen::RowMajor, Eigen::DenseIndex>>
51
+ out_t (out_data, scale, element_len);
52
+ Eigen::array<int , 2 > cast ({scale, 1 });
53
+ out_t .device (place) = x_t .broadcast (cast);
54
+ x_data += element_len;
55
+ out_data += element_len * scale;
121
56
}
122
57
}
123
58
};
@@ -130,25 +65,24 @@ class SeqExpandGradKernel : public framework::OpKernel<T> {
130
65
auto * x = context.Input <LoDTensor>(" X" );
131
66
auto * out = context.Input <LoDTensor>(" Out" );
132
67
auto * d_x = context.Output <LoDTensor>(framework::GradVarName (" X" ));
133
- auto out_lod = out->lod ();
134
- auto out_abs_lod = out_lod.ToAbsOffset ();
68
+ auto out_last_level = out->lod ().back ();
135
69
d_x->set_lod (x->lod ());
136
70
const T* d_out_data = d_out->data <T>();
137
71
auto d_out_dims = d_out->dims ();
138
72
T* d_x_data = d_x->mutable_data <T>(context.GetPlace ());
139
73
size_t element_len = framework::product (d_out_dims) / d_out_dims[0 ];
140
- for ( size_t i = 0 ; i < out-> NumElements (); ++i) {
141
- size_t ele_count = out_abs_lod[ 0 ][i + 1 ] - out_abs_lod[ 0 ][i];
142
- size_t repeat = out-> NumElements ( 0 , i) ;
143
- Eigen::TensorMap<Eigen::Tensor< const T, 2 >> d_out_t (
144
- d_out_data, static_cast < int >(repeat),
145
- static_cast <int >((ele_count * element_len) / repeat));
146
- Eigen::TensorMap<Eigen::Tensor<T, 1 >> d_x_t (
147
- d_x_data, static_cast <int >((ele_count * element_len) / repeat ));
74
+
75
+ for ( size_t i = 0 ; i < out_last_level. size () - 1 ; ++i) {
76
+ size_t repeat = out_last_level[i + 1 ] - out_last_level[i] ;
77
+ Eigen::TensorMap<
78
+ Eigen::Tensor< const T, 2 , Eigen::RowMajor, Eigen::DenseIndex>>
79
+ d_out_t (d_out_data, static_cast <int >(repeat), element_len );
80
+ Eigen::TensorMap<Eigen::Tensor<T, 1 , Eigen::RowMajor, Eigen::DenseIndex>>
81
+ d_x_t ( d_x_data, static_cast <int >(element_len));
148
82
auto place = context.GetEigenDevice <Place>();
149
83
d_x_t .device (place) = d_out_t .sum (Eigen::array<int , 1 >({{0 }}));
150
- d_out_data += (ele_count * element_len);
151
- d_x_data += ((ele_count * element_len) / repeat) ;
84
+ d_out_data += (repeat * element_len);
85
+ d_x_data += element_len;
152
86
}
153
87
}
154
88
};
0 commit comments