|
| 1 | +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +
|
| 3 | +Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +you may not use this file except in compliance with the License. |
| 5 | +You may obtain a copy of the License at |
| 6 | +
|
| 7 | + http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +
|
| 9 | +Unless required by applicable law or agreed to in writing, software |
| 10 | +distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +See the License for the specific language governing permissions and |
| 13 | +limitations under the License. */ |
| 14 | + |
| 15 | +#include "paddle/fluid/operators/math/math_function.h" |
| 16 | +#include "paddle/fluid/operators/mean_iou_op.h" |
| 17 | +#include "paddle/fluid/platform/cuda_primitives.h" |
| 18 | +#include "paddle/fluid/platform/gpu_info.h" |
| 19 | + |
| 20 | +namespace paddle { |
| 21 | +namespace operators { |
| 22 | + |
| 23 | +using platform::PADDLE_CUDA_NUM_THREADS; |
| 24 | + |
| 25 | +#define CUDA_1D_KERNEL_LOOP(i, n) \ |
| 26 | + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ |
| 27 | + i += blockDim.x * gridDim.x) |
| 28 | + |
| 29 | +template <typename T> |
| 30 | +__global__ void CountCUDAKernel(const int num_classes, const int count, |
| 31 | + const T* predictions, const T* labels, |
| 32 | + int* wrong, int* correct) { |
| 33 | + extern __shared__ int blcok_cache[]; |
| 34 | + int* wrong_c = blcok_cache; |
| 35 | + int* correct_c = blcok_cache + num_classes; |
| 36 | + // init cache |
| 37 | + for (int i = threadIdx.x; i < num_classes * 2; i += blockDim.x) { |
| 38 | + blcok_cache[i] = 0; |
| 39 | + } |
| 40 | + __syncthreads(); |
| 41 | + |
| 42 | + T pred; |
| 43 | + T label; |
| 44 | + CUDA_1D_KERNEL_LOOP(i, count) { |
| 45 | + pred = predictions[i]; |
| 46 | + label = labels[i]; |
| 47 | + if (pred == label) { |
| 48 | + atomicAdd(correct_c + pred, 1); |
| 49 | + } else { |
| 50 | + atomicAdd(wrong_c + pred, 1); |
| 51 | + atomicAdd(wrong_c + label, 1); |
| 52 | + } |
| 53 | + } |
| 54 | + |
| 55 | + __syncthreads(); |
| 56 | + |
| 57 | + for (int i = threadIdx.x; i < num_classes; i += blockDim.x) { |
| 58 | + atomicAdd(wrong + i, wrong_c[i]); |
| 59 | + atomicAdd(correct + i, correct_c[i]); |
| 60 | + } |
| 61 | +} |
| 62 | + |
| 63 | +__global__ void ComputeIoUCUDAKernel(const int num_classes, int* wrong, |
| 64 | + int* correct, float* ious, float* iou) { |
| 65 | + __shared__ int valid_count_c; |
| 66 | + if (threadIdx.x == 0) { |
| 67 | + valid_count_c = 0; |
| 68 | + } |
| 69 | + __syncthreads(); |
| 70 | + CUDA_1D_KERNEL_LOOP(i, num_classes) { |
| 71 | + int wrong_n = wrong[i]; |
| 72 | + int correct_n = correct[i]; |
| 73 | + int denominator = wrong_n + correct_n; |
| 74 | + if (denominator > 0) { |
| 75 | + atomicAdd(&valid_count_c, 1); |
| 76 | + ious[i] = static_cast<float>(correct_n) / denominator; |
| 77 | + } else { |
| 78 | + ious[i] = 0; |
| 79 | + } |
| 80 | + } |
| 81 | + __syncthreads(); |
| 82 | + if (threadIdx.x == 0) { |
| 83 | + float iou_sum = 0; |
| 84 | + for (int i = 0; i < num_classes; ++i) { |
| 85 | + iou_sum += ious[i]; |
| 86 | + } |
| 87 | + iou[0] += iou_sum / valid_count_c; |
| 88 | + } |
| 89 | +} |
| 90 | + |
| 91 | +template <typename T> |
| 92 | +class MeanIoUCUDAOpKernel : public framework::OpKernel<T> { |
| 93 | + public: |
| 94 | + void Compute(const framework::ExecutionContext& ctx) const override { |
| 95 | + auto& place = *ctx.template device_context<platform::CUDADeviceContext>() |
| 96 | + .eigen_device(); |
| 97 | + // get input and output tensor |
| 98 | + auto* predictions = ctx.Input<Tensor>("Predictions"); |
| 99 | + auto* labels = ctx.Input<Tensor>("Labels"); |
| 100 | + auto* out_mean_iou = ctx.Output<Tensor>("OutMeanIou"); |
| 101 | + auto* out_wrong = ctx.Output<Tensor>("OutWrong"); |
| 102 | + auto* out_correct = ctx.Output<Tensor>("OutCorrect"); |
| 103 | + int num_classes = static_cast<int>(ctx.Attr<int>("num_classes")); |
| 104 | + |
| 105 | + // Get data ptr |
| 106 | + const T* predictions_data = predictions->data<T>(); |
| 107 | + const T* labels_data = labels->data<T>(); |
| 108 | + int* out_wrong_data = out_wrong->mutable_data<int>(ctx.GetPlace()); |
| 109 | + int* out_correct_data = out_correct->mutable_data<int>(ctx.GetPlace()); |
| 110 | + float* out_mean_iou_data = |
| 111 | + out_mean_iou->mutable_data<float>(ctx.GetPlace()); |
| 112 | + |
| 113 | + // Get Eigen tensor |
| 114 | + auto out_mean_iou_t = EigenTensor<float, 1>::From(*out_mean_iou); |
| 115 | + auto out_wrong_t = EigenTensor<int, 1>::From(*out_wrong); |
| 116 | + auto out_correct_t = EigenTensor<int, 1>::From(*out_correct); |
| 117 | + |
| 118 | + // Temporary tensor |
| 119 | + Tensor ious; |
| 120 | + float* ious_data = ious.mutable_data<float>( |
| 121 | + {static_cast<int64_t>(num_classes)}, ctx.GetPlace()); |
| 122 | + auto ious_t = EigenTensor<float, 1>::From(ious); |
| 123 | + |
| 124 | + // Init out_wrong, out_correct and out_mean_iou |
| 125 | + out_wrong_t.device(place) = out_wrong_t.constant(0); |
| 126 | + out_correct_t.device(place) = out_correct_t.constant(0); |
| 127 | + out_mean_iou_t.device(place) = out_mean_iou_t.constant(0.0f); |
| 128 | + |
| 129 | + // collect pre wrong, correct and mean_iou |
| 130 | + auto in_mean_ious = ctx.MultiInput<Tensor>("InMeanIou"); |
| 131 | + for (int i = 0; i < in_mean_ious.size(); ++i) { |
| 132 | + out_mean_iou_t.device(place) += |
| 133 | + EigenTensor<float, 1>::From(*in_mean_ious[i]); |
| 134 | + } |
| 135 | + auto in_wrongs = ctx.MultiInput<Tensor>("InWrongs"); |
| 136 | + for (int i = 0; i < in_wrongs.size(); ++i) { |
| 137 | + out_wrong_t.device(place) += EigenTensor<int, 1>::From(*in_wrongs[i]); |
| 138 | + } |
| 139 | + auto in_corrects = ctx.MultiInput<Tensor>("InCorrects"); |
| 140 | + for (int i = 0; i < in_corrects.size(); ++i) { |
| 141 | + out_correct_t.device(place) += EigenTensor<int, 1>::From(*in_corrects[i]); |
| 142 | + } |
| 143 | + // compute |
| 144 | + auto stream = ctx.cuda_device_context().stream(); |
| 145 | + int block = PADDLE_CUDA_NUM_THREADS; |
| 146 | + int grid = (predictions->numel() + block - 1) / block; |
| 147 | + int cache_size = (num_classes * 2 + 1) * sizeof(int); |
| 148 | + CountCUDAKernel<T><<<grid, block, cache_size, stream>>>( |
| 149 | + num_classes, predictions->numel(), predictions_data, labels_data, |
| 150 | + out_wrong_data, out_correct_data); |
| 151 | + ctx.device_context().Wait(); |
| 152 | + ComputeIoUCUDAKernel<<<1, block, 0, stream>>>(num_classes, out_wrong_data, |
| 153 | + out_correct_data, ious_data, |
| 154 | + out_mean_iou_data); |
| 155 | + } |
| 156 | +}; |
| 157 | + |
| 158 | +} // namespace operators |
| 159 | +} // namespace paddle |
| 160 | + |
| 161 | +namespace ops = paddle::operators; |
| 162 | +REGISTER_OP_CUDA_KERNEL(mean_iou, ops::MeanIoUCUDAOpKernel<int>, |
| 163 | + ops::MeanIoUCUDAOpKernel<int64_t>, |
| 164 | + ops::MeanIoUCUDAOpKernel<int32_t>); |
0 commit comments