Skip to content

Commit 6fcdb24

Browse files
Add mean IOU op. (#10519)
* Add mean_iou op. * Add unitest for mean iou op. * Add optional collections of confusion matrix and mean_iou. * Fix cuda kernel. * Refine code. 1. Merge computing in GPU to two kernel. 2. Use wrong array and correct array instead of confusion matrix. * Add python api and fix cuda kernel. * Fix comments. * Small fix. * Small fix.
1 parent f790b96 commit 6fcdb24

File tree

5 files changed

+570
-62
lines changed

5 files changed

+570
-62
lines changed

paddle/fluid/operators/mean_iou_op.cc

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/operators/mean_iou_op.h"
16+
17+
namespace paddle {
18+
namespace operators {
19+
20+
class MeanIoUOp : public framework::OperatorWithKernel {
21+
public:
22+
using framework::OperatorWithKernel::OperatorWithKernel;
23+
24+
void InferShape(framework::InferShapeContext* ctx) const override {
25+
PADDLE_ENFORCE(ctx->HasInput("Predictions"),
26+
"Input (Predictions) of MeanIoU op should not be null.");
27+
PADDLE_ENFORCE(ctx->HasInput("Labels"),
28+
"Input (labels) of MeanIoU op should not be null.");
29+
PADDLE_ENFORCE(ctx->HasOutput("OutMeanIou"),
30+
"Output (OutMeanIou) of MeanIoU op should not be null.");
31+
PADDLE_ENFORCE(ctx->HasOutput("OutWrong"),
32+
"Output (OutWrong) of MeanIoU op should not be null.");
33+
PADDLE_ENFORCE(ctx->HasOutput("OutCorrect"),
34+
"Output (OutWrong) of MeanIoU op should not be null.");
35+
36+
int64_t num_classes =
37+
static_cast<int64_t>(ctx->Attrs().Get<int>("num_classes"));
38+
39+
ctx->SetOutputDim("OutMeanIou", {1});
40+
ctx->SetOutputDim("OutWrong", {num_classes});
41+
ctx->SetOutputDim("OutCorrect", {num_classes});
42+
}
43+
44+
protected:
45+
framework::OpKernelType GetExpectedKernelType(
46+
const framework::ExecutionContext& ctx) const override {
47+
return framework::OpKernelType(
48+
framework::ToDataType(ctx.Input<Tensor>("Predictions")->type()),
49+
ctx.GetPlace());
50+
}
51+
};
52+
53+
class MeanIoUOpMaker : public framework::OpProtoAndCheckerMaker {
54+
public:
55+
void Make() override {
56+
AddInput("Predictions",
57+
"(Tensor), A Tensor of prediction results for semantic labels"
58+
" with type int32 or int64. The rank should be greater than 1.");
59+
AddInput(
60+
"Labels",
61+
"(Tensor), A Tensor of ground truth labels with type int32 or int64."
62+
"Its shape should be the same as Input(Predictions).");
63+
AddInput("InWrongs",
64+
"(vector<Tensor>), A list of Tensor with shape "
65+
"[num_classes]. They are used to collect wrong number among "
66+
"batches. Empty list is also valid here.")
67+
.AsDuplicable()
68+
.AsDispensable();
69+
AddInput(
70+
"InCorrects",
71+
"(vector<Tensor>), A list of Tensor with shape "
72+
"[num_classes]. They are used to collect correct number among batches. "
73+
"Empty list is also valid here.")
74+
.AsDuplicable()
75+
.AsDispensable();
76+
AddInput("InMeanIou",
77+
"(vector<Tensor>), A list of Tensor that Output(mean_iou) should "
78+
"be added to. Empty list is also valid here.")
79+
.AsDuplicable()
80+
.AsDispensable();
81+
AddOutput("OutMeanIou",
82+
"(vector<Tensor>), A Tensor representing the"
83+
" mean intersection-over-union with shape [1].");
84+
AddOutput("OutWrong", "(Tensor), A Tensor with shape [num_classes]. ");
85+
AddOutput("OutCorrect", "(Tensor), A Tensor with shape [num_classes]. ");
86+
AddAttr<int>("num_classes", "(int), The possible number of labels.");
87+
88+
AddComment(R"DOC(
89+
mean-IOU Operator.
90+
Mean Intersection-Over-Union is a common evaluation metric for
91+
semantic image segmentation, which first computes the IOU for each
92+
semantic class and then computes the average over classes.
93+
IOU is defined as follows:
94+
IOU = true_positive / (true_positive + false_positive + false_negative).
95+
It is based on pixel level area while "IOU Similarity Operator"
96+
is based on area of rectangle.
97+
98+
)DOC");
99+
}
100+
};
101+
102+
} // namespace operators
103+
} // namespace paddle
104+
105+
namespace ops = paddle::operators;
106+
REGISTER_OPERATOR(mean_iou, ops::MeanIoUOp, ops::MeanIoUOpMaker,
107+
paddle::framework::EmptyGradOpMaker);
108+
REGISTER_OP_CPU_KERNEL(mean_iou, ops::MeanIoUKernel<int>,
109+
ops::MeanIoUKernel<int32_t>,
110+
ops::MeanIoUKernel<int64_t>);

paddle/fluid/operators/mean_iou_op.cu

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/operators/math/math_function.h"
16+
#include "paddle/fluid/operators/mean_iou_op.h"
17+
#include "paddle/fluid/platform/cuda_primitives.h"
18+
#include "paddle/fluid/platform/gpu_info.h"
19+
20+
namespace paddle {
21+
namespace operators {
22+
23+
using platform::PADDLE_CUDA_NUM_THREADS;
24+
25+
#define CUDA_1D_KERNEL_LOOP(i, n) \
26+
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
27+
i += blockDim.x * gridDim.x)
28+
29+
template <typename T>
30+
__global__ void CountCUDAKernel(const int num_classes, const int count,
31+
const T* predictions, const T* labels,
32+
int* wrong, int* correct) {
33+
extern __shared__ int blcok_cache[];
34+
int* wrong_c = blcok_cache;
35+
int* correct_c = blcok_cache + num_classes;
36+
// init cache
37+
for (int i = threadIdx.x; i < num_classes * 2; i += blockDim.x) {
38+
blcok_cache[i] = 0;
39+
}
40+
__syncthreads();
41+
42+
T pred;
43+
T label;
44+
CUDA_1D_KERNEL_LOOP(i, count) {
45+
pred = predictions[i];
46+
label = labels[i];
47+
if (pred == label) {
48+
atomicAdd(correct_c + pred, 1);
49+
} else {
50+
atomicAdd(wrong_c + pred, 1);
51+
atomicAdd(wrong_c + label, 1);
52+
}
53+
}
54+
55+
__syncthreads();
56+
57+
for (int i = threadIdx.x; i < num_classes; i += blockDim.x) {
58+
atomicAdd(wrong + i, wrong_c[i]);
59+
atomicAdd(correct + i, correct_c[i]);
60+
}
61+
}
62+
63+
__global__ void ComputeIoUCUDAKernel(const int num_classes, int* wrong,
64+
int* correct, float* ious, float* iou) {
65+
__shared__ int valid_count_c;
66+
if (threadIdx.x == 0) {
67+
valid_count_c = 0;
68+
}
69+
__syncthreads();
70+
CUDA_1D_KERNEL_LOOP(i, num_classes) {
71+
int wrong_n = wrong[i];
72+
int correct_n = correct[i];
73+
int denominator = wrong_n + correct_n;
74+
if (denominator > 0) {
75+
atomicAdd(&valid_count_c, 1);
76+
ious[i] = static_cast<float>(correct_n) / denominator;
77+
} else {
78+
ious[i] = 0;
79+
}
80+
}
81+
__syncthreads();
82+
if (threadIdx.x == 0) {
83+
float iou_sum = 0;
84+
for (int i = 0; i < num_classes; ++i) {
85+
iou_sum += ious[i];
86+
}
87+
iou[0] += iou_sum / valid_count_c;
88+
}
89+
}
90+
91+
template <typename T>
92+
class MeanIoUCUDAOpKernel : public framework::OpKernel<T> {
93+
public:
94+
void Compute(const framework::ExecutionContext& ctx) const override {
95+
auto& place = *ctx.template device_context<platform::CUDADeviceContext>()
96+
.eigen_device();
97+
// get input and output tensor
98+
auto* predictions = ctx.Input<Tensor>("Predictions");
99+
auto* labels = ctx.Input<Tensor>("Labels");
100+
auto* out_mean_iou = ctx.Output<Tensor>("OutMeanIou");
101+
auto* out_wrong = ctx.Output<Tensor>("OutWrong");
102+
auto* out_correct = ctx.Output<Tensor>("OutCorrect");
103+
int num_classes = static_cast<int>(ctx.Attr<int>("num_classes"));
104+
105+
// Get data ptr
106+
const T* predictions_data = predictions->data<T>();
107+
const T* labels_data = labels->data<T>();
108+
int* out_wrong_data = out_wrong->mutable_data<int>(ctx.GetPlace());
109+
int* out_correct_data = out_correct->mutable_data<int>(ctx.GetPlace());
110+
float* out_mean_iou_data =
111+
out_mean_iou->mutable_data<float>(ctx.GetPlace());
112+
113+
// Get Eigen tensor
114+
auto out_mean_iou_t = EigenTensor<float, 1>::From(*out_mean_iou);
115+
auto out_wrong_t = EigenTensor<int, 1>::From(*out_wrong);
116+
auto out_correct_t = EigenTensor<int, 1>::From(*out_correct);
117+
118+
// Temporary tensor
119+
Tensor ious;
120+
float* ious_data = ious.mutable_data<float>(
121+
{static_cast<int64_t>(num_classes)}, ctx.GetPlace());
122+
auto ious_t = EigenTensor<float, 1>::From(ious);
123+
124+
// Init out_wrong, out_correct and out_mean_iou
125+
out_wrong_t.device(place) = out_wrong_t.constant(0);
126+
out_correct_t.device(place) = out_correct_t.constant(0);
127+
out_mean_iou_t.device(place) = out_mean_iou_t.constant(0.0f);
128+
129+
// collect pre wrong, correct and mean_iou
130+
auto in_mean_ious = ctx.MultiInput<Tensor>("InMeanIou");
131+
for (int i = 0; i < in_mean_ious.size(); ++i) {
132+
out_mean_iou_t.device(place) +=
133+
EigenTensor<float, 1>::From(*in_mean_ious[i]);
134+
}
135+
auto in_wrongs = ctx.MultiInput<Tensor>("InWrongs");
136+
for (int i = 0; i < in_wrongs.size(); ++i) {
137+
out_wrong_t.device(place) += EigenTensor<int, 1>::From(*in_wrongs[i]);
138+
}
139+
auto in_corrects = ctx.MultiInput<Tensor>("InCorrects");
140+
for (int i = 0; i < in_corrects.size(); ++i) {
141+
out_correct_t.device(place) += EigenTensor<int, 1>::From(*in_corrects[i]);
142+
}
143+
// compute
144+
auto stream = ctx.cuda_device_context().stream();
145+
int block = PADDLE_CUDA_NUM_THREADS;
146+
int grid = (predictions->numel() + block - 1) / block;
147+
int cache_size = (num_classes * 2 + 1) * sizeof(int);
148+
CountCUDAKernel<T><<<grid, block, cache_size, stream>>>(
149+
num_classes, predictions->numel(), predictions_data, labels_data,
150+
out_wrong_data, out_correct_data);
151+
ctx.device_context().Wait();
152+
ComputeIoUCUDAKernel<<<1, block, 0, stream>>>(num_classes, out_wrong_data,
153+
out_correct_data, ious_data,
154+
out_mean_iou_data);
155+
}
156+
};
157+
158+
} // namespace operators
159+
} // namespace paddle
160+
161+
namespace ops = paddle::operators;
162+
REGISTER_OP_CUDA_KERNEL(mean_iou, ops::MeanIoUCUDAOpKernel<int>,
163+
ops::MeanIoUCUDAOpKernel<int64_t>,
164+
ops::MeanIoUCUDAOpKernel<int32_t>);

