|
| 1 | +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +
|
| 3 | +Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +you may not use this file except in compliance with the License. |
| 5 | +You may obtain a copy of the License at |
| 6 | +
|
| 7 | + http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +
|
| 9 | +Unless required by applicable law or agreed to in writing, software |
| 10 | +distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +See the License for the specific language governing permissions and |
| 13 | +limitations under the License. */ |
| 14 | + |
| 15 | +#include <string> |
| 16 | +#include "paddle/fluid/operators/fake_quantize_op.h" |
| 17 | +#include "paddle/fluid/platform/cuda_primitives.h" |
| 18 | + |
| 19 | +namespace paddle { |
| 20 | +namespace operators { |
| 21 | + |
| 22 | +template <typename T> |
| 23 | +__global__ void FindAbsMaxKernel(const int n, const T* in, T* out) { |
| 24 | + int bid = threadIdx.x + blockIdx.x * blockDim.x; |
| 25 | + int tid = threadIdx.x; |
| 26 | + |
| 27 | + extern __shared__ T shared_max_data[]; |
| 28 | + if (gridDim.x > 1) { |
| 29 | + shared_max_data[tid] = T(0); |
| 30 | + for (int i = bid; i < n; i += blockDim.x * gridDim.x) { |
| 31 | + T tmp = fabs(in[i]); |
| 32 | + if (tmp > shared_max_data[tid]) { |
| 33 | + shared_max_data[tid] = tmp; |
| 34 | + } |
| 35 | + } |
| 36 | + } else { |
| 37 | + if (bid < n) { |
| 38 | + shared_max_data[tid] = fabs(in[bid]); |
| 39 | + } else { |
| 40 | + shared_max_data[tid] = T(0); |
| 41 | + } |
| 42 | + } |
| 43 | + __syncthreads(); |
| 44 | + |
| 45 | + for (int i = blockDim.x / 2; i > 0; i >>= 1) { |
| 46 | + if (tid < i && shared_max_data[tid] < shared_max_data[tid + i]) { |
| 47 | + shared_max_data[tid] = shared_max_data[tid + i]; |
| 48 | + } |
| 49 | + __syncthreads(); |
| 50 | + } |
| 51 | + if (tid == 0) { |
| 52 | + out[blockIdx.x] = shared_max_data[0]; |
| 53 | + } |
| 54 | +} |
| 55 | + |
| 56 | +float FindAbsMaxGpu(const platform::CUDADeviceContext& ctx, const float* array, |
| 57 | + int length) { |
| 58 | + float host_max; |
| 59 | + int kNumTheads = 1024; |
| 60 | + int gridDimx = (kNumTheads - 1 + length) / kNumTheads; |
| 61 | + gridDimx = (gridDimx > kNumTheads) ? kNumTheads : gridDimx; |
| 62 | + framework::Tensor t; |
| 63 | + float* device_max = t.mutable_data<float>(framework::make_ddim({gridDimx}), |
| 64 | + platform::CUDAPlace()); |
| 65 | + FindAbsMaxKernel<float><<<gridDimx, kNumTheads, kNumTheads * sizeof(float), |
| 66 | + ctx.stream()>>>(length, array, device_max); |
| 67 | + FindAbsMaxKernel< |
| 68 | + float><<<1, kNumTheads, kNumTheads * sizeof(float), ctx.stream()>>>( |
| 69 | + gridDimx, device_max, device_max); |
| 70 | + PADDLE_ENFORCE_EQ( |
| 71 | + cudaMemcpy(&host_max, device_max, sizeof(float), cudaMemcpyDeviceToHost), |
| 72 | + cudaSuccess, "cudaMemcpy failed"); |
| 73 | + return host_max; |
| 74 | +} |
| 75 | + |
| 76 | +template <typename T> |
| 77 | +__global__ void ApplySaturateKernel(const int n, const T* in, T* out, |
| 78 | + int* num_saturate, const T min, |
| 79 | + const T max) { |
| 80 | + int bid = threadIdx.x + blockIdx.x * blockDim.x; |
| 81 | + int tid = threadIdx.x; |
| 82 | + |
| 83 | + extern __shared__ int shared_count[]; |
| 84 | + shared_count[tid] = 0; |
| 85 | + for (int i = bid; i < n; i += blockDim.x * gridDim.x) { |
| 86 | + if (in[i] > max) { |
| 87 | + out[i] = max; |
| 88 | + shared_count[tid] += 1; |
| 89 | + } else if (in[i] < min) { |
| 90 | + out[i] = min; |
| 91 | + shared_count[tid] += 1; |
| 92 | + } else { |
| 93 | + out[i] = in[i]; |
| 94 | + } |
| 95 | + } |
| 96 | + __syncthreads(); |
| 97 | + |
| 98 | + for (int i = blockDim.x / 2; i > 0; i >>= 1) { |
| 99 | + if (tid < i) { |
| 100 | + shared_count[tid] += shared_count[tid + i]; |
| 101 | + } |
| 102 | + __syncthreads(); |
| 103 | + } |
| 104 | + if (tid == 0) { |
| 105 | + num_saturate[blockIdx.x] = shared_count[0]; |
| 106 | + } |
| 107 | +} |
| 108 | + |
| 109 | +template <typename T> |
| 110 | +__global__ void ReduceKernel(const int n, const T* in, T* out) { |
| 111 | + int tid = threadIdx.x; |
| 112 | + extern __shared__ T shared_sum[]; |
| 113 | + if (tid < n) { |
| 114 | + shared_sum[tid] = in[tid]; |
| 115 | + } else { |
| 116 | + shared_sum[tid] = T(0); |
| 117 | + } |
| 118 | + __syncthreads(); |
| 119 | + // blockDim.x must >= n |
| 120 | + for (int i = (n + 1) / 2; i > 0; i >>= 1) { |
| 121 | + if (tid < i) { |
| 122 | + shared_sum[tid] += shared_sum[tid + i]; |
| 123 | + } |
| 124 | + __syncthreads(); |
| 125 | + } |
| 126 | + if (tid == 0) { |
| 127 | + out[0] = shared_sum[0]; |
| 128 | + } |
| 129 | +} |
| 130 | + |
| 131 | +template <typename T> |
| 132 | +int ApplySaturateGpu(const platform::CUDADeviceContext& ctx, const int n, |
| 133 | + const T* in, T* out, const T min, const T max) { |
| 134 | + int host_num_saturate; |
| 135 | + int kNumTheads = 1024; |
| 136 | + int gridDimx = (n + kNumTheads - 1) / kNumTheads; |
| 137 | + gridDimx = (gridDimx > kNumTheads) ? kNumTheads : gridDimx; |
| 138 | + framework::Tensor t; |
| 139 | + int* device_num_saturate = t.mutable_data<int>( |
| 140 | + framework::make_ddim({gridDimx}), platform::CUDAPlace()); |
| 141 | + ApplySaturateKernel< |
| 142 | + T><<<gridDimx, kNumTheads, kNumTheads * sizeof(T), ctx.stream()>>>( |
| 143 | + n, in, out, device_num_saturate, min, max); |
| 144 | + ReduceKernel<int><<<1, kNumTheads, kNumTheads * sizeof(T), ctx.stream()>>>( |
| 145 | + gridDimx, device_num_saturate, device_num_saturate); |
| 146 | + PADDLE_ENFORCE_EQ(cudaSuccess, |
| 147 | + cudaMemcpy(&host_num_saturate, device_num_saturate, |
| 148 | + sizeof(int), cudaMemcpyDeviceToHost), |
| 149 | + "cudaMemcpy failed"); |
| 150 | + return host_num_saturate; |
| 151 | +} |
| 152 | + |
| 153 | +template <typename DeviceContext, typename T> |
| 154 | +class FakeQuantizeCUDAKernel : public framework::OpKernel<T> { |
| 155 | + public: |
| 156 | + T FindRangeAbsMax(const platform::CUDADeviceContext& ctx, |
| 157 | + framework::Tensor* scale_list, framework::Tensor* out_scale, |
| 158 | + const T& cur_scale, int window_size, |
| 159 | + int current_iter) const { |
| 160 | + T* sl = scale_list->mutable_data<T>(platform::CPUPlace()); |
| 161 | + T remove_tmp = sl[current_iter]; |
| 162 | + sl[current_iter] = cur_scale; |
| 163 | + T& max_scale = out_scale->mutable_data<T>(platform::CPUPlace())[0]; |
| 164 | + if (max_scale < cur_scale) { |
| 165 | + max_scale = cur_scale; |
| 166 | + } else if (fabs(remove_tmp - max_scale) < 1e-6) { |
| 167 | + int size = (current_iter > window_size) ? window_size : current_iter; |
| 168 | + max_scale = T(FindAbsMaxGpu(ctx, scale_list->data<float>(), size)); |
| 169 | + } |
| 170 | + return max_scale; |
| 171 | + } |
| 172 | + |
| 173 | + T FindMovingAverageAbsMmax(framework::Tensor* in_scale, |
| 174 | + framework::Tensor* out_scale, |
| 175 | + const T& cur_scale) const { |
| 176 | + T* ins = in_scale->mutable_data<T>(platform::CPUPlace()); |
| 177 | + T* outs = out_scale->mutable_data<T>(platform::CPUPlace()); |
| 178 | + outs[0] = 0.9 * cur_scale + 0.1 * ins[0]; |
| 179 | + return T(outs[0]); |
| 180 | + } |
| 181 | + |
| 182 | + virtual void Compute(const framework::ExecutionContext& context) const { |
| 183 | + PADDLE_ENFORCE(platform::is_gpu_place(context.GetPlace()), |
| 184 | + "This kernel only runs on GPU device."); |
| 185 | + auto& device_ctx = context.cuda_device_context(); |
| 186 | + auto* tensor = context.Output<framework::Tensor>("Out"); |
| 187 | + auto* in = context.Input<framework::Tensor>("X"); |
| 188 | + const bool is_test = context.Attr<bool>("is_test"); |
| 189 | + tensor->mutable_data<T>(in->place()); |
| 190 | + context.Output<framework::Tensor>("OutMovingScale") |
| 191 | + ->mutable_data<T>( |
| 192 | + context.Input<framework::Tensor>("InMovingScale")->place()); |
| 193 | + auto quantize_type = |
| 194 | + static_cast<std::string>(context.Attr<std::string>("quantize_type")); |
| 195 | + if (quantize_type == std::string("range_abs_max")) { |
| 196 | + context.Output<framework::Tensor>("OutScales") |
| 197 | + ->mutable_data<T>( |
| 198 | + context.Input<framework::Tensor>("InScales")->place()); |
| 199 | + context.Output<framework::Tensor>("OutCurrentIter") |
| 200 | + ->mutable_data<T>( |
| 201 | + context.Input<framework::Tensor>("InCurrentIter")->place()); |
| 202 | + } |
| 203 | + |
| 204 | + T scale = T(1); |
| 205 | + int window_size = context.Attr<int>("window_size"); |
| 206 | + T bin_cnt = (T)((1 << (context.Attr<int>("bit_length") - 1)) - 1); |
| 207 | + if (quantize_type == std::string("abs_max")) { |
| 208 | + auto* saving_scale = context.Output<framework::Tensor>("OutMovingScale"); |
| 209 | + scale = (T)FindAbsMaxGpu(device_ctx, in->data<float>(), in->numel()); |
| 210 | + saving_scale->mutable_data<T>(platform::CPUPlace())[0] = scale; |
| 211 | + |
| 212 | + auto& device_ctx = context.template device_context<DeviceContext>(); |
| 213 | + auto* scale_list = context.Output<framework::Tensor>("OutScales"); |
| 214 | + math::SetConstant<DeviceContext, T> scalar; |
| 215 | + scale_list->mutable_data<T>(context.GetPlace()); |
| 216 | + scalar(device_ctx, scale_list, static_cast<T>(0)); |
| 217 | + auto* iter = context.Output<framework::Tensor>("OutCurrentIter"); |
| 218 | + iter->mutable_data<T>(context.GetPlace()); |
| 219 | + scalar(device_ctx, iter, static_cast<T>(0)); |
| 220 | + } else if (quantize_type == std::string("range_abs_max")) { |
| 221 | + auto* moving_scale = const_cast<framework::Tensor*>( |
| 222 | + context.Input<framework::Tensor>("InMovingScale")); |
| 223 | + if (is_test) { |
| 224 | + scale = moving_scale->mutable_data<T>(platform::CPUPlace())[0]; |
| 225 | + } else { |
| 226 | + auto* it = const_cast<framework::Tensor*>( |
| 227 | + context.Input<framework::Tensor>("InCurrentIter")); |
| 228 | + auto* iter = context.Output<framework::Tensor>("OutCurrentIter"); |
| 229 | + int* last_iter = it->mutable_data<int>(platform::CPUPlace()); |
| 230 | + int* current_iter = iter->mutable_data<int>(platform::CPUPlace()); |
| 231 | + auto* scale_list = context.Output<framework::Tensor>("OutScales"); |
| 232 | + auto* saving_scale = |
| 233 | + context.Output<framework::Tensor>("OutMovingScale"); |
| 234 | + scale = (T)FindAbsMaxGpu(device_ctx, in->data<float>(), in->numel()); |
| 235 | + scale = FindRangeAbsMax(device_ctx, scale_list, saving_scale, scale, |
| 236 | + window_size, current_iter[0]); |
| 237 | + (*current_iter) = (*last_iter) + 1; |
| 238 | + } |
| 239 | + } else if (quantize_type == std::string("moving_average_abs_max")) { |
| 240 | + auto* moving_scale = const_cast<framework::Tensor*>( |
| 241 | + context.Input<framework::Tensor>("InMovingScale")); |
| 242 | + if (is_test) { |
| 243 | + scale = moving_scale->mutable_data<T>(platform::CPUPlace())[0]; |
| 244 | + } else { |
| 245 | + scale = (T)FindAbsMaxGpu(device_ctx, in->data<float>(), in->numel()); |
| 246 | + auto* saving_scale = |
| 247 | + context.Output<framework::Tensor>("OutMovingScale"); |
| 248 | + scale = FindMovingAverageAbsMmax( |
| 249 | + const_cast<framework::Tensor*>(moving_scale), saving_scale, scale); |
| 250 | + } |
| 251 | + } |
| 252 | + |
| 253 | + ApplySaturateGpu<T>(device_ctx, in->numel(), in->data<T>(), |
| 254 | + tensor->mutable_data<T>(in->place()), -scale, scale); |
| 255 | + scale = bin_cnt / scale; |
| 256 | + |
| 257 | + auto& dev = |
| 258 | + *context.template device_context<DeviceContext>().eigen_device(); |
| 259 | + auto eigen_out = framework::EigenVector<T>::Flatten(*tensor); |
| 260 | + auto eigen_in = framework::EigenVector<T>::Flatten(*tensor); |
| 261 | + eigen_out.device(dev) = (scale * eigen_in).round(); |
| 262 | + } |
| 263 | +}; |
| 264 | + |
| 265 | +} // namespace operators |
| 266 | +} // namespace paddle |
| 267 | + |
| 268 | +REGISTER_OP_CUDA_KERNEL(fake_quantize, |
| 269 | + paddle::operators::FakeQuantizeCUDAKernel< |
| 270 | + paddle::platform::CUDADeviceContext, float>, |
| 271 | + paddle::operators::FakeQuantizeCUDAKernel< |
| 272 | + paddle::platform::CUDADeviceContext, double>); |
0 commit comments