Skip to content

Commit c0b7241

Browse files
authored
Fix the broadcast bug in the elementwise ops, cherry-pick from the develop
PR types: Bug fixes PR changes:Ops Describe: Fix the broadcast bug in elementwise ops, when elementwise ops do not have fast gradient calculate way,use the default way to calculate.
1 parent 5993dde commit c0b7241

File tree

2 files changed

+27
-17
lines changed

2 files changed

+27
-17
lines changed

paddle/fluid/operators/elementwise/elementwise_op_function.h

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1040,72 +1040,75 @@ void CommonGradBroadcastCUDA(
10401040
// fallback
10411041
// to old fast path.
10421042
// 2. if both x and y need broadcast, then do it one by one.
1043+
bool fast_broadcast = false;
10431044
if (x_broadcast_pos.empty() && !y_broadcast_pos.empty()) {
10441045
can_split_y = SplitDims(y_broadcast_pos, max_dim);
10451046
if (can_split_y) {
10461047
// only y need to do broadcast on h
10471048
if (y_broadcast_pos[0] == 0) {
10481049
FastBroadCastHeightCUDAF(y_broadcast_pos, true);
1049-
} else {
1050-
LOG(ERROR) << "Error, broadcast should not into w broadcast";
1050+
fast_broadcast = true;
10511051
}
1052-
return;
10531052
} else if (y_broadcast_pos.size() == 1 ||
10541053
CheckContiguousDims(y_broadcast_pos)) { // for only one dim and
10551054
// contiguous broadcast.
10561055
// If cannot split, which means input has 3 parts
10571056
FastBroadCastAllCUDAF(y_broadcast_pos, max_dim, true);
1058-
return;
1057+
fast_broadcast = true;
10591058
}
10601059
} else if (y_broadcast_pos.empty() && !x_broadcast_pos.empty()) {
10611060
// only x need broadcast
10621061
can_split_x = SplitDims(x_broadcast_pos, max_dim);
10631062
if (can_split_x) {
10641063
if (x_broadcast_pos[0] == 0) {
10651064
FastBroadCastHeightCUDAF(x_broadcast_pos, false);
1066-
} else {
1067-
// x need to do broadcast on w
1068-
LOG(ERROR) << "Error, broadcast should not into w broadcast";
1065+
fast_broadcast = true;
10691066
}
1070-
return;
10711067
} else if (x_broadcast_pos.size() == 1 ||
10721068
CheckContiguousDims(x_broadcast_pos)) {
10731069
FastBroadCastAllCUDAF(x_broadcast_pos, max_dim, false);
1074-
return;
1070+
fast_broadcast = true;
10751071
}
10761072
} else if (!x_broadcast_pos.empty() && !y_broadcast_pos.empty()) {
10771073
// do x and y broadcast each.
10781074
can_split_y = SplitDims(y_broadcast_pos, max_dim);
1075+
bool fast_broadcast_x = false;
1076+
bool fast_broadcast_y = false;
10791077
if (can_split_y) {
10801078
// begin at start.
10811079
if (y_broadcast_pos[0] == 0) {
10821080
FastCommonCUDAF(y_broadcast_pos, true);
1083-
} else {
1084-
// finish at end
1085-
LOG(ERROR) << "Error, broadcast should not into w broadcast";
1081+
fast_broadcast_y = true;
10861082
}
10871083
} else if (y_broadcast_pos.size() == 1) {
10881084
FastBroadCastOneCUDAF(y_broadcast_pos, max_dim, false);
10891085
can_split_y = true;
1086+
fast_broadcast_y = true;
10901087
}
10911088
can_split_x = SplitDims(x_broadcast_pos, max_dim);
10921089
if (can_split_x) {
10931090
if (x_broadcast_pos[0] == 0) {
10941091
FastCommonCUDAF(x_broadcast_pos, false);
1095-
} else {
1096-
LOG(ERROR) << "Error, broadcast should not into w broadcast";
1092+
fast_broadcast_x = true;
10971093
}
10981094
} else if (x_broadcast_pos.size() == 1) {
10991095
FastBroadCastOneCUDAF(x_broadcast_pos, max_dim, true);
11001096
can_split_x = true;
1097+
fast_broadcast_x = true;
11011098
}
11021099
VLOG(3) << "CommonBroadcast can_split_y:" << can_split_y
11031100
<< " can_split_x:" << can_split_x;
11041101
// if both x and y into fast path then return
1105-
if (can_split_y && can_split_x) return;
1102+
if (fast_broadcast_x && fast_broadcast_y) {
1103+
fast_broadcast = true;
1104+
}
1105+
if (can_split_y && can_split_x && fast_broadcast) return;
11061106
}
11071107

11081108
// Should remove memory copy, use reg instead.
1109+
if (fast_broadcast) {
1110+
return;
1111+
}
11091112
int x_blocks = 0;
11101113
int x_threads = 0;
11111114
ComputeBroadcastKernelSize(x_dims_array, out_dims_array, &x_blocks,
@@ -1136,7 +1139,7 @@ void CommonGradBroadcastCUDA(
11361139
1, std::multiplies<int>());
11371140
int x_block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, x_threads);
11381141
int y_block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, y_threads);
1139-
if (dx && !can_split_x) {
1142+
if (dx) {
11401143
auto x_strides_order_tmp = memory::Alloc(ctx, bytes);
11411144
int *x_strides_order_gpu =
11421145
reinterpret_cast<int *>(x_strides_order_tmp->ptr());
@@ -1153,7 +1156,7 @@ void CommonGradBroadcastCUDA(
11531156
x_strides_order_gpu, x_dims_order_gpu, x_data, y_data, out_data,
11541157
dout_data, dx_data, out_size, max_dim, x_threads, dx_op);
11551158
}
1156-
if (dy && !can_split_y) {
1159+
if (dy) {
11571160
auto y_strides_order_tmp = memory::Alloc(ctx, bytes);
11581161
int *y_strides_order_gpu =
11591162
reinterpret_cast<int *>(y_strides_order_tmp->ptr());

python/paddle/fluid/tests/unittests/test_elementwise_add_op.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,13 @@ def init_input_output(self):
263263
self.out = self.x + self.y
264264

265265

266+
class TestElementwiseAddOp_broadcast_7(TestElementwiseAddOp):
267+
def init_input_output(self):
268+
self.x = np.random.rand(1, 1, 20, 5).astype(self.dtype)
269+
self.y = np.random.rand(20, 5, 1, 1).astype(self.dtype)
270+
self.out = self.x + self.y
271+
272+
266273
class TestFP16ElementwiseAddOp_broadcast_6(TestFP16ElementwiseAddOp):
267274
def init_input_output(self):
268275
self.x = np.random.rand(2, 12, 3, 5).astype(self.dtype)

0 commit comments

Comments
 (0)