|
| 1 | +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. |
| 2 | + Licensed under the Apache License, Version 2.0 (the "License"); |
| 3 | + you may not use this file except in compliance with the License. |
| 4 | + You may obtain a copy of the License at |
| 5 | + http://www.apache.org/licenses/LICENSE-2.0 |
| 6 | + Unless required by applicable law or agreed to in writing, software |
| 7 | + distributed under the License is distributed on an "AS IS" BASIS, |
| 8 | + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 9 | + See the License for the specific language governing permissions and |
| 10 | + limitations under the License. */ |
| 11 | + |
| 12 | +#include "paddle/fluid/operators/temporal_shift_op.h" |
| 13 | +#include "paddle/fluid/platform/cuda_primitives.h" |
| 14 | + |
| 15 | +namespace paddle { |
| 16 | +namespace operators { |
| 17 | + |
| 18 | +using framework::Tensor; |
| 19 | + |
| 20 | +template <typename T> |
| 21 | +__global__ void KeTemporalShiftFw(const T* input, T* output, const int ntchw, |
| 22 | + const int tchw, const int chw, const int hw, |
| 23 | + const int w, const int t, const int c, |
| 24 | + const float shift_ratio) { |
| 25 | + int tid = blockIdx.x * blockDim.x + threadIdx.x; |
| 26 | + int stride = blockDim.x * gridDim.x; |
| 27 | + int src_it = 0; |
| 28 | + for (; tid < ntchw; tid += stride) { |
| 29 | + int in = tid / tchw; |
| 30 | + int it = (tid % tchw) / chw; |
| 31 | + int ic = (tid % chw) / hw; |
| 32 | + int ih = (tid % hw) / w; |
| 33 | + int iw = tid % w; |
| 34 | + |
| 35 | + const int c1 = static_cast<T>(c * shift_ratio); |
| 36 | + const int c2 = static_cast<T>(c * 2 * shift_ratio); |
| 37 | + |
| 38 | + if (ic < c1) { |
| 39 | + src_it = it - 1; |
| 40 | + } else if (ic < c2) { |
| 41 | + src_it = it + 1; |
| 42 | + } else { |
| 43 | + src_it = it; |
| 44 | + } |
| 45 | + |
| 46 | + if (src_it < 0 || src_it >= t) { |
| 47 | + output[tid] = 0; |
| 48 | + } else { |
| 49 | + int src_idx = GetEntryIndex(in, src_it, ic, ih, iw, tchw, chw, hw, w); |
| 50 | + output[tid] = input[src_idx]; |
| 51 | + } |
| 52 | + } |
| 53 | +} |
| 54 | + |
| 55 | +template <typename T> |
| 56 | +__global__ void KeTemporalShiftBw(const T* output_grad, T* input_grad, |
| 57 | + const int ntchw, const int tchw, |
| 58 | + const int chw, const int hw, const int w, |
| 59 | + const int t, const int c, |
| 60 | + const float shift_ratio) { |
| 61 | + int tid = blockIdx.x * blockDim.x + threadIdx.x; |
| 62 | + int stride = blockDim.x * gridDim.x; |
| 63 | + int src_it = 0; |
| 64 | + for (; tid < ntchw; tid += stride) { |
| 65 | + int in = tid / tchw; |
| 66 | + int it = (tid % tchw) / chw; |
| 67 | + int ic = (tid % chw) / hw; |
| 68 | + int ih = (tid % hw) / w; |
| 69 | + int iw = tid % w; |
| 70 | + |
| 71 | + const int c1 = static_cast<T>(c * shift_ratio); |
| 72 | + const int c2 = static_cast<T>(c * 2 * shift_ratio); |
| 73 | + |
| 74 | + if (ic < c1) { |
| 75 | + src_it = it - 1; |
| 76 | + } else if (ic < c2) { |
| 77 | + src_it = it + 1; |
| 78 | + } else { |
| 79 | + src_it = it; |
| 80 | + } |
| 81 | + |
| 82 | + if (src_it >= 0 && src_it < t) { |
| 83 | + int src_idx = GetEntryIndex(in, src_it, ic, ih, iw, tchw, chw, hw, w); |
| 84 | + input_grad[src_idx] = output_grad[tid]; |
| 85 | + } |
| 86 | + } |
| 87 | +} |
| 88 | + |
| 89 | +template <typename T> |
| 90 | +class TemporalShiftOpCUDAKernel : public framework::OpKernel<T> { |
| 91 | + public: |
| 92 | + void Compute(const framework::ExecutionContext& ctx) const override { |
| 93 | + PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), |
| 94 | + "This kernel only runs on GPU device."); |
| 95 | + auto* input = ctx.Input<Tensor>("X"); |
| 96 | + auto* output = ctx.Output<Tensor>("Out"); |
| 97 | + int t = ctx.Attr<int>("seg_num"); |
| 98 | + float shift_ratio = ctx.Attr<float>("shift_ratio"); |
| 99 | + |
| 100 | + const int nt = input->dims()[0]; |
| 101 | + const int c = input->dims()[1]; |
| 102 | + const int h = input->dims()[2]; |
| 103 | + const int w = input->dims()[3]; |
| 104 | + |
| 105 | + const int hw = h * w; |
| 106 | + const int chw = c * hw; |
| 107 | + const int tchw = t * chw; |
| 108 | + const int ntchw = nt * chw; |
| 109 | + |
| 110 | + const T* input_data = input->data<T>(); |
| 111 | + T* output_data = output->mutable_data<T>({nt, c, h, w}, ctx.GetPlace()); |
| 112 | + |
| 113 | + int pixelNum = nt * chw; |
| 114 | + int grid_dim = (pixelNum + 512 - 1) / 512; |
| 115 | + grid_dim = grid_dim > 8 ? 8 : grid_dim; |
| 116 | + |
| 117 | + KeTemporalShiftFw< |
| 118 | + T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>( |
| 119 | + input_data, output_data, ntchw, tchw, chw, hw, w, t, c, shift_ratio); |
| 120 | + } |
| 121 | +}; |
| 122 | + |
| 123 | +template <typename T> |
| 124 | +class TemporalShiftGradOpCUDAKernel : public framework::OpKernel<T> { |
| 125 | + public: |
| 126 | + void Compute(const framework::ExecutionContext& ctx) const override { |
| 127 | + auto* input_grad = ctx.Output<Tensor>(framework::GradVarName("X")); |
| 128 | + auto* output_grad = ctx.Input<Tensor>(framework::GradVarName("Out")); |
| 129 | + int t = ctx.Attr<int>("seg_num"); |
| 130 | + float shift_ratio = ctx.Attr<float>("shift_ratio"); |
| 131 | + |
| 132 | + const int nt = output_grad->dims()[0]; |
| 133 | + const int c = output_grad->dims()[1]; |
| 134 | + const int h = output_grad->dims()[2]; |
| 135 | + const int w = output_grad->dims()[3]; |
| 136 | + |
| 137 | + const int hw = h * w; |
| 138 | + const int chw = c * hw; |
| 139 | + const int tchw = t * chw; |
| 140 | + const int ntchw = nt * chw; |
| 141 | + |
| 142 | + const T* output_grad_data = output_grad->data<T>(); |
| 143 | + T* input_grad_data = |
| 144 | + input_grad->mutable_data<T>({nt, c, h, w}, ctx.GetPlace()); |
| 145 | + math::SetConstant<platform::CUDADeviceContext, T>()( |
| 146 | + ctx.template device_context<platform::CUDADeviceContext>(), input_grad, |
| 147 | + static_cast<T>(0)); |
| 148 | + |
| 149 | + int pixelNum = nt * chw; |
| 150 | + int grid_dim = (pixelNum + 512 - 1) / 512; |
| 151 | + grid_dim = grid_dim > 8 ? 8 : grid_dim; |
| 152 | + |
| 153 | + KeTemporalShiftBw< |
| 154 | + T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>( |
| 155 | + output_grad_data, input_grad_data, ntchw, tchw, chw, hw, w, t, c, |
| 156 | + shift_ratio); |
| 157 | + } |
| 158 | +}; |
| 159 | + |
| 160 | +} // namespace operators |
| 161 | +} // namespace paddle |
| 162 | + |
| 163 | +namespace ops = paddle::operators; |
| 164 | +REGISTER_OP_CUDA_KERNEL(temporal_shift, ops::TemporalShiftOpCUDAKernel<float>, |
| 165 | + ops::TemporalShiftOpCUDAKernel<double>); |
| 166 | +REGISTER_OP_CUDA_KERNEL(temporal_shift_grad, |
| 167 | + ops::TemporalShiftGradOpCUDAKernel<float>, |
| 168 | + ops::TemporalShiftGradOpCUDAKernel<double>); |
0 commit comments