@@ -1040,72 +1040,75 @@ void CommonGradBroadcastCUDA(
1040
1040
// fallback
1041
1041
// to old fast path.
1042
1042
// 2. if both x and y need broadcast, then do it one by one.
1043
+ bool fast_broadcast = false ;
1043
1044
if (x_broadcast_pos.empty () && !y_broadcast_pos.empty ()) {
1044
1045
can_split_y = SplitDims (y_broadcast_pos, max_dim);
1045
1046
if (can_split_y) {
1046
1047
// only y need to do broadcast on h
1047
1048
if (y_broadcast_pos[0 ] == 0 ) {
1048
1049
FastBroadCastHeightCUDAF (y_broadcast_pos, true );
1049
- } else {
1050
- LOG (ERROR) << " Error, broadcast should not into w broadcast" ;
1050
+ fast_broadcast = true ;
1051
1051
}
1052
- return ;
1053
1052
} else if (y_broadcast_pos.size () == 1 ||
1054
1053
CheckContiguousDims (y_broadcast_pos)) { // for only one dim and
1055
1054
// contiguous broadcast.
1056
1055
// If cannot split, which means input has 3 parts
1057
1056
FastBroadCastAllCUDAF (y_broadcast_pos, max_dim, true );
1058
- return ;
1057
+ fast_broadcast = true ;
1059
1058
}
1060
1059
} else if (y_broadcast_pos.empty () && !x_broadcast_pos.empty ()) {
1061
1060
// only x need broadcast
1062
1061
can_split_x = SplitDims (x_broadcast_pos, max_dim);
1063
1062
if (can_split_x) {
1064
1063
if (x_broadcast_pos[0 ] == 0 ) {
1065
1064
FastBroadCastHeightCUDAF (x_broadcast_pos, false );
1066
- } else {
1067
- // x need to do broadcast on w
1068
- LOG (ERROR) << " Error, broadcast should not into w broadcast" ;
1065
+ fast_broadcast = true ;
1069
1066
}
1070
- return ;
1071
1067
} else if (x_broadcast_pos.size () == 1 ||
1072
1068
CheckContiguousDims (x_broadcast_pos)) {
1073
1069
FastBroadCastAllCUDAF (x_broadcast_pos, max_dim, false );
1074
- return ;
1070
+ fast_broadcast = true ;
1075
1071
}
1076
1072
} else if (!x_broadcast_pos.empty () && !y_broadcast_pos.empty ()) {
1077
1073
// do x and y broadcast each.
1078
1074
can_split_y = SplitDims (y_broadcast_pos, max_dim);
1075
+ bool fast_broadcast_x = false ;
1076
+ bool fast_broadcast_y = false ;
1079
1077
if (can_split_y) {
1080
1078
// begin at start.
1081
1079
if (y_broadcast_pos[0 ] == 0 ) {
1082
1080
FastCommonCUDAF (y_broadcast_pos, true );
1083
- } else {
1084
- // finish at end
1085
- LOG (ERROR) << " Error, broadcast should not into w broadcast" ;
1081
+ fast_broadcast_y = true ;
1086
1082
}
1087
1083
} else if (y_broadcast_pos.size () == 1 ) {
1088
1084
FastBroadCastOneCUDAF (y_broadcast_pos, max_dim, false );
1089
1085
can_split_y = true ;
1086
+ fast_broadcast_y = true ;
1090
1087
}
1091
1088
can_split_x = SplitDims (x_broadcast_pos, max_dim);
1092
1089
if (can_split_x) {
1093
1090
if (x_broadcast_pos[0 ] == 0 ) {
1094
1091
FastCommonCUDAF (x_broadcast_pos, false );
1095
- } else {
1096
- LOG (ERROR) << " Error, broadcast should not into w broadcast" ;
1092
+ fast_broadcast_x = true ;
1097
1093
}
1098
1094
} else if (x_broadcast_pos.size () == 1 ) {
1099
1095
FastBroadCastOneCUDAF (x_broadcast_pos, max_dim, true );
1100
1096
can_split_x = true ;
1097
+ fast_broadcast_x = true ;
1101
1098
}
1102
1099
VLOG (3 ) << " CommonBroadcast can_split_y:" << can_split_y
1103
1100
<< " can_split_x:" << can_split_x;
1104
1101
// if both x and y into fast path then return
1105
- if (can_split_y && can_split_x) return ;
1102
+ if (fast_broadcast_x && fast_broadcast_y) {
1103
+ fast_broadcast = true ;
1104
+ }
1105
+ if (can_split_y && can_split_x && fast_broadcast) return ;
1106
1106
}
1107
1107
1108
1108
// Should remove memory copy, use reg instead.
1109
+ if (fast_broadcast) {
1110
+ return ;
1111
+ }
1109
1112
int x_blocks = 0 ;
1110
1113
int x_threads = 0 ;
1111
1114
ComputeBroadcastKernelSize (x_dims_array, out_dims_array, &x_blocks,
@@ -1136,7 +1139,7 @@ void CommonGradBroadcastCUDA(
1136
1139
1 , std::multiplies<int >());
1137
1140
int x_block_size = std::min (ELEMWISE_MAX_BLOCK_DIM, x_threads);
1138
1141
int y_block_size = std::min (ELEMWISE_MAX_BLOCK_DIM, y_threads);
1139
- if (dx && !can_split_x ) {
1142
+ if (dx) {
1140
1143
auto x_strides_order_tmp = memory::Alloc (ctx, bytes);
1141
1144
int *x_strides_order_gpu =
1142
1145
reinterpret_cast <int *>(x_strides_order_tmp->ptr ());
@@ -1153,7 +1156,7 @@ void CommonGradBroadcastCUDA(
1153
1156
x_strides_order_gpu, x_dims_order_gpu, x_data, y_data, out_data,
1154
1157
dout_data, dx_data, out_size, max_dim, x_threads, dx_op);
1155
1158
}
1156
- if (dy && !can_split_y ) {
1159
+ if (dy) {
1157
1160
auto y_strides_order_tmp = memory::Alloc (ctx, bytes);
1158
1161
int *y_strides_order_gpu =
1159
1162
reinterpret_cast <int *>(y_strides_order_tmp->ptr ());
0 commit comments