Skip to content

Commit 07dc5a1

Browse files
authored
Add generate_mask_labels_op to support Mask-RCNN and refine some code. (#15371)
* Add generate_mask_labels_op to support Mask-RCNN. * Refine sigmoid_cross_entropy to support nomalize mode. * Fix generator_proposals_label. * Use DeviceTemporaryAllocator in roi_pool and roi_algin. * Remove shape check in data_feeder.
1 parent 9f5108a commit 07dc5a1

23 files changed

+1933
-204
lines changed

paddle/fluid/API.spec

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ paddle.fluid.layers.clip ArgSpec(args=['x', 'min', 'max', 'name'], varargs=None,
197197
paddle.fluid.layers.clip_by_norm ArgSpec(args=['x', 'max_norm', 'name'], varargs=None, keywords=None, defaults=(None,))
198198
paddle.fluid.layers.mean ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,))
199199
paddle.fluid.layers.mul ArgSpec(args=['x', 'y', 'x_num_col_dims', 'y_num_col_dims', 'name'], varargs=None, keywords=None, defaults=(1, 1, None))
200-
paddle.fluid.layers.sigmoid_cross_entropy_with_logits ArgSpec(args=['x', 'label', 'ignore_index', 'name'], varargs=None, keywords=None, defaults=(-100, None))
200+
paddle.fluid.layers.sigmoid_cross_entropy_with_logits ArgSpec(args=['x', 'label', 'ignore_index', 'name', 'normalize'], varargs=None, keywords=None, defaults=(-100, None, False))
201201
paddle.fluid.layers.maxout ArgSpec(args=['x', 'groups', 'name'], varargs=None, keywords=None, defaults=(None,))
202202
paddle.fluid.layers.space_to_depth ArgSpec(args=['x', 'blocksize', 'name'], varargs=None, keywords=None, defaults=(None,))
203203
paddle.fluid.layers.affine_grid ArgSpec(args=['theta', 'out_shape', 'name'], varargs=None, keywords=None, defaults=(None,))
@@ -318,6 +318,7 @@ paddle.fluid.layers.anchor_generator ArgSpec(args=['input', 'anchor_sizes', 'asp
318318
paddle.fluid.layers.roi_perspective_transform ArgSpec(args=['input', 'rois', 'transformed_height', 'transformed_width', 'spatial_scale'], varargs=None, keywords=None, defaults=(1.0,))
319319
paddle.fluid.layers.generate_proposal_labels ArgSpec(args=['rpn_rois', 'gt_classes', 'is_crowd', 'gt_boxes', 'im_info', 'batch_size_per_im', 'fg_fraction', 'fg_thresh', 'bg_thresh_hi', 'bg_thresh_lo', 'bbox_reg_weights', 'class_nums', 'use_random'], varargs=None, keywords=None, defaults=(256, 0.25, 0.25, 0.5, 0.0, [0.1, 0.1, 0.2, 0.2], None, True))
320320
paddle.fluid.layers.generate_proposals ArgSpec(args=['scores', 'bbox_deltas', 'im_info', 'anchors', 'variances', 'pre_nms_top_n', 'post_nms_top_n', 'nms_thresh', 'min_size', 'eta', 'name'], varargs=None, keywords=None, defaults=(6000, 1000, 0.5, 0.1, 1.0, None))
321+
paddle.fluid.layers.generate_mask_labels ArgSpec(args=['im_info', 'gt_classes', 'is_crowd', 'gt_segms', 'rois', 'labels_int32', 'num_classes', 'resolution'], varargs=None, keywords=None, defaults=None)
321322
paddle.fluid.layers.iou_similarity ArgSpec(args=['x', 'y', 'name'], varargs=None, keywords=None, defaults=(None,))
322323
paddle.fluid.layers.box_coder ArgSpec(args=['prior_box', 'prior_box_var', 'target_box', 'code_type', 'box_normalized', 'name'], varargs=None, keywords=None, defaults=('encode_center_size', True, None))
323324
paddle.fluid.layers.polygon_box_transform ArgSpec(args=['input', 'name'], varargs=None, keywords=None, defaults=(None,))

paddle/fluid/operators/affine_channel_op.cu

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ __global__ void AffineChannelScaleBiasGradientCUDAKernel(
8383
T* dbias) {
8484
const int outer_size = C;
8585
const int inner_size = N * HxW;
86-
typedef cub::BlockReduce<T, BlockDim> BlockReduce;
86+
typedef cub::BlockReduce<double, BlockDim> BlockReduce;
8787
__shared__ typename BlockReduce::TempStorage ds_storage;
8888
__shared__ typename BlockReduce::TempStorage db_storage;
8989

@@ -97,13 +97,16 @@ __global__ void AffineChannelScaleBiasGradientCUDAKernel(
9797
ds_sum += dy[index] * x[index];
9898
db_sum += dy[index];
9999
}
100-
ds_sum = BlockReduce(ds_storage).Reduce(ds_sum, cub::Sum());
101-
db_sum = BlockReduce(db_storage).Reduce(db_sum, cub::Sum());
100+
__syncthreads();
101+
auto ds_out =
102+
BlockReduce(ds_storage).Reduce(static_cast<double>(ds_sum), cub::Sum());
103+
auto db_out =
104+
BlockReduce(db_storage).Reduce(static_cast<double>(db_sum), cub::Sum());
105+
__syncthreads();
102106
if (threadIdx.x == 0) {
103-
dscale[i] = ds_sum;
104-
dbias[i] = db_sum;
107+
dscale[i] = ds_out;
108+
dbias[i] = db_out;
105109
}
106-
__syncthreads();
107110
}
108111
}
109112

paddle/fluid/operators/detection/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,7 @@ detection_library(roi_perspective_transform_op SRCS roi_perspective_transform_op
4545
foreach(src ${LOCAL_DETECTION_LIBS})
4646
set(OP_LIBRARY ${src} ${OP_LIBRARY} CACHE INTERNAL "op libs")
4747
endforeach()
48+
49+
cc_library(mask_util SRCS mask_util.cc DEPS memory)
50+
cc_test(mask_util_test SRCS mask_util_test.cc DEPS memory mask_util)
51+
detection_library(generate_mask_labels_op SRCS generate_mask_labels_op.cc DEPS mask_util)

paddle/fluid/operators/detection/bbox_util.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
11
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
23
Licensed under the Apache License, Version 2.0 (the "License");
34
you may not use this file except in compliance with the License.
45
You may obtain a copy of the License at
6+
57
http://www.apache.org/licenses/LICENSE-2.0
8+
69
Unless required by applicable law or agreed to in writing, software
710
distributed under the License is distributed on an "AS IS" BASIS,
811
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
912
See the License for the specific language governing permissions and
1013
limitations under the License. */
14+
1115
#pragma once
1216
#include <algorithm>
1317
#include "paddle/fluid/framework/eigen.h"
@@ -88,7 +92,9 @@ void BboxOverlaps(const framework::Tensor& r_boxes,
8892
inter_w = std::max(x_max - x_min + 1, zero);
8993
inter_h = std::max(y_max - y_min + 1, zero);
9094
inter_area = inter_w * inter_h;
91-
overlaps_et(i, j) = inter_area / (r_box_area + c_box_area - inter_area);
95+
overlaps_et(i, j) =
96+
(inter_area == 0.) ? 0 : inter_area /
97+
(r_box_area + c_box_area - inter_area);
9298
}
9399
}
94100
}

0 commit comments

Comments
 (0)