Skip to content

Commit 2c31ea9

Browse files
authored
Merge pull request #13424 from chengduoZH/refine_seq_concat
Refine seq_concat
2 parents 5996e22 + 2445950 commit 2c31ea9

File tree

14 files changed

+325
-431
lines changed

14 files changed

+325
-431
lines changed

paddle/fluid/API.spec

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,7 @@ paddle.fluid.layers.stack ArgSpec(args=['x', 'axis'], varargs=None, keywords=Non
175175
paddle.fluid.layers.pad2d ArgSpec(args=['input', 'paddings', 'mode', 'pad_value', 'data_format', 'name'], varargs=None, keywords=None, defaults=([0, 0, 0, 0], 'constant', 0.0, 'NCHW', None))
176176
paddle.fluid.layers.unstack ArgSpec(args=['x', 'axis', 'num'], varargs=None, keywords=None, defaults=(0, None))
177177
paddle.fluid.layers.sequence_enumerate ArgSpec(args=['input', 'win_size', 'pad_value', 'name'], varargs=None, keywords=None, defaults=(0, None))
178+
paddle.fluid.layers.sequence_concat ArgSpec(args=['input', 'name'], varargs=None, keywords=None, defaults=(None,))
178179
paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True))
179180
paddle.fluid.layers.open_recordio_file ArgSpec(args=['filename', 'shapes', 'lod_levels', 'dtypes', 'pass_num', 'for_parallel'], varargs=None, keywords=None, defaults=(1, True))
180181
paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None))

