Skip to content

Commit c9d2046

Browse files
committed
roi_align for gpu
1 parent 2f5a801 commit c9d2046

File tree

1 file changed

+38
-0
lines changed

1 file changed

+38
-0
lines changed
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/operators/roi_align_op.h"
16+
#include "paddle/fluid/platform/cuda_primitives.h"
17+
18+
namespace paddle {
19+
namespace operators {
20+
21+
using Tensor = framework::Tensor;
22+
using LoDTensor = framework::LoDTensor;
23+
24+
static constexpr int kNumCUDAThreads = 512;
25+
static constexpr int kNumMaxinumNumBlocks = 4096;
26+
27+
} // namespace operators
28+
} // namespace paddle
29+
30+
namespace ops = paddle::operators;
31+
REGISTER_OP_CUDA_KERNEL(
32+
roi_align,
33+
ops::GPUROIAlignOpKernel<paddle::platform::CUDADeviceContext, float>,
34+
ops::GPUROIAlignOpKernel<paddle::platform::CUDADeviceContext, double>);
35+
REGISTER_OP_CUDA_KERNEL(
36+
roi_align_grad,
37+
ops::GPUROIAlignGradOpKernel<paddle::platform::CUDADeviceContext, float>,
38+
ops::GPUROIAlignGradOpKernel<paddle::platform::CUDADeviceContext, double>);

0 commit comments

Comments
 (0)