Skip to content

Commit b3cf133

Browse files
committed
optimize
1 parent 9494833 commit b3cf133

File tree

1 file changed

+51
-54
lines changed

1 file changed

+51
-54
lines changed
Lines changed: 51 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,92 +1,89 @@
11
#include "pad_reflect_1d.cuh"
22

3-
static __global__ void pad_reflect_1d_kernel_f32(
4-
const void * __restrict__ src0,
5-
void * __restrict__ dst,
6-
const int64_t ne0,
7-
const int64_t ne00,
8-
const int64_t ne01,
9-
const uint3 ne01_packed,
10-
const int64_t ne02,
11-
const int64_t ne03,
12-
const int64_t nb00,
13-
const int64_t nb01,
14-
const int64_t nb02,
15-
const int64_t nb03,
16-
const int64_t nb0,
17-
const int64_t nb1,
18-
const int64_t nb2,
19-
const int64_t nb3,
20-
const int p0,
21-
const int p1) {
22-
3+
static __global__ __launch_bounds__(CUDA_PAD_REFLECT_1D_BLOCK_SIZE, 1) void
4+
pad_reflect_1d_kernel_f32(
5+
const void * __restrict__ src0,
6+
void * __restrict__ dst,
7+
const int64_t ne0,
8+
const int64_t ne00,
9+
const uint3 ne01,
10+
const int64_t ne02,
11+
const int64_t ne03,
12+
const int64_t nb00,
13+
const int64_t nb01,
14+
const int64_t nb02,
15+
const int64_t nb03,
16+
const int64_t nb0,
17+
const int64_t nb1,
18+
const int64_t nb2,
19+
const int64_t nb3,
20+
const int p0,
21+
const int p1) {
2322
const int64_t i3 = blockIdx.z;
2423
const int64_t i2 = blockIdx.y;
2524

26-
const uint2 div_mod_packed = fast_div_modulo(blockIdx.x, ne01_packed);
27-
const int64_t tile1 = div_mod_packed.y; // i1
28-
const int64_t tile0 = div_mod_packed.x; // nth i0 tile
29-
const int64_t i1 = tile1;
30-
const int64_t i0 = threadIdx.x + tile0 * blockDim.x;
31-
if ( i0 >= ne0 || i1 >= ne01 || i2 >= ne02 || i3 >= ne03 ) {
25+
const uint2 div_mod_packed = fast_div_modulo(blockIdx.x, ne01);
26+
const int64_t tile1 = div_mod_packed.y; // i1
27+
const int64_t tile0 = div_mod_packed.x; // nth i0 tile
28+
const int64_t i1 = tile1;
29+
const int64_t i0 = threadIdx.x + tile0 * blockDim.x;
30+
31+
// ne01.z is original value of unpacked ne01 (see init_fastdiv_values in common.cuh)
32+
if (i0 >= ne0 || i1 >= ne01.z || i2 >= ne02 || i3 >= ne03) {
3233
return;
3334
}
3435

35-
const char * src0_ptr = (const char *)src0 + i3*nb03 + i2*nb02 + i1*nb01;
36-
char * dst_ptr = (char *)dst + i3*nb3 + i2*nb2 + i1*nb1;
36+
const char * src0_ptr = (const char *) src0 + i3 * nb03 + i2 * nb02 + i1 * nb01;
37+
char * dst_ptr = (char *) dst + i3 * nb3 + i2 * nb2 + i1 * nb1;
3738

38-
float value;
39+
float value;
3940
const int64_t j = i0 - p0;
4041

41-
if ( j<0 ) {// i0<p0
42+
if (j < 0) { // i0<p0
4243
// Left padding - reflect
43-
value = *(const float *)(src0_ptr - j * nb00);
44-
} else if ( j < ne00 ) { //i0 < ne0 - p1
44+
value = *(const float *) (src0_ptr - j * nb00);
45+
} else if (j < ne00) { //i0 < ne0 - p1
4546
// Middle - copy
46-
value = *(const float *)(src0_ptr + j * nb00);
47-
} else {
47+
value = *(const float *) (src0_ptr + j * nb00);
48+
} else {
4849
// Right padding - reflect
4950
const int64_t src_idx = (ne0 - p1 - p0) - (p1 + 1 - (ne0 - i0)) - 1;
50-
value = *(const float *)(src0_ptr + src_idx * nb00);
51+
value = *(const float *) (src0_ptr + src_idx * nb00);
5152
}
52-
*(float *)(dst_ptr + i0 * nb0) = value;
53+
*(float *) (dst_ptr + i0 * nb0) = value;
5354
}
5455

5556
void ggml_cuda_op_pad_reflect_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
56-
const ggml_tensor * src0 = dst->src[0];
57-
cudaStream_t stream = ctx.stream();
57+
const ggml_tensor * src0 = dst->src[0];
58+
cudaStream_t stream = ctx.stream();
5859

5960
GGML_ASSERT(src0->type == GGML_TYPE_F32);
6061
GGML_ASSERT(dst->type == GGML_TYPE_F32);
6162

6263
const int32_t * opts = (const int32_t *) dst->op_params;
63-
const int p0 = opts[0];
64-
const int p1 = opts[1];
64+
const int p0 = opts[0];
65+
const int p1 = opts[1];
6566

66-
const int64_t ne00 = src0->ne[0];
67-
const int64_t ne01 = src0->ne[1];
67+
const int64_t ne00 = src0->ne[0];
68+
const int64_t ne01 = src0->ne[1];
6869
const uint3 ne01_packed = init_fastdiv_values(ne01);
69-
const int64_t ne02 = src0->ne[2];
70-
const int64_t ne03 = src0->ne[3];
70+
const int64_t ne02 = src0->ne[2];
71+
const int64_t ne03 = src0->ne[3];
7172

7273
const int64_t ne0 = dst->ne[0];
7374

7475
// sanity: padded length matches
7576
GGML_ASSERT(ne0 == ne00 + p0 + p1);
7677

77-
constexpr int64_t bx = CUDA_PAD_REFLECT_1D_BLOCK_SIZE; // threads per block (x)
78-
const int64_t tiles0 = (ne0 + bx - 1) / bx; // number of tiles along i0
78+
constexpr int64_t bx = CUDA_PAD_REFLECT_1D_BLOCK_SIZE; // threads per block (x)
79+
const int64_t tiles0 = (ne0 + bx - 1) / bx; // number of tiles along i0
7980
// grid.x covers i1 and all tiles of i0: [ne01 * tiles0]
8081
// grid.y covers i2: [ne02]
8182
// grid.z covers i3: [ne03]
82-
const dim3 grid_dims((unsigned)(ne01 * tiles0), (unsigned)ne02, (unsigned)ne03);
83-
const dim3 block_dims((unsigned)bx, 1, 1);
83+
const dim3 grid_dims((unsigned) (ne01 * tiles0), (unsigned) ne02, (unsigned) ne03);
84+
const dim3 block_dims((unsigned) bx, 1, 1);
8485

8586
pad_reflect_1d_kernel_f32<<<grid_dims, block_dims, 0, stream>>>(
86-
src0->data, dst->data,
87-
ne0, ne00, ne01, ne01_packed, ne02, ne03,
88-
src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
89-
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
90-
p0, p1
91-
);
87+
src0->data, dst->data, ne0, ne00, ne01_packed, ne02, ne03, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
88+
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], p0, p1);
9289
}

0 commit comments

Comments
 (0)