@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
See the License for the specific language governing permissions and
13
13
limitations under the License. */
14
14
15
+ #include < algorithm>
15
16
#include " paddle/fluid/operators/math/sequence_padding.h"
16
17
17
18
namespace paddle {
@@ -61,7 +62,7 @@ template <typename T>
61
62
class PaddingLoDTensorFunctor <platform::CUDADeviceContext, T> {
62
63
public:
63
64
void operator ()(const platform::CUDADeviceContext& context,
64
- const framework::LoDTensor& seq, framework::Tensor& padding,
65
+ const framework::LoDTensor& seq, framework::Tensor* padding,
65
66
bool norm_by_times) {
66
67
auto lod = seq.lod ();
67
68
PADDLE_ENFORCE_GT (lod.size (), 0UL ,
@@ -76,7 +77,7 @@ class PaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
76
77
" The first dimension of LoDTensor seq should be "
77
78
" equal to the sum of all sequences's length." );
78
79
79
- auto padding_dims = padding. dims ();
80
+ auto padding_dims = padding-> dims ();
80
81
PADDLE_ENFORCE_EQ (padding_dims.size (), 3UL ,
81
82
" The input padding should be a 3-D Tensor of shape "
82
83
" [max_sequence_length, num_sequences, sequence_width]." );
@@ -97,8 +98,8 @@ class PaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
97
98
" width of sequence in LoDTensor seq." );
98
99
99
100
if (!norm_by_times && num_sequences == 1UL ) {
100
- TensorCopy (seq, context.GetPlace (), context, & padding);
101
- padding. Resize (padding_dims);
101
+ TensorCopy (seq, context.GetPlace (), context, padding);
102
+ padding-> Resize (padding_dims);
102
103
return ;
103
104
}
104
105
@@ -117,7 +118,7 @@ class PaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
117
118
dim3 grid (grid_dim_x, grid_dim_y);
118
119
119
120
const T* seq_data = seq.data <T>();
120
- T* padding_data = padding. data <T>();
121
+ T* padding_data = padding-> data <T>();
121
122
if (norm_by_times) {
122
123
SequencePaddingKernel<T, 1 , 1 ><<<grid, threads, 0 , context.stream()>>> (
123
124
padding_data, const_cast <T*>(seq_data),
@@ -136,16 +137,16 @@ template <typename T>
136
137
class UnpaddingLoDTensorFunctor <platform::CUDADeviceContext, T> {
137
138
public:
138
139
void operator ()(const platform::CUDADeviceContext& context,
139
- framework::LoDTensor& seq, const framework::Tensor& padding,
140
+ framework::LoDTensor* seq, const framework::Tensor& padding,
140
141
bool norm_by_times) {
141
- auto lod = seq. lod ();
142
+ auto lod = seq-> lod ();
142
143
PADDLE_ENFORCE_GT (lod.size (), 0UL ,
143
144
" The lod of LoDTensor seq should not be null." );
144
145
145
146
const size_t level = 0 ;
146
147
framework::LoD abs_offset_lod = framework::ToAbsOffset (lod);
147
148
148
- auto seq_dims = seq. dims ();
149
+ auto seq_dims = seq-> dims ();
149
150
PADDLE_ENFORCE_EQ (seq_dims[0 ],
150
151
static_cast <int64_t >(abs_offset_lod[level].back ()),
151
152
" The first dimension of LoDTensor seq should be "
@@ -166,14 +167,14 @@ class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
166
167
" The second dimension of Tensor padding should be "
167
168
" the number of sequences in LoDTensor seq." );
168
169
169
- const int64_t sequence_width = seq. numel () / seq_dims[0 ];
170
+ const int64_t sequence_width = seq-> numel () / seq_dims[0 ];
170
171
PADDLE_ENFORCE_EQ (padding_dims[2 ], sequence_width,
171
172
" The third dimension of Tensor padding should be the "
172
173
" width of sequence in LoDTensor seq." );
173
174
174
175
if (!norm_by_times && num_sequences == 1UL ) {
175
- TensorCopy (padding, context.GetPlace (), context, & seq);
176
- seq. Resize (seq_dims);
176
+ TensorCopy (padding, context.GetPlace (), context, seq);
177
+ seq-> Resize (seq_dims);
177
178
return ;
178
179
}
179
180
@@ -192,7 +193,7 @@ class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
192
193
dim3 grid (grid_dim_x, grid_dim_y);
193
194
194
195
const T* padding_data = padding.data <T>();
195
- T* seq_data = seq. data <T>();
196
+ T* seq_data = seq-> data <T>();
196
197
if (norm_by_times) {
197
198
SequencePaddingKernel<T, 1 , 0 ><<<grid, threads, 0 , context.stream()>>> (
198
199
const_cast <T*>(padding_data), seq_data,
0 commit comments