Skip to content

Commit 3fbff1e

Browse files
author
sweetsky0901
committed
for code review 5
1 parent 350cc61 commit 3fbff1e

File tree

3 files changed

+7
-0
lines changed

3 files changed

+7
-0
lines changed

paddle/operators/math/maxouting.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ class MaxOutGradFunctor<platform::CPUPlace, T> {
8989
if (input_data[input_idx] == output_data[output_idx]) {
9090
input_grad_data[input_idx] += output_grad_data[output_idx];
9191
continue_match = false;
92+
break;
9293
}
9394
}
9495
}

paddle/operators/math/maxouting.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ __global__ void KernelMaxoutGrad(
6565
if (input_data[data_idx + g * feat_len] == output_data[i]) {
6666
max_index = data_idx + g * feat_len;
6767
continue_match = false;
68+
break;
6869
}
6970
}
7071
if (max_index != -1) {

paddle/operators/maxout_op.cu.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,11 @@
1717
namespace ops = paddle::operators;
1818
REGISTER_OP_GPU_KERNEL(maxout, ops::MaxOutKernel<paddle::platform::GPUPlace,
1919
float>);
20+
REGISTER_OP_GPU_KERNEL(maxout, ops::MaxOutKernel<paddle::platform::GPUPlace,
21+
double>);
2022
REGISTER_OP_GPU_KERNEL(maxout_grad,
2123
ops::MaxOutGradKernel<paddle::platform::GPUPlace,
2224
float>);
25+
REGISTER_OP_GPU_KERNEL(maxout_grad,
26+
ops::MaxOutGradKernel<paddle::platform::GPUPlace,
27+
double>);

0 commit comments

Comments
 (0)