Skip to content

Commit 9494833

Browse files
committed
use fast_div to improve performance
1 parent 1e29faf commit 9494833

File tree

2 files changed

+16
-5
lines changed

2 files changed

+16
-5
lines changed

ggml/src/ggml-cuda/common.cuh

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -636,6 +636,14 @@ static __device__ __forceinline__ uint32_t fastmodulo(uint32_t n, const uint3 fa
636636
return n - fastdiv(n, fastdiv_values) * fastdiv_values.z;
637637
}
638638

639+
// Calculate both division and modulo at once, returns <n/divisor, n%divisor>
640+
static __device__ __forceinline__ uint2 fast_div_modulo(uint32_t n, const uint3 fastdiv_values) {
641+
// expects fastdiv_values to contain <mp, L, divisor> in <x, y, z> (see init_fastdiv_values)
642+
const uint32_t div_val = fastdiv(n, fastdiv_values);
643+
const uint32_t mod_val = n - div_val * fastdiv_values.z;
644+
return make_uint2(div_val, mod_val);
645+
}
646+
639647
typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, float2 & v);
640648

641649
static __device__ __forceinline__ float get_alibi_slope(

ggml/src/ggml-cuda/pad_reflect_1d.cu

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ static __global__ void pad_reflect_1d_kernel_f32(
66
const int64_t ne0,
77
const int64_t ne00,
88
const int64_t ne01,
9+
const uint3 ne01_packed,
910
const int64_t ne02,
1011
const int64_t ne03,
1112
const int64_t nb00,
@@ -22,24 +23,25 @@ static __global__ void pad_reflect_1d_kernel_f32(
2223
const int64_t i3 = blockIdx.z;
2324
const int64_t i2 = blockIdx.y;
2425

25-
const int64_t tile1 = blockIdx.x % ne01; // i1
26-
const int64_t tile0 = blockIdx.x / ne01; // nth i0 tile
26+
const uint2 div_mod_packed = fast_div_modulo(blockIdx.x, ne01_packed);
27+
const int64_t tile1 = div_mod_packed.y; // i1
28+
const int64_t tile0 = div_mod_packed.x; // nth i0 tile
2729
const int64_t i1 = tile1;
2830
const int64_t i0 = threadIdx.x + tile0 * blockDim.x;
2931
if ( i0 >= ne0 || i1 >= ne01 || i2 >= ne02 || i3 >= ne03 ) {
3032
return;
3133
}
3234

3335
const char * src0_ptr = (const char *)src0 + i3*nb03 + i2*nb02 + i1*nb01;
34-
const char * dst_ptr = (const char *)dst + i3*nb3 + i2*nb2 + i1*nb1;
36+
char * dst_ptr = (char *)dst + i3*nb3 + i2*nb2 + i1*nb1;
3537

3638
float value;
3739
const int64_t j = i0 - p0;
3840

3941
if ( j<0 ) {// i0<p0
4042
// Left padding - reflect
4143
value = *(const float *)(src0_ptr - j * nb00);
42-
} else if (j < ne00) { //i0 < ne0 - p1
44+
} else if ( j < ne00 ) { //i0 < ne0 - p1
4345
// Middle - copy
4446
value = *(const float *)(src0_ptr + j * nb00);
4547
} else {
@@ -63,6 +65,7 @@ void ggml_cuda_op_pad_reflect_1d(ggml_backend_cuda_context & ctx, ggml_tensor *
6365

6466
const int64_t ne00 = src0->ne[0];
6567
const int64_t ne01 = src0->ne[1];
68+
const uint3 ne01_packed = init_fastdiv_values(ne01);
6669
const int64_t ne02 = src0->ne[2];
6770
const int64_t ne03 = src0->ne[3];
6871

@@ -81,7 +84,7 @@ void ggml_cuda_op_pad_reflect_1d(ggml_backend_cuda_context & ctx, ggml_tensor *
8184

8285
pad_reflect_1d_kernel_f32<<<grid_dims, block_dims, 0, stream>>>(
8386
src0->data, dst->data,
84-
ne0, ne00, ne01, ne02, ne03,
87+
ne0, ne00, ne01, ne01_packed, ne02, ne03,
8588
src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
8689
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
8790
p0, p1

0 commit comments

Comments
 (0)