-
Notifications
You must be signed in to change notification settings - Fork 13.4k
CUDA: add set rows for f32 and f16 #14551
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+147
−0
Merged
Changes from 3 commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
#include "set-rows.cuh" | ||
|
||
typedef void (*set_rows_kernel_t)(const char * src, char * dst); | ||
|
||
template<typename src_t, typename dst_t> | ||
__device__ void set_rows_1(const src_t * src_f, dst_t * dst_f) {} | ||
|
||
template<> | ||
__device__ __forceinline__ void set_rows_1<float, half>(const float * src_f, half * dst_h) { | ||
*dst_h = __float2half(*src_f); | ||
} | ||
|
||
template<> | ||
__device__ __forceinline__ void set_rows_1<float, float>(const float * src_f, float * dst_f) { | ||
*dst_f = *src_f; | ||
} | ||
|
||
template<typename src_t, typename dst_t> | ||
static __global__ void k_set_rows( | ||
const src_t * __restrict__ src0, const int64_t * __restrict__ src1, dst_t * __restrict__ dst, | ||
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, | ||
const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13, | ||
const size_t nb01, const size_t nb02, const size_t nb03, | ||
const size_t nb10, const size_t nb11, const size_t nb12, | ||
const size_t nb1, const size_t nb2, const size_t nb3) { | ||
|
||
const int64_t i = blockDim.x * blockIdx.x + threadIdx.x; | ||
am17an marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
const int64_t ne_total = ne00 * ne01 * ne02 * ne03; | ||
|
||
if (i >= ne_total) { | ||
return; | ||
} | ||
|
||
const int64_t i03 = i / (ne00 * ne01 * ne02); | ||
const int64_t i02 = (i - i03 * ne00 * ne01 * ne02) / (ne00 * ne01); | ||
const int64_t i01 = (i - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01) / ne00; | ||
const int64_t i00 = i - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01 - i01 * ne00; | ||
|
||
const int64_t i12 = i03 % ne12; | ||
const int64_t i11 = i02 % ne11; | ||
const int64_t i10 = i01; | ||
|
||
const int64_t dst_row = *(src1 + i10*nb10 + i11*nb11 + i12*nb12); | ||
|
||
const src_t * src0_row = src0 + i01*nb01 + i02*nb02 + i03*nb03; | ||
dst_t * dst_row_ptr = dst + dst_row*nb1 + i02*nb2 + i03*nb3; | ||
|
||
const src_t* src_elem = src0_row + i00; | ||
dst_t* dst_elem = dst_row_ptr + i00; | ||
set_rows_1(src_elem, dst_elem); | ||
} | ||
|
||
template<typename src_t, typename dst_t> | ||
static void set_rows_cuda( | ||
const src_t * src0_d, const int64_t * src1_d, dst_t * dst_d, | ||
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, | ||
const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13, | ||
const size_t nb01, const size_t nb02, const size_t nb03, | ||
const size_t nb10, const size_t nb11, const size_t nb12, | ||
const size_t nb1, const size_t nb2, const size_t nb3, | ||
cudaStream_t stream) { | ||
|
||
const int64_t ne_total = ne00 * ne01 * ne02 * ne03; | ||
const int num_blocks = (ne_total + CUDA_SET_ROWS_BLOCK_SIZE - 1) / CUDA_SET_ROWS_BLOCK_SIZE; | ||
const dim3 block_size(CUDA_SET_ROWS_BLOCK_SIZE); | ||
const dim3 grid_size(num_blocks); | ||
|
||
|
||
const int64_t s01 = nb01/sizeof(src_t); | ||
const int64_t s02 = nb02/sizeof(src_t); | ||
const int64_t s03 = nb03/sizeof(src_t); | ||
const int64_t s10 = nb10/sizeof(int64_t); | ||
const int64_t s11 = nb11/sizeof(int64_t); | ||
const int64_t s12 = nb12/sizeof(int64_t); | ||
const int64_t s1 = nb1/sizeof(dst_t); | ||
const int64_t s2 = nb2/sizeof(dst_t); | ||
const int64_t s3 = nb3/sizeof(dst_t); | ||
|
||
if (ne_total > 0) { | ||
k_set_rows<<<grid_size, block_size, 0, stream>>>( | ||
src0_d, src1_d, dst_d, | ||
ne00, ne01, ne02, ne03, | ||
ne10, ne11, ne12, ne13, | ||
s01, s02, s03, | ||
s10, s11, s12, | ||
s1, s2, s3); | ||
} | ||
} | ||
|
||
|
||
void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { | ||
const ggml_tensor * src0 = dst->src[0]; | ||
const ggml_tensor * src1 = dst->src[1]; | ||
|
||
GGML_ASSERT(src0->type == GGML_TYPE_F32); | ||
GGML_ASSERT(src1->type == GGML_TYPE_I64); | ||
|
||
GGML_TENSOR_BINARY_OP_LOCALS | ||
|
||
const float * src0_d = (const float *)src0->data; | ||
const int64_t * src1_d = (const int64_t *)src1->data; | ||
|
||
cudaStream_t stream = ctx.stream(); | ||
|
||
|
||
|
||
if (dst->type == GGML_TYPE_F32) { | ||
set_rows_cuda( | ||
src0_d, src1_d, (float*)dst->data, | ||
ne00, ne01, ne02, ne03, | ||
ne10, ne11, ne12, ne13, | ||
nb01, nb02, nb03, | ||
nb10, nb11, nb12, | ||
nb1, nb2, nb3, | ||
stream | ||
); | ||
} else if (dst->type == GGML_TYPE_F16) { | ||
set_rows_cuda( | ||
src0_d, src1_d, (half*)dst->data, | ||
ne00, ne01, ne02, ne03, | ||
ne10, ne11, ne12, ne13, | ||
nb01, nb02, nb03, | ||
nb10, nb11, nb12, | ||
nb1, nb2, nb3, | ||
stream | ||
); | ||
} else { | ||
GGML_ABORT("unsupported type"); | ||
} | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
#pragma once | ||
|
||
#include "common.cuh" | ||
|
||
#define CUDA_SET_ROWS_BLOCK_SIZE 256 | ||
|
||
void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst); |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.