paddle/fluid/framework/op_desc.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -441,7 +441,10 @@ static void InitInferShapeFuncs() {
441441

442442
for (auto &kern_pair : OperatorWithKernel::AllOpKernels()) {
443443
auto op_type = kern_pair.first;
444-
auto &op_info = info_map.at(op_type);
444+
auto it = info_map.find(op_type);
445+
PADDLE_ENFORCE(it != info_map.end(), "%s has not been registered",
446+
op_type);
447+
auto &op_info = it->second;
445448
auto op = static_cast<OperatorWithKernel *>(op_info.Creator()(
446449
"", VariableNameMap{}, VariableNameMap{}, AttributeMap{}));
447450
if (op_info.infer_shape_) { // infer_shape has been registered.

paddle/fluid/operators/concat_op.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ class ConcatOpGrad : public framework::OperatorWithKernel {
9595

9696
void InferShape(framework::InferShapeContext *ctx) const override {
9797
ctx->SetOutputsDim(framework::GradVarName("X"), ctx->GetInputsDim("X"));
98+
ctx->ShareLoD("X", framework::GradVarName("X"));
9899
}
99100
};
100101

paddle/fluid/operators/concat_op.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,9 @@ class ConcatGradKernel : public framework::OpKernel<T> {
109109
auto& dev_ctx = ctx.template device_context<DeviceContext>();
110110
paddle::operators::math::ConcatGradFunctor<DeviceContext, T>
111111
concat_grad_functor;
112-
concat_grad_functor(dev_ctx, *out_grad, ins, static_cast<int>(axis),
113-
&outputs);
112+
concat_grad_functor(dev_ctx, *out_grad,
113+
ctx.MultiInput<framework::Tensor>("X"),
114+
static_cast<int>(axis), &outputs);
114115
}
115116
}
116117
};

paddle/fluid/operators/detail/safe_ref.h

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#pragma once
16-
16+
#include <vector>
1717
#include "paddle/fluid/platform/enforce.h"
1818

1919
namespace paddle {
@@ -24,10 +24,22 @@ namespace detail {
2424
* and passed by `args`
2525
*/
2626
template <typename T, typename... ARGS>
27-
inline T &Ref(T *ptr, ARGS &&... args) {
27+
inline T& Ref(T* ptr, ARGS&&... args) {
2828
PADDLE_ENFORCE(ptr != nullptr, args...);
2929
return *ptr;
3030
}
31+
32+
template <typename T, typename... ARGS>
33+
inline std::vector<std::reference_wrapper<T>> VectorRef(
34+
const std::vector<T*>& vec, ARGS&&... args) {
35+
std::vector<std::reference_wrapper<T>> result;
36+
result.reserve(vec.size());
37+
for (auto* ptr : vec) {
38+
result.emplace_back(Ref(ptr, args...));
39+
}
40+
return result;
41+
}
42+
3143
} // namespace detail
3244
} // namespace operators
3345
} // namespace paddle

paddle/fluid/operators/math/concat.cc

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ template <typename T>
2727
class ConcatFunctor<platform::CPUDeviceContext, T> {
2828
public:
2929
void operator()(const platform::CPUDeviceContext& context,
30-
const std::vector<framework::Tensor>& input, const int axis,
30+
const std::vector<framework::Tensor>& input, int axis,
3131
framework::Tensor* output) {
3232
// TODO(zcd): Add input data validity checking
3333
int num = input.size();
@@ -71,7 +71,7 @@ class ConcatGradFunctor<platform::CPUDeviceContext, T> {
7171
public:
7272
void operator()(const platform::CPUDeviceContext& context,
7373
const framework::Tensor& input,
74-
const std::vector<const framework::LoDTensor*>& ref_inputs,
74+
const std::vector<const framework::Tensor*>& ref_inputs,
7575
const int axis, std::vector<framework::Tensor*>* outputs) {
7676
// TODO(zcd): Add input data validity checking
7777
size_t num = outputs->size();
@@ -109,16 +109,11 @@ class ConcatGradFunctor<platform::CPUDeviceContext, T> {
109109
}
110110
}
111111
};
112+
#define DEFINE_FUNCTOR(type) \
113+
template class ConcatFunctor<platform::CPUDeviceContext, type>; \
114+
template class ConcatGradFunctor<platform::CPUDeviceContext, type>;
112115

113-
template class ConcatFunctor<platform::CPUDeviceContext, int>;
114-
template class ConcatFunctor<platform::CPUDeviceContext, int64_t>;
115-
template class ConcatFunctor<platform::CPUDeviceContext, float>;
116-
template class ConcatFunctor<platform::CPUDeviceContext, double>;
117-
118-
template class ConcatGradFunctor<platform::CPUDeviceContext, int>;
119-
template class ConcatGradFunctor<platform::CPUDeviceContext, int64_t>;
120-
template class ConcatGradFunctor<platform::CPUDeviceContext, float>;
121-
template class ConcatGradFunctor<platform::CPUDeviceContext, double>;
116+
FOR_ALL_TYPES(DEFINE_FUNCTOR);
122117

123118
} // namespace math
124119
} // namespace operators

paddle/fluid/operators/math/concat.cu

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License. */
1717
#include "paddle/fluid/framework/mixed_vector.h"
1818
#include "paddle/fluid/operators/math/concat.h"
1919
#include "paddle/fluid/platform/cuda_primitives.h"
20+
#include "paddle/fluid/platform/float16.h"
2021

2122
namespace paddle {
2223
namespace operators {
@@ -118,7 +119,7 @@ template <typename T>
118119
class ConcatFunctor<platform::CUDADeviceContext, T> {
119120
public:
120121
void operator()(const platform::CUDADeviceContext& context,
121-
const std::vector<framework::Tensor>& input, const int axis,
122+
const std::vector<framework::Tensor>& input, int axis,
122123
framework::Tensor* output) {
123124
// TODO(zcd): Add input data validity checking
124125
int in_num = input.size();
@@ -192,8 +193,8 @@ class ConcatGradFunctor<platform::CUDADeviceContext, T> {
192193
public:
193194
void operator()(const platform::CUDADeviceContext& context,
194195
const framework::Tensor& input,
195-
const std::vector<const framework::LoDTensor*>& ref_inputs,
196-
const int axis, std::vector<framework::Tensor*>* outputs) {
196+
const std::vector<const framework::Tensor*>& ref_inputs,
197+
int axis, std::vector<framework::Tensor*>* outputs) {
197198
// TODO(zcd): Add input data validity checking
198199
int o_num = outputs->size();
199200
int out_row = 1;
@@ -261,15 +262,11 @@ class ConcatGradFunctor<platform::CUDADeviceContext, T> {
261262
}
262263
};
263264

264-
template class ConcatFunctor<platform::CUDADeviceContext, int>;
265-
template class ConcatFunctor<platform::CUDADeviceContext, int64_t>;
266-
template class ConcatFunctor<platform::CUDADeviceContext, float>;
267-
template class ConcatFunctor<platform::CUDADeviceContext, double>;
265+
#define DEFINE_FUNCTOR(type) \
266+
template class ConcatFunctor<platform::CUDADeviceContext, type>; \
267+
template class ConcatGradFunctor<platform::CUDADeviceContext, type>
268268

269-
template class ConcatGradFunctor<platform::CUDADeviceContext, int>;
270-
template class ConcatGradFunctor<platform::CUDADeviceContext, int64_t>;
271-
template class ConcatGradFunctor<platform::CUDADeviceContext, float>;
272-
template class ConcatGradFunctor<platform::CUDADeviceContext, double>;
269+
FOR_ALL_TYPES(DEFINE_FUNCTOR);
273270

274271
} // namespace math
275272
} // namespace operators

paddle/fluid/operators/math/concat.h

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ template <typename DeviceContext, typename T>
3737
class ConcatFunctor {
3838
public:
3939
void operator()(const DeviceContext& context,
40-
const std::vector<framework::Tensor>& input, const int axis,
40+
const std::vector<framework::Tensor>& input, int axis,
4141
framework::Tensor* output);
4242
};
4343

@@ -57,10 +57,21 @@ template <typename DeviceContext, typename T>
5757
class ConcatGradFunctor {
5858
public:
5959
void operator()(const DeviceContext& context, const framework::Tensor& input,
60-
const std::vector<const framework::LoDTensor*>& ref_inputs,
61-
const int axis, std::vector<framework::Tensor*>* outputs);
60+
const std::vector<const framework::Tensor*>& ref_inputs,
61+
int axis, std::vector<framework::Tensor*>* outputs);
6262
};
6363

6464
} // namespace math
6565
} // namespace operators
6666
} // namespace paddle
67+
68+
#define FOR_ALL_TYPES(macro) \
69+
macro(int); \
70+
macro(float); \
71+
macro(double); \
72+
macro(bool); \
73+
macro(int64_t); \
74+
macro(int16_t); \
75+
macro(uint8_t); \
76+
macro(int8_t); \
77+
macro(::paddle::platform::float16)

0 commit comments

Comments
 (0)