Skip to content

Commit 0d75191

Browse files
committed
speed up lod_tensor to array and array to lod_tensor
1 parent 437debf commit 0d75191

File tree

7 files changed

+159
-48
lines changed

7 files changed

+159
-48
lines changed

paddle/fluid/operators/array_to_lod_tensor_op.cc

Lines changed: 57 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ distributed under the License is distributed on an "AS IS" BASIS,
1111
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
14+
#include <paddle/fluid/operators/math/concat.h>
1415
#include <numeric>
1516

1617
#include "paddle/fluid/framework/lod_rank_table.h"
@@ -24,6 +25,50 @@ namespace operators {
2425

2526
using LoD = framework::LoD;
2627

28+
class ArrayToLoDFunctor;
29+
template <typename DeviceContext>
30+
struct ArrayToLoDFunctorImpl {
31+
const ArrayToLoDFunctor *prev_functor_;
32+
DeviceContext *dev_ctx_;
33+
34+
template <typename T>
35+
void apply();
36+
};
37+
38+
struct ArrayToLoDFunctor : public boost::static_visitor<void> {
39+
std::vector<framework::Tensor> in;
40+
mutable framework::Tensor *out;
41+
42+
template <typename Place>
43+
void operator()(Place place) const {
44+
auto &pool = platform::DeviceContextPool::Instance();
45+
if (std::is_same<Place, platform::CPUPlace>::value) {
46+
Apply(static_cast<platform::CPUDeviceContext *>(pool.Get(place)));
47+
} else {
48+
#ifdef PADDLE_WITH_CUDA
49+
Apply(static_cast<platform::CUDADeviceContext *>(pool.Get(place)));
50+
#else
51+
PADDLE_THROW("Fluid is not compiled with CUDA");
52+
#endif
53+
}
54+
}
55+
56+
template <typename DeviceContext>
57+
void Apply(DeviceContext *dev_ctx) const {
58+
ArrayToLoDFunctorImpl<DeviceContext> functor;
59+
functor.dev_ctx_ = dev_ctx;
60+
functor.prev_functor_ = this;
61+
framework::VisitDataType(framework::ToDataType(out->type()), functor);
62+
}
63+
};
64+
65+
template <typename DeviceContext>
66+
template <typename T>
67+
void ArrayToLoDFunctorImpl<DeviceContext>::apply() {
68+
math::ConcatFunctor<DeviceContext, T> func;
69+
func(*dev_ctx_, prev_functor_->in, 0, prev_functor_->out);
70+
}
71+
2772
class ArrayToLoDTensorOp : public framework::OperatorBase {
2873
public:
2974
ArrayToLoDTensorOp(const std::string &type,
@@ -47,14 +92,18 @@ class ArrayToLoDTensorOp : public framework::OperatorBase {
4792
int rank = x[0].dims().size();
4893
platform::Place place = x[0].place();
4994
std::type_index data_type = x[0].type();
50-
framework::DDim ins_dims = framework::slice_ddim(x[0].dims(), 1, rank);
5195
int64_t batch_size = x[0].dims()[0];
96+
framework::DDim ins_dims = rank > 1
97+
? framework::slice_ddim(x[0].dims(), 1, rank)
98+
: framework::make_ddim({0});
5299
for (size_t i = 1; i < x.size(); ++i) {
53-
PADDLE_ENFORCE_EQ(framework::slice_ddim(x[i].dims(), 1, rank), ins_dims,
100+
auto ins_i_dims = rank > 1 ? framework::slice_ddim(x[i].dims(), 1, rank)
101+
: framework::make_ddim({0});
102+
PADDLE_ENFORCE_EQ(ins_i_dims, ins_dims,
54103
"The dimension of the %zu'th element in LoDTensorArray "
55104
"differs from previous ones.",
56105
i);
57-
PADDLE_ENFORCE(platform::places_are_same_class(x[i].place(), place),
106+
PADDLE_ENFORCE(x[i].place() == place,
58107
"The place class of the %zu'th element in LoDTensorArray "
59108
"differs from previous ones.",
60109
i);
@@ -82,13 +131,14 @@ class ArrayToLoDTensorOp : public framework::OperatorBase {
82131
// Build LoDTensor `out`
83132
framework::LoD *out_lod = out->mutable_lod();
84133
out_lod->clear();
85-
size_t out_offset = 0;
86134
auto prefix_lod = rank_table.coarse_lod();
87135
prefix_lod.emplace_back();
88136
auto &cur_level_lod = prefix_lod.back();
89137
cur_level_lod.push_back(0);
138+
ArrayToLoDFunctor functor;
90139
for (size_t idx : table_item_idx) {
91140
cur_level_lod.push_back(cur_level_lod.back() + table_items[idx].length);
141+
PADDLE_ENFORCE_LE(table_items[idx].length, x.size());
92142
for (size_t x_idx = 0; x_idx < table_items[idx].length; ++x_idx) {
93143
auto lod_and_offset = framework::GetSubLoDAndAbsoluteOffset(
94144
x[x_idx].lod(), idx, idx + 1, 0);
@@ -106,17 +156,11 @@ class ArrayToLoDTensorOp : public framework::OperatorBase {
106156
if (len == 0) {
107157
continue;
108158
}
109-
auto slice = out->Slice(out_offset, out_offset + len);
110-
111-
platform::DeviceContextPool &pool =
112-
platform::DeviceContextPool::Instance();
113-
auto &dev_ctx = *pool.Get(place);
114-
115-
framework::TensorCopy(x[x_idx].Slice(start_offset, end_offset), place,
116-
dev_ctx, &slice);
117-
out_offset += len;
159+
functor.in.emplace_back(x[x_idx].Slice(start_offset, end_offset));
118160
}
119161
}
162+
functor.out = out;
163+
platform::VisitPlace(place, functor);
120164
out_lod->insert(out_lod->begin(), prefix_lod.begin(), prefix_lod.end());
121165
}
122166
};

paddle/fluid/operators/concat_op.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ class ConcatOpGrad : public framework::OperatorWithKernel {
9595

9696
void InferShape(framework::InferShapeContext *ctx) const override {
9797
ctx->SetOutputsDim(framework::GradVarName("X"), ctx->GetInputsDim("X"));
98+
ctx->ShareLoD("X", framework::GradVarName("X"));
9899
}
99100
};
100101

paddle/fluid/operators/concat_op.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,9 @@ class ConcatGradKernel : public framework::OpKernel<T> {
109109
auto& dev_ctx = ctx.template device_context<DeviceContext>();
110110
paddle::operators::math::ConcatGradFunctor<DeviceContext, T>
111111
concat_grad_functor;
112-
concat_grad_functor(dev_ctx, *out_grad, ins, static_cast<int>(axis),
113-
&outputs);
112+
concat_grad_functor(dev_ctx, *out_grad,
113+
ctx.MultiInput<framework::Tensor>("X"),
114+
static_cast<int>(axis), &outputs);
114115
}
115116
}
116117
};

paddle/fluid/operators/lod_tensor_to_array_op.cc

Lines changed: 70 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,13 @@ distributed under the License is distributed on an "AS IS" BASIS,
1111
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
14+
#include <algorithm>
15+
#include <map>
1416
#include "paddle/fluid/framework/lod_rank_table.h"
1517
#include "paddle/fluid/framework/lod_tensor_array.h"
1618
#include "paddle/fluid/framework/op_registry.h"
1719
#include "paddle/fluid/operators/detail/safe_ref.h"
20+
#include "paddle/fluid/operators/math/concat.h"
1821
#include "paddle/fluid/platform/device_context.h"
1922
#include "paddle/fluid/platform/port.h"
2023

@@ -26,6 +29,61 @@ struct CopyRange {
2629
size_t end;
2730
};
2831

32+
struct LoDTensorToArrayFunctor;
33+
34+
template <typename DeviceContext>
35+
struct LoDTensorToArrayFunctorImpl {
36+
const LoDTensorToArrayFunctor *prev_functor_;
37+
DeviceContext *dev_ctx_;
38+
template <typename T>
39+
void apply();
40+
};
41+
42+
struct LoDTensorToArrayFunctor : public boost::static_visitor<void> {
43+
std::vector<const framework::Tensor *> ref_inputs_;
44+
mutable std::vector<framework::Tensor *> outputs_;
45+
const framework::Tensor &input_;
46+
47+
explicit LoDTensorToArrayFunctor(const framework::Tensor &input)
48+
: input_(input) {}
49+
50+
void AddOutput(framework::Tensor *t) {
51+
outputs_.emplace_back(t);
52+
ref_inputs_.emplace_back(t);
53+
}
54+
55+
template <typename Place>
56+
void operator()(Place place) const {
57+
auto &pool = platform::DeviceContextPool::Instance();
58+
auto *dev_ctx = pool.Get(place);
59+
if (std::is_same<Place, platform::CPUPlace>::value) {
60+
Apply(static_cast<platform::CPUDeviceContext *>(dev_ctx));
61+
} else {
62+
#ifdef PADDLE_WITH_CUDA
63+
Apply(static_cast<platform::CUDADeviceContext *>(dev_ctx));
64+
#else
65+
PADDLE_THROW("Not compiled with cuda");
66+
#endif
67+
}
68+
}
69+
70+
template <typename DeviceContext>
71+
void Apply(DeviceContext *dev_ctx) const {
72+
LoDTensorToArrayFunctorImpl<DeviceContext> func;
73+
func.prev_functor_ = this;
74+
func.dev_ctx_ = dev_ctx;
75+
framework::VisitDataType(framework::ToDataType(input_.type()), func);
76+
}
77+
};
78+
79+
template <typename DeviceContext>
80+
template <typename T>
81+
void LoDTensorToArrayFunctorImpl<DeviceContext>::apply() {
82+
math::ConcatGradFunctor<DeviceContext, T> func;
83+
func(*dev_ctx_, prev_functor_->input_, prev_functor_->ref_inputs_, 0,
84+
&prev_functor_->outputs_);
85+
}
86+
2987
class LoDTensorToArrayOp : public framework::OperatorBase {
3088
public:
3189
LoDTensorToArrayOp(const std::string &type,
@@ -72,6 +130,11 @@ class LoDTensorToArrayOp : public framework::OperatorBase {
72130
copy_ranges[t].emplace_back(CopyRange{start_offset, end_offset});
73131
}
74132
}
133+
134+
auto &outputs = *const_cast<framework::Scope &>(scope)
135+
.Var()
136+
->GetMutable<std::map<size_t, framework::Tensor>>();
137+
75138
for (size_t i = 0; i < max_seq_len; ++i) {
76139
auto &ranges = copy_ranges[i];
77140
size_t height = std::accumulate(
@@ -90,17 +153,16 @@ class LoDTensorToArrayOp : public framework::OperatorBase {
90153
// out[i][offset: offset+len] = x[each_range.begin: each_range.end]
91154
auto slice = out[i].Slice(static_cast<int>(offset),
92155
static_cast<int>(offset + len));
93-
94-
platform::DeviceContextPool &pool =
95-
platform::DeviceContextPool::Instance();
96-
auto &dev_ctx = *pool.Get(place);
97-
98-
framework::TensorCopy(x.Slice(static_cast<int>(each_range.begin),
99-
static_cast<int>(each_range.end)),
100-
x.place(), dev_ctx, &slice);
156+
outputs.insert({each_range.begin, slice});
101157
offset += len;
102158
}
103159
}
160+
161+
LoDTensorToArrayFunctor functor(x);
162+
for (auto &out_pair : outputs) {
163+
functor.AddOutput(&out_pair.second);
164+
}
165+
platform::VisitPlace(place, functor);
104166
}
105167
};
106168

paddle/fluid/operators/math/concat.cc

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ template <typename T>
2727
class ConcatFunctor<platform::CPUDeviceContext, T> {
2828
public:
2929
void operator()(const platform::CPUDeviceContext& context,
30-
const std::vector<framework::Tensor>& input, const int axis,
30+
const std::vector<framework::Tensor>& input, int axis,
3131
framework::Tensor* output) {
3232
// TODO(zcd): Add input data validity checking
3333
int num = input.size();
@@ -71,7 +71,7 @@ class ConcatGradFunctor<platform::CPUDeviceContext, T> {
7171
public:
7272
void operator()(const platform::CPUDeviceContext& context,
7373
const framework::Tensor& input,
74-
const std::vector<const framework::LoDTensor*>& ref_inputs,
74+
const std::vector<const framework::Tensor*>& ref_inputs,
7575
const int axis, std::vector<framework::Tensor*>* outputs) {
7676
// TODO(zcd): Add input data validity checking
7777
size_t num = outputs->size();
@@ -109,16 +109,11 @@ class ConcatGradFunctor<platform::CPUDeviceContext, T> {
109109
}
110110
}
111111
};
112+
#define DEFINE_FUNCTOR(type) \
113+
template class ConcatFunctor<platform::CPUDeviceContext, type>; \
114+
template class ConcatGradFunctor<platform::CPUDeviceContext, type>;
112115

113-
template class ConcatFunctor<platform::CPUDeviceContext, int>;
114-
template class ConcatFunctor<platform::CPUDeviceContext, int64_t>;
115-
template class ConcatFunctor<platform::CPUDeviceContext, float>;
116-
template class ConcatFunctor<platform::CPUDeviceContext, double>;
117-
118-
template class ConcatGradFunctor<platform::CPUDeviceContext, int>;
119-
template class ConcatGradFunctor<platform::CPUDeviceContext, int64_t>;
120-
template class ConcatGradFunctor<platform::CPUDeviceContext, float>;
121-
template class ConcatGradFunctor<platform::CPUDeviceContext, double>;
116+
FOR_ALL_TYPES(DEFINE_FUNCTOR);
122117

123118
} // namespace math
124119
} // namespace operators

paddle/fluid/operators/math/concat.cu

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License. */
1717
#include "paddle/fluid/framework/mixed_vector.h"
1818
#include "paddle/fluid/operators/math/concat.h"
1919
#include "paddle/fluid/platform/cuda_primitives.h"
20+
#include "paddle/fluid/platform/float16.h"
2021

2122
namespace paddle {
2223
namespace operators {
@@ -118,7 +119,7 @@ template <typename T>
118119
class ConcatFunctor<platform::CUDADeviceContext, T> {
119120
public:
120121
void operator()(const platform::CUDADeviceContext& context,
121-
const std::vector<framework::Tensor>& input, const int axis,
122+
const std::vector<framework::Tensor>& input, int axis,
122123
framework::Tensor* output) {
123124
// TODO(zcd): Add input data validity checking
124125
int in_num = input.size();
@@ -192,8 +193,8 @@ class ConcatGradFunctor<platform::CUDADeviceContext, T> {
192193
public:
193194
void operator()(const platform::CUDADeviceContext& context,
194195
const framework::Tensor& input,
195-
const std::vector<const framework::LoDTensor*>& ref_inputs,
196-
const int axis, std::vector<framework::Tensor*>* outputs) {
196+
const std::vector<const framework::Tensor*>& ref_inputs,
197+
int axis, std::vector<framework::Tensor*>* outputs) {
197198
// TODO(zcd): Add input data validity checking
198199
int o_num = outputs->size();
199200
int out_row = 1;
@@ -261,15 +262,11 @@ class ConcatGradFunctor<platform::CUDADeviceContext, T> {
261262
}
262263
};
263264

264-
template class ConcatFunctor<platform::CUDADeviceContext, int>;
265-
template class ConcatFunctor<platform::CUDADeviceContext, int64_t>;
266-
template class ConcatFunctor<platform::CUDADeviceContext, float>;
267-
template class ConcatFunctor<platform::CUDADeviceContext, double>;
265+
#define DEFINE_FUNCTOR(type) \
266+
template class ConcatFunctor<platform::CUDADeviceContext, type>; \
267+
template class ConcatGradFunctor<platform::CUDADeviceContext, type>
268268

269-
template class ConcatGradFunctor<platform::CUDADeviceContext, int>;
270-
template class ConcatGradFunctor<platform::CUDADeviceContext, int64_t>;
271-
template class ConcatGradFunctor<platform::CUDADeviceContext, float>;
272-
template class ConcatGradFunctor<platform::CUDADeviceContext, double>;
269+
FOR_ALL_TYPES(DEFINE_FUNCTOR);
273270

274271
} // namespace math
275272
} // namespace operators

paddle/fluid/operators/math/concat.h

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ template <typename DeviceContext, typename T>
3737
class ConcatFunctor {
3838
public:
3939
void operator()(const DeviceContext& context,
40-
const std::vector<framework::Tensor>& input, const int axis,
40+
const std::vector<framework::Tensor>& input, int axis,
4141
framework::Tensor* output);
4242
};
4343

@@ -57,10 +57,21 @@ template <typename DeviceContext, typename T>
5757
class ConcatGradFunctor {
5858
public:
5959
void operator()(const DeviceContext& context, const framework::Tensor& input,
60-
const std::vector<const framework::LoDTensor*>& ref_inputs,
61-
const int axis, std::vector<framework::Tensor*>* outputs);
60+
const std::vector<const framework::Tensor*>& ref_inputs,
61+
int axis, std::vector<framework::Tensor*>* outputs);
6262
};
6363

6464
} // namespace math
6565
} // namespace operators
6666
} // namespace paddle
67+
68+
#define FOR_ALL_TYPES(macro) \
69+
macro(int); \
70+
macro(float); \
71+
macro(double); \
72+
macro(bool); \
73+
macro(int64_t); \
74+
macro(int16_t); \
75+
macro(uint8_t); \
76+
macro(int8_t); \
77+
macro(::paddle::platform::float16)

0 commit comments

Comments
 (0)