Skip to content

Commit 5ea7bf8

Browse files
authored
Merge pull request #12872 from sneaxiy/stack_op
Add stack_op for DAM model
2 parents 405d6d0 + ba168bd commit 5ea7bf8

File tree

7 files changed

+487
-0
lines changed

7 files changed

+487
-0
lines changed

paddle/fluid/API.spec

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ paddle.fluid.layers.crop ArgSpec(args=['x', 'shape', 'offsets', 'name'], varargs
162162
paddle.fluid.layers.rank_loss ArgSpec(args=['label', 'left', 'right', 'name'], varargs=None, keywords=None, defaults=(None,))
163163
paddle.fluid.layers.prelu ArgSpec(args=['x', 'mode', 'param_attr', 'name'], varargs=None, keywords=None, defaults=(None, None))
164164
paddle.fluid.layers.flatten ArgSpec(args=['x', 'axis', 'name'], varargs=None, keywords=None, defaults=(1, None))
165+
paddle.fluid.layers.stack ArgSpec(args=['x', 'axis'], varargs=None, keywords=None, defaults=(0,))
165166
paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True))
166167
paddle.fluid.layers.open_recordio_file ArgSpec(args=['filename', 'shapes', 'lod_levels', 'dtypes', 'pass_num', 'for_parallel'], varargs=None, keywords=None, defaults=(1, True))
167168
paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None))

paddle/fluid/framework/array.h

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include <cstdint>
18+
#include "paddle/fluid/platform/hostdevice.h"
19+
20+
namespace paddle {
21+
namespace framework {
22+
template <typename T, size_t N>
23+
class Array {
24+
static_assert(N > 0, "The size of array must be larger than 0");
25+
26+
public:
27+
HOSTDEVICE Array() {}
28+
29+
HOSTDEVICE explicit Array(const T &val) {
30+
for (size_t i = 0; i < N; ++i) data_[i] = val;
31+
}
32+
33+
HOSTDEVICE const T *Get() const { return data_; }
34+
35+
HOSTDEVICE T *GetMutable() { return data_; }
36+
37+
HOSTDEVICE T &operator[](size_t index) { return data_[index]; }
38+
39+
HOSTDEVICE const T &operator[](size_t index) const { return data_[index]; }
40+
41+
HOSTDEVICE constexpr size_t size() const { return N; }
42+
43+
private:
44+
T data_[N];
45+
};
46+
47+
} // namespace framework
48+
} // namespace paddle

paddle/fluid/operators/stack_op.cc

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/operators/stack_op.h"
16+
17+
namespace plat = paddle::platform;
18+
namespace ops = paddle::operators;
19+
REGISTER_OPERATOR(stack, ops::StackOp, ops::StackOpMaker,
20+
ops::StackGradOpDescMaker);
21+
REGISTER_OPERATOR(stack_grad, ops::StackOpGrad);
22+
23+
REGISTER_OP_CPU_KERNEL(stack, ops::StackKernel<plat::CPUDeviceContext, float>,
24+
ops::StackKernel<plat::CPUDeviceContext, double>);
25+
26+
REGISTER_OP_CPU_KERNEL(stack_grad,
27+
ops::StackGradKernel<plat::CPUDeviceContext, float>,
28+
ops::StackGradKernel<plat::CPUDeviceContext, double>);

paddle/fluid/operators/stack_op.cu

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/operators/stack_op.h"
16+
17+
namespace plat = paddle::platform;
18+
namespace ops = paddle::operators;
19+
20+
REGISTER_OP_CUDA_KERNEL(stack, ops::StackKernel<plat::CUDADeviceContext, float>,
21+
ops::StackKernel<plat::CUDADeviceContext, double>);
22+
23+
REGISTER_OP_CUDA_KERNEL(stack_grad,
24+
ops::StackGradKernel<plat::CUDADeviceContext, float>,
25+
ops::StackGradKernel<plat::CUDADeviceContext, double>);

