@@ -21,32 +21,33 @@ static __global__ void pad_reflect_1d_kernel_f32(
2121
2222 const int64_t i3 = blockIdx .z ;
2323 const int64_t i2 = blockIdx .y ;
24- const int64_t i1 = blockIdx .x ;
2524
26- if (i1 >= ne01 || i2 >= ne02 || i3 >= ne03) {
25+ const int64_t tile1 = blockIdx .x % ne01; // i1
26+ const int64_t tile0 = blockIdx .x / ne01; // nth i0 tile
27+ const int64_t i1 = tile1;
28+ const int64_t i0 = threadIdx .x + tile0 * blockDim .x ;
29+ if ( i0 >= ne0 || i1 >= ne01 || i2 >= ne02 || i3 >= ne03 ) {
2730 return ;
2831 }
2932
3033 const char * src0_ptr = (const char *)src0 + i3*nb03 + i2*nb02 + i1*nb01;
31- char * dst_ptr = (char *)dst + i3*nb3 + i2*nb2 + i1*nb1;
34+ const char * dst_ptr = (const char *)dst + i3*nb3 + i2*nb2 + i1*nb1;
3235
33- for ( int64_t i0 = threadIdx . x ; i0 < ne0; i0 += blockDim . x ) {
34- float value ;
36+ float value;
37+ const int64_t j = i0 - p0 ;
3538
36- if (i0 < p0) {
37- // Left padding - reflect
38- value = *(const float *)(src0_ptr + (p0 - i0) * nb00);
39- } else if (i0 < ne0 - p1) {
40- // Middle - copy
41- value = *(const float *)(src0_ptr + (i0 - p0) * nb00);
42- } else {
43- // Right padding - reflect
44- int64_t src_idx = (ne0 - p1 - p0) - (p1 + 1 - (ne0 - i0)) - 1 ;
45- value = *(const float *)(src0_ptr + src_idx * nb00);
46- }
47-
48- *(float *)(dst_ptr + i0 * nb0) = value;
39+ if ( j<0 ) {// i0<p0
40+ // Left padding - reflect
41+ value = *(const float *)(src0_ptr - j * nb00);
42+ } else if (j < ne00) { // i0 < ne0 - p1
43+ // Middle - copy
44+ value = *(const float *)(src0_ptr + j * nb00);
45+ } else {
46+ // Right padding - reflect
47+ const int64_t src_idx = (ne0 - p1 - p0) - (p1 + 1 - (ne0 - i0)) - 1 ;
48+ value = *(const float *)(src0_ptr + src_idx * nb00);
4949 }
50+ *(float *)(dst_ptr + i0 * nb0) = value;
5051}
5152
5253void ggml_cuda_op_pad_reflect_1d (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@@ -67,10 +68,16 @@ void ggml_cuda_op_pad_reflect_1d(ggml_backend_cuda_context & ctx, ggml_tensor *
6768
6869 const int64_t ne0 = dst->ne [0 ];
6970
71+ // sanity: padded length matches
7072 GGML_ASSERT (ne0 == ne00 + p0 + p1);
7173
72- const dim3 block_dims (CUDA_PAD_REFLECT_1D_BLOCK_SIZE, 1 , 1 );
73- const dim3 grid_dims (ne01, ne02, ne03);
74+ constexpr int64_t bx = CUDA_PAD_REFLECT_1D_BLOCK_SIZE; // threads per block (x)
75+ const int64_t tiles0 = (ne0 + bx - 1 ) / bx; // number of tiles along i0
76+ // grid.x covers i1 and all tiles of i0: [ne01 * tiles0]
77+ // grid.y covers i2: [ne02]
78+ // grid.z covers i3: [ne03]
79+ const dim3 grid_dims ((unsigned )(ne01 * tiles0), (unsigned )ne02, (unsigned )ne03);
80+ const dim3 block_dims ((unsigned )bx, 1 , 1 );
7481
7582 pad_reflect_1d_kernel_f32<<<grid_dims, block_dims, 0 , stream>>> (
7683 src0->data , dst->data ,
0 commit comments