@@ -6,6 +6,7 @@ static __global__ void pad_reflect_1d_kernel_f32(
66 const int64_t ne0,
77 const int64_t ne00,
88 const int64_t ne01,
9+ const uint3 ne01_packed,
910 const int64_t ne02,
1011 const int64_t ne03,
1112 const int64_t nb00,
@@ -22,24 +23,25 @@ static __global__ void pad_reflect_1d_kernel_f32(
2223 const int64_t i3 = blockIdx .z ;
2324 const int64_t i2 = blockIdx .y ;
2425
25- const int64_t tile1 = blockIdx .x % ne01; // i1
26- const int64_t tile0 = blockIdx .x / ne01; // nth i0 tile
26+ const uint2 div_mod_packed = fast_div_modulo (blockIdx .x , ne01_packed);
27+ const int64_t tile1 = div_mod_packed.y ; // i1
28+ const int64_t tile0 = div_mod_packed.x ; // nth i0 tile
2729 const int64_t i1 = tile1;
2830 const int64_t i0 = threadIdx .x + tile0 * blockDim .x ;
2931 if ( i0 >= ne0 || i1 >= ne01 || i2 >= ne02 || i3 >= ne03 ) {
3032 return ;
3133 }
3234
3335 const char * src0_ptr = (const char *)src0 + i3*nb03 + i2*nb02 + i1*nb01;
34- const char * dst_ptr = (const char *)dst + i3*nb3 + i2*nb2 + i1*nb1;
36+ char * dst_ptr = (char *)dst + i3*nb3 + i2*nb2 + i1*nb1;
3537
3638 float value;
3739 const int64_t j = i0 - p0;
3840
3941 if ( j<0 ) {// i0<p0
4042 // Left padding - reflect
4143 value = *(const float *)(src0_ptr - j * nb00);
42- } else if (j < ne00) { // i0 < ne0 - p1
44+ } else if ( j < ne00 ) { // i0 < ne0 - p1
4345 // Middle - copy
4446 value = *(const float *)(src0_ptr + j * nb00);
4547 } else {
@@ -63,6 +65,7 @@ void ggml_cuda_op_pad_reflect_1d(ggml_backend_cuda_context & ctx, ggml_tensor *
6365
6466 const int64_t ne00 = src0->ne [0 ];
6567 const int64_t ne01 = src0->ne [1 ];
68+ const uint3 ne01_packed = init_fastdiv_values (ne01);
6669 const int64_t ne02 = src0->ne [2 ];
6770 const int64_t ne03 = src0->ne [3 ];
6871
@@ -81,7 +84,7 @@ void ggml_cuda_op_pad_reflect_1d(ggml_backend_cuda_context & ctx, ggml_tensor *
8184
8285 pad_reflect_1d_kernel_f32<<<grid_dims, block_dims, 0 , stream>>> (
8386 src0->data , dst->data ,
84- ne0, ne00, ne01, ne02, ne03,
87+ ne0, ne00, ne01, ne01_packed, ne02, ne03,
8588 src0->nb [0 ], src0->nb [1 ], src0->nb [2 ], src0->nb [3 ],
8689 dst->nb [0 ], dst->nb [1 ], dst->nb [2 ], dst->nb [3 ],
8790 p0, p1
0 commit comments