paddle/fluid/operators/stack_op.h

Lines changed: 278 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,278 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include "paddle/fluid/framework/op_registry.h"
18+
#include "paddle/fluid/platform/for_range.h"
19+
20+
#ifdef __NVCC__
21+
#include <thrust/device_vector.h>
22+
#include "paddle/fluid/framework/array.h"
23+
#endif
24+
25+
namespace paddle {
26+
namespace operators {
27+
28+
class StackOp : public framework::OperatorWithKernel {
29+
public:
30+
using framework::OperatorWithKernel::OperatorWithKernel;
31+
32+
void InferShape(framework::InferShapeContext *ctx) const override {
33+
PADDLE_ENFORCE_GT(ctx->Inputs("X").size(), 0,
34+
"Number of Inputs(X) must be larger than 0");
35+
PADDLE_ENFORCE(ctx->HasOutput("Y"), "Output(Y) must exist.");
36+
37+
auto input_dims = ctx->GetInputsDim("X");
38+
for (size_t i = 1; i < input_dims.size(); ++i) {
39+
PADDLE_ENFORCE_EQ(input_dims[i], input_dims[0],
40+
"Dims of all Inputs(X) must be the same");
41+
}
42+
43+
// Only lod of X[0] would be shared with Y
44+
ctx->ShareLoD("X", /*->*/ "Y");
45+
46+
int axis = ctx->Attrs().Get<int>("axis");
47+
int rank = input_dims[0].size();
48+
PADDLE_ENFORCE(
49+
axis >= -(rank + 1) && axis < rank + 1,
50+
"Attr(axis) must be inside [-(rank+1), rank+1), where rank = %d", rank);
51+
if (axis < 0) axis += (rank + 1);
52+
53+
auto vec = framework::vectorize2int(input_dims[0]);
54+
vec.insert(vec.begin() + axis, input_dims.size());
55+
ctx->SetOutputDim("Y", framework::make_ddim(vec));
56+
}
57+
};
58+
59+
class StackOpMaker : public framework::OpProtoAndCheckerMaker {
60+
public:
61+
void Make() override {
62+
AddInput("X", "The input of stack op.").AsDuplicable();
63+
AddOutput("Y", "The output of stack op.");
64+
AddAttr<int>("axis",
65+
"The axis along which all of the Inputs(X) should be stacked.")
66+
.SetDefault(0);
67+
AddComment(R"DOC(
68+
Stack Operator.
69+
70+
Stack all of the Inputs(X) into one tensor along Attr(axis). The dims of all Inputs(X) must be the same.
71+
)DOC");
72+
}
73+
};
74+
75+
template <typename VecXType, typename T>
76+
struct StackFunctor {
77+
HOSTDEVICE StackFunctor(const VecXType &x, T *y, int n, int post)
78+
: x_(x), y_(y), n_(n), post_(post) {}
79+
80+
HOSTDEVICE void operator()(int idx) {
81+
int i = idx / (n_ * post_);
82+
int which_x = idx / post_ - i * n_;
83+
int x_index = i * post_ + idx % post_;
84+
y_[idx] = x_[which_x][x_index];
85+
}
86+
87+
private:
88+
VecXType x_;
89+
T *y_;
90+
int n_;
91+
int post_;
92+
};
93+
94+
template <typename VecDxType, typename T>
95+
struct StackGradFunctor {
96+
HOSTDEVICE StackGradFunctor(const VecDxType &dx, const T *dy, int n, int post)
97+
: dx_(dx), dy_(dy), n_(n), post_(post) {}
98+
99+
HOSTDEVICE void operator()(int idx) {
100+
int i = idx / (n_ * post_);
101+
int which_x = idx / post_ - i * n_;
102+
int x_index = i * post_ + idx % post_;
103+
dx_[which_x][x_index] = dy_[idx];
104+
}
105+
106+
private:
107+
VecDxType dx_;
108+
const T *dy_;
109+
int n_;
110+
int post_;
111+
};
112+
113+
template <typename DeviceContext, typename VecXType, typename T>
114+
static inline void StackFunctorForRange(const DeviceContext &ctx,
115+
const VecXType &x, T *y, int total_num,
116+
int n, int post) {
117+
platform::ForRange<DeviceContext> for_range(ctx, total_num);
118+
for_range(StackFunctor<VecXType, T>(x, y, n, post));
119+
}
120+
121+
template <typename DeviceContext, typename VecDxType, typename T>
122+
static inline void StackGradFunctorForRange(const DeviceContext &ctx,
123+
const VecDxType &dx, const T *dy,
124+
int total_num, int n, int post) {
125+
platform::ForRange<DeviceContext> for_range(ctx, total_num);
126+
for_range(StackGradFunctor<VecDxType, T>(dx, dy, n, post));
127+
}
128+
129+
template <typename DeviceContext, typename T>
130+
class StackKernel : public framework::OpKernel<T> {
131+
using Tensor = framework::LoDTensor;
132+
133+
public:
134+
void Compute(const framework::ExecutionContext &ctx) const override {
135+
auto x = ctx.MultiInput<Tensor>("X");
136+
auto *y = ctx.Output<Tensor>("Y");
137+
138+
int axis = ctx.Attr<int>("axis");
139+
if (axis < 0) axis += (x[0]->dims().size() + 1);
140+
141+
int n = static_cast<int>(x.size());
142+
auto *y_data = y->mutable_data<T>(ctx.GetPlace());
143+
std::vector<const T *> x_datas(n);
144+
for (int i = 0; i < n; i++) x_datas[i] = x[i]->data<T>();
145+
146+
int pre = 1, post = 1;
147+
auto &dim = x[0]->dims();
148+
for (auto i = 0; i < axis; ++i) pre *= dim[i];
149+
for (auto i = axis; i < dim.size(); ++i) post *= dim[i];
150+
int total_num = pre * n * post;
151+
152+
auto &dev_ctx = ctx.template device_context<DeviceContext>();
153+
constexpr auto kMaxThreshold = 16;
154+
if (std::is_same<DeviceContext, platform::CPUDeviceContext>::value ||
155+
n > kMaxThreshold) {
156+
#ifdef __NVCC__
157+
VLOG(10) << "Stack more than " << kMaxThreshold
158+
<< " tensors on GPU may be slow.";
159+
thrust::device_vector<const T *> device_x_vec(x_datas);
160+
auto x_data_arr = device_x_vec.data().get();
161+
#else
162+
auto x_data_arr = x_datas.data();
163+
#endif
164+
StackFunctorForRange(dev_ctx, x_data_arr, y_data, total_num, n, post);
165+
#ifdef __NVCC__
166+
// Wait() must be called because device_x_vec may be destructed before
167+
// kernel ends
168+
dev_ctx.Wait();
169+
#endif
170+
}
171+
#ifdef __NVCC__
172+
else { // NOLINT
173+
framework::Array<const T *, kMaxThreshold> x_data_arr;
174+
for (int i = 0; i < n; ++i) x_data_arr[i] = x_datas[i];
175+
StackFunctorForRange(dev_ctx, x_data_arr, y_data, total_num, n, post);
176+
}
177+
#endif
178+
}
179+
};
180+
181+
class StackOpGrad : public framework::OperatorWithKernel {
182+
public:
183+
using framework::OperatorWithKernel::OperatorWithKernel;
184+
185+
void InferShape(framework::InferShapeContext *ctx) const override {
186+
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")),
187+
"Input(Y@Grad) must exist.");
188+
189+
int axis = ctx->Attrs().Get<int>("axis");
190+
auto dy_dim = ctx->GetInputDim(framework::GradVarName("Y"));
191+
int rank = dy_dim.size();
192+
PADDLE_ENFORCE(axis >= -rank && axis < rank,
193+
"Attr(axis) must be inside [-rank, rank), where rank = %d",
194+
rank);
195+
if (axis < 0) axis += rank;
196+
197+
PADDLE_ENFORCE_EQ(ctx->Outputs(framework::GradVarName("X")).size(),
198+
static_cast<size_t>(dy_dim[axis]),
199+
"Number of Outputs(X@Grad) is wrong");
200+
auto vec = framework::vectorize2int(dy_dim);
201+
vec.erase(vec.begin() + axis);
202+
ctx->SetOutputsDim(
203+
framework::GradVarName("X"),
204+
std::vector<framework::DDim>(dy_dim[axis], framework::make_ddim(vec)));
205+
}
206+
};
207+
208+
class StackGradOpDescMaker : public framework::SingleGradOpDescMaker {
209+
public:
210+
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
211+
212+
protected:
213+
std::unique_ptr<framework::OpDesc> Apply() const override {
214+
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
215+
op->SetType("stack_grad");
216+
op->SetInput(framework::GradVarName("Y"), OutputGrad("Y"));
217+
op->SetOutput(framework::GradVarName("X"), InputGrad("X", false));
218+
op->SetAttrMap(Attrs());
219+
return op;
220+
}
221+
};
222+
223+
template <typename DeviceContext, typename T>
224+
class StackGradKernel : public framework::OpKernel<T> {
225+
using Tensor = framework::LoDTensor;
226+
227+
public:
228+
void Compute(const framework::ExecutionContext &ctx) const override {
229+
auto *dy = ctx.Input<Tensor>(framework::GradVarName("Y"));
230+
auto dx = ctx.MultiOutput<Tensor>(framework::GradVarName("X"));
231+
int axis = ctx.Attr<int>("axis");
232+
if (axis < 0) axis += dy->dims().size();
233+
234+
int n = dy->dims()[axis];
235+
std::vector<T *> dx_datas(n); // NOLINT
236+
for (int i = 0; i < n; i++) {
237+
dx_datas[i] = dx[i]->mutable_data<T>(ctx.GetPlace());
238+
}
239+
auto dy_data = dy->data<T>();
240+
241+
int pre = 1;
242+
for (int i = 0; i < axis; ++i) pre *= dy->dims()[i];
243+
int total_num = dy->numel();
244+
int post = total_num / (n * pre);
245+
246+
auto &dev_ctx = ctx.template device_context<DeviceContext>();
247+
constexpr auto kMaxThreshold = 16;
248+
if (std::is_same<DeviceContext, platform::CPUDeviceContext>::value ||
249+
n > kMaxThreshold) {
250+
#ifdef __NVCC__
251+
VLOG(10) << "Stack more than " << kMaxThreshold
252+
<< " tensors on GPU may be slow.";
253+
thrust::device_vector<T *> device_dx_vec(dx_datas);
254+
auto dx_data_arr = device_dx_vec.data().get();
255+
#else
256+
auto dx_data_arr = dx_datas.data();
257+
#endif
258+
StackGradFunctorForRange(dev_ctx, dx_data_arr, dy_data, total_num, n,
259+
post);
260+
#ifdef __NVCC__
261+
// Wait() must be called because device_dx_vec may be destructed before
262+
// kernel ends
263+
dev_ctx.Wait();
264+
#endif
265+
}
266+
#ifdef __NVCC__
267+
else { // NOLINT
268+
framework::Array<T *, kMaxThreshold> dx_data_arr;
269+
for (int i = 0; i < n; ++i) dx_data_arr[i] = dx_datas[i];
270+
StackGradFunctorForRange(dev_ctx, dx_data_arr, dy_data, total_num, n,
271+
post);
272+
}
273+
#endif
274+
}
275+
};
276+
277+
} // namespace operators
278+
} // namespace paddle

0 commit comments

Comments
 (0)