-
Notifications
You must be signed in to change notification settings - Fork 13.5k
CUDA: Optimize PAD_REFLECT_1D #15957
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
JohannesGaessler
merged 11 commits into
ggml-org:master
from
bugparty:PAD_REFLECT_1D_expriment
Sep 18, 2025
Merged
Changes from all commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
1e29faf
CUDA: Optimize PAD_REFLECT_1D
bugparty 9494833
use fast_div to improve performance
bugparty 8583552
Apply suggestion from @JohannesGaessler
bugparty a5ef1d0
Apply suggestion from @JohannesGaessler
bugparty b3cf133
optimize
bugparty d73ba84
use a concise expression to further speedup the cuda kernel
bugparty e280cb8
add comment for rel_i0
bugparty 188ce93
Merge branch 'ggml-org:master' into PAD_REFLECT_1D_expriment
bugparty 4286ea7
Merge branch 'ggml-org:master' into PAD_REFLECT_1D_expriment
bugparty dd6789b
Merge branch 'ggml-org:master' into PAD_REFLECT_1D_expriment
bugparty aa12620
Merge branch 'ggml-org:master' into PAD_REFLECT_1D_expriment
bugparty File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,82 +1,89 @@ | ||
| #include "pad_reflect_1d.cuh" | ||
|
|
||
| static __global__ void pad_reflect_1d_kernel_f32( | ||
| const void * __restrict__ src0, | ||
| void * __restrict__ dst, | ||
| const int64_t ne0, | ||
| const int64_t ne00, | ||
| const int64_t ne01, | ||
| const int64_t ne02, | ||
| const int64_t ne03, | ||
| const int64_t nb00, | ||
| const int64_t nb01, | ||
| const int64_t nb02, | ||
| const int64_t nb03, | ||
| const int64_t nb0, | ||
| const int64_t nb1, | ||
| const int64_t nb2, | ||
| const int64_t nb3, | ||
| const int p0, | ||
| const int p1) { | ||
|
|
||
| static __global__ __launch_bounds__(CUDA_PAD_REFLECT_1D_BLOCK_SIZE, 1) void | ||
| pad_reflect_1d_kernel_f32( | ||
| const void * __restrict__ src0, | ||
| void * __restrict__ dst, | ||
| const int64_t ne0, | ||
| const int64_t ne00, | ||
| const uint3 ne01, | ||
| const int64_t ne02, | ||
| const int64_t ne03, | ||
| const int64_t nb00, | ||
| const int64_t nb01, | ||
| const int64_t nb02, | ||
| const int64_t nb03, | ||
| const int64_t nb0, | ||
| const int64_t nb1, | ||
| const int64_t nb2, | ||
| const int64_t nb3, | ||
| const int p0, | ||
| const int p1) { | ||
| const int64_t i3 = blockIdx.z; | ||
| const int64_t i2 = blockIdx.y; | ||
| const int64_t i1 = blockIdx.x; | ||
|
|
||
| if (i1 >= ne01 || i2 >= ne02 || i3 >= ne03) { | ||
| const uint2 div_mod_packed = fast_div_modulo(blockIdx.x, ne01); | ||
| const int64_t tile1 = div_mod_packed.y; // i1 | ||
| const int64_t tile0 = div_mod_packed.x; // nth i0 tile | ||
| const int64_t i1 = tile1; | ||
| const int64_t i0 = threadIdx.x + tile0 * blockDim.x; | ||
|
|
||
| // ne01.z is original value of unpacked ne01 (see init_fastdiv_values in common.cuh) | ||
| if (i0 >= ne0 || i1 >= ne01.z || i2 >= ne02 || i3 >= ne03) { | ||
| return; | ||
| } | ||
|
|
||
| const char * src0_ptr = (const char *)src0 + i3*nb03 + i2*nb02 + i1*nb01; | ||
| char * dst_ptr = (char *)dst + i3*nb3 + i2*nb2 + i1*nb1; | ||
|
|
||
| for (int64_t i0 = threadIdx.x; i0 < ne0; i0 += blockDim.x) { | ||
| float value; | ||
| const char * src0_ptr = (const char *) src0 + i3 * nb03 + i2 * nb02 + i1 * nb01; | ||
| char * dst_ptr = (char *) dst + i3 * nb3 + i2 * nb2 + i1 * nb1; | ||
|
|
||
| if (i0 < p0) { | ||
| // Left padding - reflect | ||
| value = *(const float *)(src0_ptr + (p0 - i0) * nb00); | ||
| } else if (i0 < ne0 - p1) { | ||
| // Middle - copy | ||
| value = *(const float *)(src0_ptr + (i0 - p0) * nb00); | ||
| } else { | ||
| // Right padding - reflect | ||
| int64_t src_idx = (ne0 - p1 - p0) - (p1 + 1 - (ne0 - i0)) - 1; | ||
| value = *(const float *)(src0_ptr + src_idx * nb00); | ||
| } | ||
| const int64_t rel_i0 = i0 - p0; // relative i0 in src0 | ||
| int64_t src_idx; | ||
bugparty marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| *(float *)(dst_ptr + i0 * nb0) = value; | ||
| if (rel_i0 < 0) { | ||
| // Left padding - reflect | ||
| src_idx = -rel_i0; | ||
| } else if (rel_i0 < ne00) { | ||
| // Middle - copy | ||
| src_idx = rel_i0; | ||
| } else { | ||
| // Right padding - reflect | ||
| src_idx = 2 * ne00 - 2 - rel_i0; | ||
| } | ||
| const float value = *(const float *) (src0_ptr + src_idx * nb00); | ||
| *(float *) (dst_ptr + i0 * nb0) = value; | ||
| } | ||
|
|
||
| void ggml_cuda_op_pad_reflect_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { | ||
| const ggml_tensor * src0 = dst->src[0]; | ||
| cudaStream_t stream = ctx.stream(); | ||
| const ggml_tensor * src0 = dst->src[0]; | ||
| cudaStream_t stream = ctx.stream(); | ||
|
|
||
| GGML_ASSERT(src0->type == GGML_TYPE_F32); | ||
| GGML_ASSERT(dst->type == GGML_TYPE_F32); | ||
|
|
||
| const int32_t * opts = (const int32_t *) dst->op_params; | ||
| const int p0 = opts[0]; | ||
| const int p1 = opts[1]; | ||
| const int p0 = opts[0]; | ||
| const int p1 = opts[1]; | ||
|
|
||
| const int64_t ne00 = src0->ne[0]; | ||
| const int64_t ne01 = src0->ne[1]; | ||
| const int64_t ne02 = src0->ne[2]; | ||
| const int64_t ne03 = src0->ne[3]; | ||
| const int64_t ne00 = src0->ne[0]; | ||
| const int64_t ne01 = src0->ne[1]; | ||
| const uint3 ne01_packed = init_fastdiv_values(ne01); | ||
| const int64_t ne02 = src0->ne[2]; | ||
| const int64_t ne03 = src0->ne[3]; | ||
|
|
||
| const int64_t ne0 = dst->ne[0]; | ||
|
|
||
| // sanity: padded length matches | ||
| GGML_ASSERT(ne0 == ne00 + p0 + p1); | ||
|
|
||
| const dim3 block_dims(CUDA_PAD_REFLECT_1D_BLOCK_SIZE, 1, 1); | ||
| const dim3 grid_dims(ne01, ne02, ne03); | ||
| constexpr int64_t bx = CUDA_PAD_REFLECT_1D_BLOCK_SIZE; // threads per block (x) | ||
| const int64_t tiles0 = (ne0 + bx - 1) / bx; // number of tiles along i0 | ||
| // grid.x covers i1 and all tiles of i0: [ne01 * tiles0] | ||
| // grid.y covers i2: [ne02] | ||
| // grid.z covers i3: [ne03] | ||
| const dim3 grid_dims((unsigned) (ne01 * tiles0), (unsigned) ne02, (unsigned) ne03); | ||
| const dim3 block_dims((unsigned) bx, 1, 1); | ||
|
|
||
| pad_reflect_1d_kernel_f32<<<grid_dims, block_dims, 0, stream>>>( | ||
| src0->data, dst->data, | ||
| ne0, ne00, ne01, ne02, ne03, | ||
| src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], | ||
| dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], | ||
| p0, p1 | ||
| ); | ||
| src0->data, dst->data, ne0, ne00, ne01_packed, ne02, ne03, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], | ||
| dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], p0, p1); | ||
| } | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.