paddle/fluid/operators/mean_iou_op.h

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
#include <algorithm>
17+
#include "paddle/fluid/framework/op_registry.h"
18+
19+
namespace paddle {
20+
namespace operators {
21+
using Tensor = framework::Tensor;
22+
23+
template <typename T, int D, int MajorType = Eigen::RowMajor,
24+
typename IndexType = Eigen::DenseIndex>
25+
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
26+
27+
template <typename T>
28+
class MeanIoUKernel : public framework::OpKernel<T> {
29+
public:
30+
void Compute(const framework::ExecutionContext& ctx) const override {
31+
auto& place = *ctx.template device_context<platform::CPUDeviceContext>()
32+
.eigen_device();
33+
// get input and output tensor
34+
auto* predictions = ctx.Input<Tensor>("Predictions");
35+
auto* labels = ctx.Input<Tensor>("Labels");
36+
auto* out_mean_iou = ctx.Output<Tensor>("OutMeanIou");
37+
auto* out_wrong = ctx.Output<Tensor>("OutWrong");
38+
auto* out_correct = ctx.Output<Tensor>("OutCorrect");
39+
int num_classes = static_cast<int>(ctx.Attr<int>("num_classes"));
40+
41+
// get data ptr
42+
const T* predictions_data = predictions->data<T>();
43+
const T* labels_data = labels->data<T>();
44+
float* out_mean_iou_data =
45+
out_mean_iou->mutable_data<float>(ctx.GetPlace());
46+
int* out_wrong_data = out_wrong->mutable_data<int>(ctx.GetPlace());
47+
int* out_correct_data = out_correct->mutable_data<int>(ctx.GetPlace());
48+
49+
// get eigen tensor
50+
auto out_mean_iou_t = EigenTensor<float, 1>::From(*out_mean_iou);
51+
auto out_wrong_t = EigenTensor<int, 1>::From(*out_wrong);
52+
auto out_correct_t = EigenTensor<int, 1>::From(*out_correct);
53+
54+
// Tmp tensor
55+
Tensor denominator;
56+
Tensor valid_count;
57+
Tensor iou_sum;
58+
59+
// get data ptr of tmp tensor
60+
int* denominator_data = denominator.mutable_data<int>(
61+
{static_cast<int64_t>(num_classes)}, ctx.GetPlace());
62+
int* valid_count_data = valid_count.mutable_data<int>({1}, ctx.GetPlace());
63+
float* iou_sum_data = iou_sum.mutable_data<float>({1}, ctx.GetPlace());
64+
65+
// get eigen tensor of tmp tensor
66+
auto denominator_t = EigenTensor<int, 1>::From(denominator);
67+
auto valid_count_t = EigenTensor<int, 1>::From(valid_count);
68+
auto iou_sum_t = EigenTensor<float, 1>::From(iou_sum);
69+
70+
// init out_wrong, out_correct and out_mean_iou
71+
out_wrong_t = out_wrong_t.constant(0);
72+
out_correct_t = out_correct_t.constant(0);
73+
out_mean_iou_t = out_mean_iou_t.constant(0);
74+
75+
// collect pre wrong, correct and mean_iou
76+
auto in_mean_ious = ctx.MultiInput<Tensor>("InMeanIou");
77+
for (size_t i = 0; i < in_mean_ious.size(); ++i) {
78+
out_mean_iou_t.device(place) +=
79+
EigenTensor<float, 1>::From(*in_mean_ious[i]);
80+
}
81+
auto in_wrongs = ctx.MultiInput<Tensor>("InWrongs");
82+
for (size_t i = 0; i < in_wrongs.size(); ++i) {
83+
out_wrong_t.device(place) += EigenTensor<int, 1>::From(*in_wrongs[i]);
84+
}
85+
auto in_corrects = ctx.MultiInput<Tensor>("InCorrects");
86+
for (size_t i = 0; i < in_corrects.size(); ++i) {
87+
out_correct_t.device(place) += EigenTensor<int, 1>::From(*in_corrects[i]);
88+
}
89+
90+
// compute
91+
for (int64_t i = 0; i < predictions->numel(); ++i) {
92+
if (predictions_data[i] == labels_data[i]) {
93+
out_correct_data[predictions_data[i]] += 1;
94+
} else {
95+
out_wrong_data[labels_data[i]] += 1;
96+
out_wrong_data[predictions_data[i]] += 1;
97+
}
98+
}
99+
100+
denominator_t = out_wrong_t + out_correct_t;
101+
valid_count_t =
102+
(denominator_t > denominator_t.constant(0.0f)).cast<int>().sum();
103+
104+
for (int i = 0; i < num_classes; ++i) {
105+
if (denominator_data[i] == 0) {
106+
denominator_data[i] = 1;
107+
}
108+
}
109+
110+
iou_sum_t =
111+
(out_correct_t.cast<float>() / denominator_t.cast<float>()).sum();
112+
out_mean_iou_data[0] += (iou_sum_data[0] / valid_count_data[0]);
113+
}
114+
};
115+
116+
} // namespace operators
117+
} // namespace paddle

0 commit comments

Comments
 (0)