@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
See the License for the specific language governing permissions and
13
13
limitations under the License. */
14
14
15
- #include " RowConvOp .h"
16
- #include " hl_base .h"
15
+ #include " paddle/cuda/include/hl_base .h"
16
+ #include " paddle/function/RowConvOp .h"
17
17
18
18
namespace paddle {
19
19
@@ -94,7 +94,7 @@ __global__ void KeRowConv2(real* y,
94
94
}
95
95
96
96
template <>
97
- void RowConv<DEVICE_TYPE_GPU>(GpuMatrix& out,
97
+ void RowConv<DEVICE_TYPE_GPU>(GpuMatrix& out, // NOLINT
98
98
const GpuMatrix& in,
99
99
const GpuMatrix& filter,
100
100
const GpuIVector& seq) {
@@ -144,6 +144,10 @@ __global__ void KeRowConvBwWeight(real* dw,
144
144
}
145
145
__syncthreads ();
146
146
147
+ // NOTE(zcd): temporary solution
148
+ unsigned mask = 0u ;
149
+ CREATE_SHFL_MASK (mask, true );
150
+
147
151
for (int i = 0 ; i < numSeq; ++i) {
148
152
const int start = starts[i];
149
153
const int end = starts[i + 1 ];
@@ -170,11 +174,10 @@ __global__ void KeRowConvBwWeight(real* dw,
170
174
real val = sh_x[tidy][tidx] * sh_dy[tidy][tidx + context - 1 - t];
171
175
__syncthreads ();
172
176
// warp size and blockDim.x is 32.
173
- val += __shfl_down (val, 16 );
174
- val += __shfl_down (val, 8 );
175
- val += __shfl_down (val, 4 );
176
- val += __shfl_down (val, 2 );
177
- val += __shfl_down (val, 1 );
177
+
178
+ for (int offset = 16 ; offset > 0 ; offset /= 2 )
179
+ val += __shfl_down_sync (mask, val, offset);
180
+
178
181
__syncthreads ();
179
182
if (tidx == 0 ) {
180
183
sh_dw[t][tidy] += val;
@@ -205,6 +208,10 @@ __global__ void KeRowConvBwWeight2(real* dw,
205
208
__shared__ real sh_x[BLOCK_H][BLOCK_W];
206
209
__shared__ real sh_dy[BLOCK_H][BLOCK_W];
207
210
211
+ // NOTE(zcd): temporary solution
212
+ unsigned mask = 0u ;
213
+ CREATE_SHFL_MASK (mask, true );
214
+
208
215
for (int i = 0 ; i < numSeq; ++i) {
209
216
const int start = starts[i];
210
217
const int end = starts[i + 1 ];
@@ -230,11 +237,9 @@ __global__ void KeRowConvBwWeight2(real* dw,
230
237
real val = sh_x[tidy][tidx] * sh_dy[tidy][tidx];
231
238
__syncthreads ();
232
239
// warp size and blockDim.x is 32.
233
- val += __shfl_down (val, 16 );
234
- val += __shfl_down (val, 8 );
235
- val += __shfl_down (val, 4 );
236
- val += __shfl_down (val, 2 );
237
- val += __shfl_down (val, 1 );
240
+ for (int offset = 16 ; offset > 0 ; offset /= 2 )
241
+ val += __shfl_down_sync (mask, val, offset);
242
+
238
243
__syncthreads ();
239
244
240
245
if (tidx == 0 && (gidx + tidy) < width) {
@@ -323,8 +328,8 @@ template <>
323
328
void RowConvGrad<DEVICE_TYPE_GPU>(const GpuMatrix& outG,
324
329
const GpuMatrix& in,
325
330
const GpuMatrix& filter,
326
- GpuMatrix& inG,
327
- GpuMatrix& filterG,
331
+ GpuMatrix& inG, // NOLINT
332
+ GpuMatrix& filterG, // NOLINT
328
333
const GpuIVector& seq) {
329
334
const size_t numSeq = seq.getSize () - 1 ;
330
335
const size_t contextLength = filter.getHeight ();
0 commit comments