-
Notifications
You must be signed in to change notification settings - Fork 13.5k
CUDA: add set #14980
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
CUDA: add set #14980
Changes from 2 commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,119 @@ | ||
| #include "ggml-cuda/common.cuh" | ||
| #include "set.cuh" | ||
|
|
||
| static __global__ void set_f32_cuda_copy(const float * __restrict__ src0, | ||
| float * __restrict__ dst, | ||
| const size_t ne0, | ||
| const size_t ne1, | ||
| const size_t ne2, | ||
| const size_t ne3, | ||
| const size_t nb0, | ||
| const size_t nb1, | ||
| const size_t nb2, | ||
| const size_t nb3) { | ||
| const size_t total = ne0 * ne1 * ne2 * ne3; | ||
| const size_t gid = blockIdx.x * blockDim.x + threadIdx.x; | ||
| if (gid >= total) { | ||
| return; | ||
| } | ||
|
|
||
| size_t tmp = gid; | ||
|
|
||
| const size_t i0 = tmp % ne0; | ||
| tmp /= ne0; | ||
| const size_t i1 = tmp % ne1; | ||
| tmp /= ne1; | ||
| const size_t i2 = tmp % ne2; | ||
| tmp /= ne2; | ||
| const size_t i3 = tmp; | ||
|
|
||
| const size_t pos = (i0 * nb0 + i1 * nb1 + i2 * nb2 + i3 * nb3); | ||
|
|
||
| *((float *) ((char *) dst + pos)) = *((const float *) ((const char *) src0 + pos)); | ||
| } | ||
|
|
||
| static __global__ void set_f32_cuda(const float * __restrict__ src1, | ||
| float * __restrict__ dst, | ||
| const size_t ne10, | ||
| const size_t ne11, | ||
| const size_t ne12, | ||
| const size_t ne13, | ||
| const size_t nb10, | ||
| const size_t nb11, | ||
| const size_t nb12, | ||
| const size_t nb13, | ||
| const size_t nb0, | ||
| const size_t nb1, | ||
| const size_t nb2, | ||
| const size_t nb3, | ||
| const size_t offset | ||
|
|
||
| ) { | ||
| const size_t total = ne10 * ne11 * ne12 * ne13; | ||
| const size_t gid = blockIdx.x * blockDim.x + threadIdx.x; | ||
| if (gid >= total) { | ||
| return; | ||
| } | ||
|
|
||
| size_t tmp = gid; | ||
|
|
||
| const size_t i0 = tmp % ne10; | ||
| tmp /= ne10; | ||
| const size_t i1 = tmp % ne11; | ||
| tmp /= ne11; | ||
| const size_t i2 = tmp % ne12; | ||
| tmp /= ne12; | ||
| const size_t i3 = tmp; | ||
|
|
||
| size_t dst_offset = offset + i0 * nb0 + i1 * nb1 + i2 * nb2 + i3 * nb3; | ||
| size_t src1_offset = i0 * nb10 + i1 * nb11 + i2 * nb12 + i3 * nb13; | ||
|
|
||
| *((float *) ((char *) dst + dst_offset)) = *((const float *) ((const char *) src1 + src1_offset)); | ||
| } | ||
|
|
||
| void ggml_cuda_op_set(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { | ||
| // nb0 is implicitly element_size because src0 and dst are contiguous | ||
| const int32_t nb1 = dst->op_params[0]; | ||
| const int32_t nb2 = dst->op_params[1]; | ||
| const int32_t nb3 = dst->op_params[2]; | ||
| const int32_t offset = dst->op_params[3]; | ||
| const bool inplace = dst->op_params[4]; | ||
|
|
||
| const ggml_tensor * src0 = dst->src[0]; | ||
| const ggml_tensor * src1 = dst->src[1]; | ||
|
|
||
| GGML_ASSERT(ggml_are_same_shape(src0, dst)); | ||
|
|
||
| // TODO: support more dtypes. | ||
| GGML_ASSERT(src0->type == GGML_TYPE_F32); | ||
| GGML_ASSERT(src1->type == GGML_TYPE_F32); | ||
| GGML_ASSERT(dst->type == GGML_TYPE_F32); | ||
|
|
||
| GGML_TENSOR_BINARY_OP_LOCALS01; | ||
| const int nb0 = ggml_element_size(dst); | ||
|
|
||
| const float * src0_d = (const float *) src0->data; | ||
| const float * src1_d = (const float *) src1->data; | ||
| float * dst_d = (float *) dst->data; | ||
|
|
||
| cudaStream_t stream = ctx.stream(); | ||
|
|
||
| if (!inplace) { | ||
| // copy whole src0 -> dst. | ||
| const size_t total = ne00 * ne01 * ne02 * ne03; | ||
|
|
||
| const int num_blocks = (total + CUDA_SET_BLOCK_SIZE - 1) / CUDA_SET_BLOCK_SIZE; | ||
|
|
||
| set_f32_cuda_copy<<<num_blocks, CUDA_SET_BLOCK_SIZE, 0, stream>>>( | ||
| src0_d, dst_d, ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03); | ||
am17an marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| } | ||
|
|
||
| // set: src1 -> dst | ||
| // set_f32_cuda | ||
|
|
||
| const size_t total = ne10 * ne11 * ne12 * ne13; | ||
| const size_t num_blocks = (total + CUDA_SET_BLOCK_SIZE - 1) / CUDA_SET_BLOCK_SIZE; | ||
|
|
||
| set_f32_cuda<<<num_blocks, CUDA_SET_BLOCK_SIZE, 0, stream>>>( | ||
| src1_d, dst_d, ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13, nb0, nb1, nb2, nb3, offset); | ||
| } | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,7 @@ | ||
| #pragma once | ||
|
|
||
| #include "common.cuh" | ||
|
|
||
| #define CUDA_SET_BLOCK_SIZE 256 | ||
|
|
||
| void ggml_cuda_op_set(ggml_backend_cuda_context & ctx, ggml_tensor * dst); |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.