Skip to content

Commit 63ac947

Browse files
authored
Merge pull request #16135 from heavengate/shift
Add temporal_shift op for TSM model
2 parents bb80dae + 193185b commit 63ac947

File tree

7 files changed

+585
-0
lines changed

7 files changed

+585
-0
lines changed

paddle/fluid/API.spec

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,7 @@ paddle.fluid.layers.merge_selected_rows (ArgSpec(args=['x', 'name'], varargs=Non
225225
paddle.fluid.layers.get_tensor_from_selected_rows (ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '7ffc849e71f31dfe29030ff94e662de6'))
226226
paddle.fluid.layers.lstm (ArgSpec(args=['input', 'init_h', 'init_c', 'max_len', 'hidden_size', 'num_layers', 'dropout_prob', 'is_bidirec', 'is_test', 'name', 'default_initializer', 'seed'], varargs=None, keywords=None, defaults=(0.0, False, False, None, None, -1)), ('document', 'd5e6c494ac35100e2ed4d4bd9a1ed932'))
227227
paddle.fluid.layers.shuffle_channel (ArgSpec(args=['x', 'group', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '2fa6782d43d02ae64482d21235a82949'))
228+
paddle.fluid.layers.temporal_shift (ArgSpec(args=['x', 'seg_num', 'shift_ratio', 'name'], varargs=None, keywords=None, defaults=(0.25, None)), ('document', 'fe4481fb31363b09cfdd228fc6776ddf'))
228229
paddle.fluid.layers.py_func (ArgSpec(args=['func', 'x', 'out', 'backward_func', 'skip_vars_in_backward_input'], varargs=None, keywords=None, defaults=(None, None)), ('document', '8404e472ac12b4a30a505d3d3a3e5fdb'))
229230
paddle.fluid.layers.psroi_pool (ArgSpec(args=['input', 'rois', 'output_channels', 'spatial_scale', 'pooled_height', 'pooled_width', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '1546136806fef5c08f6918544bd9151d'))
230231
paddle.fluid.layers.teacher_student_sigmoid_loss (ArgSpec(args=['input', 'label', 'soft_max_up_bound', 'soft_max_lower_bound'], varargs=None, keywords=None, defaults=(15.0, -15.0)), ('document', '2f6ff96864054a31aa4bb659c6722c99'))
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
2+
Licensed under the Apache License, Version 2.0 (the "License");
3+
you may not use this file except in compliance with the License.
4+
You may obtain a copy of the License at
5+
http://www.apache.org/licenses/LICENSE-2.0
6+
Unless required by applicable law or agreed to in writing, software
7+
distributed under the License is distributed on an "AS IS" BASIS,
8+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
See the License for the specific language governing permissions and
10+
limitations under the License. */
11+
12+
#include "paddle/fluid/operators/temporal_shift_op.h"
13+
#include "paddle/fluid/framework/op_registry.h"
14+
15+
namespace paddle {
16+
namespace operators {
17+
18+
using framework::Tensor;
19+
20+
class TemporalShiftOp : public framework::OperatorWithKernel {
21+
public:
22+
using framework::OperatorWithKernel::OperatorWithKernel;
23+
24+
protected:
25+
void InferShape(framework::InferShapeContext* ctx) const override {
26+
PADDLE_ENFORCE(ctx->HasInput("X"),
27+
"Input(X) of TemporalShiftOp should not be null.");
28+
PADDLE_ENFORCE(ctx->HasOutput("Out"),
29+
"Output(Out) of TemporalShiftOp should not be null.");
30+
31+
auto dim_x = ctx->GetInputDim("X");
32+
PADDLE_ENFORCE_EQ(dim_x.size(), 4,
33+
"Input(X) rank should be 4 in shape of [N*T, C, H, W].");
34+
35+
int seg_num = ctx->Attrs().Get<int>("seg_num");
36+
float shift_ratio = ctx->Attrs().Get<float>("shift_ratio");
37+
PADDLE_ENFORCE_GT(seg_num, 0, "Attr(seg_num) should be greater than 0.");
38+
PADDLE_ENFORCE(shift_ratio > 0 || shift_ratio < .5,
39+
"Attr(shift_ratio) should be greater than 0 and less "
40+
"than 0.5.");
41+
42+
if (ctx->IsRuntime()) {
43+
PADDLE_ENFORCE_EQ(
44+
dim_x[0] % seg_num, 0,
45+
"Input(X) dims[0] should be divided exactly by Attr(seg_num).");
46+
}
47+
48+
ctx->SetOutputDim("Out", dim_x);
49+
ctx->ShareLoD("X", "Out");
50+
}
51+
52+
protected:
53+
framework::OpKernelType GetExpectedKernelType(
54+
const framework::ExecutionContext& ctx) const override {
55+
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
56+
ctx.GetPlace());
57+
}
58+
};
59+
60+
class TemporalShiftOpMaker : public framework::OpProtoAndCheckerMaker {
61+
public:
62+
void Make() override {
63+
AddInput("X",
64+
"The input tensor of temporal shift operator. "
65+
"This is a 4-D tensor with shape of [N*T, C, H, W]. "
66+
"While N is the batch size, T is the temporal segment "
67+
"number, C is the channel number, H is the height of "
68+
"features and W is the width of features.");
69+
AddOutput("Out",
70+
"The output tensor of temporal shift operator. "
71+
"This is a 4-D tensor in the same shape with Input(X).");
72+
73+
AddAttr<int>("seg_num",
74+
"The temporal segment number, this should be a positive "
75+
"integer.");
76+
AddAttr<float>(
77+
"shift_ratio",
78+
"The shift ratio of the channels, the first :attr:`shift_ratio` part "
79+
"of channels will be shifted by -1 along the temporal dimension, "
80+
"and the second :attr:`shift_ratio` part of channels will be shifted "
81+
"by 1 along the temporal dimension. Default 0.25.")
82+
.SetDefault(0.25);
83+
84+
AddComment(R"DOC(
85+
This operator calculates the temporal shifting features for Input(X).
86+
87+
Input(X) should be in shape of [N*T, C, H, W], while N is the batch
88+
size, T is the temporal segment number specified by :attr:`seg_num`,
89+
C is the channel number, H and W is the height and width of features.
90+
91+
Temporal Shifting is calculated as follows:
92+
93+
Step 1: Reshape Input(X) to [N, T, C, H, W].
94+
95+
Step 2: Pad 0 to reshaping result in the 2nd(T) dimension with
96+
padding width as 1 on each side, padding result will be in shape
97+
of [N, T+2, C, H, W].
98+
99+
Step 3: Assume :attr:`shift_ratio` is :math:`1/4`, slice padding
100+
result as follows:
101+
102+
$$
103+
slice1 = x[:, :T, :C/4, :, :]
104+
$$
105+
$$
106+
slice2 = x[:, 2:T+2, C/4:C/2, :, :]
107+
$$
108+
$$
109+
slice3 = x[:, 1:T+1, C/2:, :, :]
110+
$$
111+
112+
Step 4: Concatenate three slices along the 3rd(C) dimension and
113+
reshape result to [N*T, C, H, W].
114+
115+
For details of temporal shifting, please refer to paper:
116+
`Temporal Shift Module <http://arxiv.org/abs/1811.08383>`_ .
117+
118+
)DOC");
119+
}
120+
};
121+
122+
class TemporalShiftOpGrad : public framework::OperatorWithKernel {
123+
public:
124+
using framework::OperatorWithKernel::OperatorWithKernel;
125+
126+
protected:
127+
void InferShape(framework::InferShapeContext* ctx) const override {
128+
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
129+
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
130+
"Input(Out@GRAD) should not be null");
131+
auto dim_x = ctx->GetInputDim("X");
132+
if (ctx->HasOutput(framework::GradVarName("X"))) {
133+
ctx->SetOutputDim(framework::GradVarName("X"), dim_x);
134+
}
135+
}
136+
137+
framework::OpKernelType GetExpectedKernelType(
138+
const framework::ExecutionContext& ctx) const override {
139+
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
140+
ctx.GetPlace());
141+
}
142+
};
143+
144+
} // namespace operators
145+
} // namespace paddle
146+
147+
namespace ops = paddle::operators;
148+
REGISTER_OPERATOR(temporal_shift, ops::TemporalShiftOp,
149+
ops::TemporalShiftOpMaker,
150+
paddle::framework::DefaultGradOpDescMaker<true>);
151+
REGISTER_OPERATOR(temporal_shift_grad, ops::TemporalShiftOpGrad);
152+
REGISTER_OP_CPU_KERNEL(temporal_shift, ops::TemporalShiftKernel<float>,
153+
ops::TemporalShiftKernel<double>);
154+
REGISTER_OP_CPU_KERNEL(temporal_shift_grad, ops::TemporalShiftGradKernel<float>,
155+
ops::TemporalShiftGradKernel<double>);
Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
2+
Licensed under the Apache License, Version 2.0 (the "License");
3+
you may not use this file except in compliance with the License.
4+
You may obtain a copy of the License at
5+
http://www.apache.org/licenses/LICENSE-2.0
6+
Unless required by applicable law or agreed to in writing, software
7+
distributed under the License is distributed on an "AS IS" BASIS,
8+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
See the License for the specific language governing permissions and
10+
limitations under the License. */
11+
12+
#include "paddle/fluid/operators/temporal_shift_op.h"
13+
#include "paddle/fluid/platform/cuda_primitives.h"
14+
15+
namespace paddle {
16+
namespace operators {
17+
18+
using framework::Tensor;
19+
20+
template <typename T>
21+
__global__ void KeTemporalShiftFw(const T* input, T* output, const int ntchw,
22+
const int tchw, const int chw, const int hw,
23+
const int w, const int t, const int c,
24+
const float shift_ratio) {
25+
int tid = blockIdx.x * blockDim.x + threadIdx.x;
26+
int stride = blockDim.x * gridDim.x;
27+
int src_it = 0;
28+
for (; tid < ntchw; tid += stride) {
29+
int in = tid / tchw;
30+
int it = (tid % tchw) / chw;
31+
int ic = (tid % chw) / hw;
32+
int ih = (tid % hw) / w;
33+
int iw = tid % w;
34+
35+
const int c1 = static_cast<T>(c * shift_ratio);
36+
const int c2 = static_cast<T>(c * 2 * shift_ratio);
37+
38+
if (ic < c1) {
39+
src_it = it - 1;
40+
} else if (ic < c2) {
41+
src_it = it + 1;
42+
} else {
43+
src_it = it;
44+
}
45+
46+
if (src_it < 0 || src_it >= t) {
47+
output[tid] = 0;
48+
} else {
49+
int src_idx = GetEntryIndex(in, src_it, ic, ih, iw, tchw, chw, hw, w);
50+
output[tid] = input[src_idx];
51+
}
52+
}
53+
}
54+
55+
template <typename T>
56+
__global__ void KeTemporalShiftBw(const T* output_grad, T* input_grad,
57+
const int ntchw, const int tchw,
58+
const int chw, const int hw, const int w,
59+
const int t, const int c,
60+
const float shift_ratio) {
61+
int tid = blockIdx.x * blockDim.x + threadIdx.x;
62+
int stride = blockDim.x * gridDim.x;
63+
int src_it = 0;
64+
for (; tid < ntchw; tid += stride) {
65+
int in = tid / tchw;
66+
int it = (tid % tchw) / chw;
67+
int ic = (tid % chw) / hw;
68+
int ih = (tid % hw) / w;
69+
int iw = tid % w;
70+
71+
const int c1 = static_cast<T>(c * shift_ratio);
72+
const int c2 = static_cast<T>(c * 2 * shift_ratio);
73+
74+
if (ic < c1) {
75+
src_it = it - 1;
76+
} else if (ic < c2) {
77+
src_it = it + 1;
78+
} else {
79+
src_it = it;
80+
}
81+
82+
if (src_it >= 0 && src_it < t) {
83+
int src_idx = GetEntryIndex(in, src_it, ic, ih, iw, tchw, chw, hw, w);
84+
input_grad[src_idx] = output_grad[tid];
85+
}
86+
}
87+
}
88+
89+
template <typename T>
90+
class TemporalShiftOpCUDAKernel : public framework::OpKernel<T> {
91+
public:
92+
void Compute(const framework::ExecutionContext& ctx) const override {
93+
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
94+
"This kernel only runs on GPU device.");
95+
auto* input = ctx.Input<Tensor>("X");
96+
auto* output = ctx.Output<Tensor>("Out");
97+
int t = ctx.Attr<int>("seg_num");
98+
float shift_ratio = ctx.Attr<float>("shift_ratio");
99+
100+
const int nt = input->dims()[0];
101+
const int c = input->dims()[1];
102+
const int h = input->dims()[2];
103+
const int w = input->dims()[3];
104+
105+
const int hw = h * w;
106+
const int chw = c * hw;
107+
const int tchw = t * chw;
108+
const int ntchw = nt * chw;
109+
110+
const T* input_data = input->data<T>();
111+
T* output_data = output->mutable_data<T>({nt, c, h, w}, ctx.GetPlace());
112+
113+
int pixelNum = nt * chw;
114+
int grid_dim = (pixelNum + 512 - 1) / 512;
115+
grid_dim = grid_dim > 8 ? 8 : grid_dim;
116+
117+
KeTemporalShiftFw<
118+
T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
119+
input_data, output_data, ntchw, tchw, chw, hw, w, t, c, shift_ratio);
120+
}
121+
};
122+
123+
template <typename T>
124+
class TemporalShiftGradOpCUDAKernel : public framework::OpKernel<T> {
125+
public:
126+
void Compute(const framework::ExecutionContext& ctx) const override {
127+
auto* input_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
128+
auto* output_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
129+
int t = ctx.Attr<int>("seg_num");
130+
float shift_ratio = ctx.Attr<float>("shift_ratio");
131+
132+
const int nt = output_grad->dims()[0];
133+
const int c = output_grad->dims()[1];
134+
const int h = output_grad->dims()[2];
135+
const int w = output_grad->dims()[3];
136+
137+
const int hw = h * w;
138+
const int chw = c * hw;
139+
const int tchw = t * chw;
140+
const int ntchw = nt * chw;
141+
142+
const T* output_grad_data = output_grad->data<T>();
143+
T* input_grad_data =
144+
input_grad->mutable_data<T>({nt, c, h, w}, ctx.GetPlace());
145+
math::SetConstant<platform::CUDADeviceContext, T>()(
146+
ctx.template device_context<platform::CUDADeviceContext>(), input_grad,
147+
static_cast<T>(0));
148+
149+
int pixelNum = nt * chw;
150+
int grid_dim = (pixelNum + 512 - 1) / 512;
151+
grid_dim = grid_dim > 8 ? 8 : grid_dim;
152+
153+
KeTemporalShiftBw<
154+
T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
155+
output_grad_data, input_grad_data, ntchw, tchw, chw, hw, w, t, c,
156+
shift_ratio);
157+
}
158+
};
159+
160+
} // namespace operators
161+
} // namespace paddle
162+
163+
namespace ops = paddle::operators;
164+
REGISTER_OP_CUDA_KERNEL(temporal_shift, ops::TemporalShiftOpCUDAKernel<float>,
165+
ops::TemporalShiftOpCUDAKernel<double>);
166+
REGISTER_OP_CUDA_KERNEL(temporal_shift_grad,
167+
ops::TemporalShiftGradOpCUDAKernel<float>,
168+
ops::TemporalShiftGradOpCUDAKernel<double>);

0 commit comments

Comments
 (0)