Skip to content
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 47 additions & 2 deletions ggml/src/ggml-cuda/cpy.cu
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,34 @@

typedef void (*cpy_kernel_t)(const char * cx, char * cdst);

template<typename T>
static __global__ void cpy_contiguous(const T * cx, T * cdst_direct, const int ne_elements,
T ** cdst_indirect, int graph_cpynode_index) {
const int64_t tid = blockDim.x * blockIdx.x + threadIdx.x;
const int64_t stride = blockDim.x * gridDim.x;

T * cdst = (cdst_indirect != nullptr) ? cdst_indirect[graph_cpynode_index] : cdst_direct;

const int elements_per_thread = 4;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is declared both while launching the kernel and here, perhaps make it a constant like CUDA_CPY_BLOCK_SIZE

for (int64_t base_idx = tid * elements_per_thread; base_idx < ne_elements; base_idx += stride * elements_per_thread) {
const int64_t remaining = ne_elements - base_idx;

if (remaining >= elements_per_thread) {
if (base_idx % 4 == 0) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from what I understand base_idx is always a multiple of elements_per_threads which is 4, so this check is not neccessary?

*((float4*)(cdst + base_idx)) = *((const float4*)(cx + base_idx));
} else {
for (int j = 0; j < elements_per_thread && base_idx + j < ne_elements; ++j) {
cdst[base_idx + j] = cx[base_idx + j];
}
}
} else {
for (int j = 0; j < remaining; ++j) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since here remaining is < 4, we can do a unroll like below, but I doubt it will have any affect on performance

        #pragma unroll
        for (int j = 0; j < 4; ++j) {
            size_t i = base + (size_t)j;
            if (i < ne_elements) cdst[i] = cx[i];
        }

cdst[base_idx + j] = cx[base_idx + j];
}
}
}
}

template <cpy_kernel_t cpy_1>
static __global__ void cpy_flt(const char * cx, char * cdst_direct, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
Expand Down Expand Up @@ -138,6 +166,23 @@ void ggml_cuda_cpy_dest_ptrs_copy(ggml_cuda_graph * cuda_graph, char ** host_des
#endif
}

template<typename T>
static void ggml_cpy_contiguous_cuda(
const T * cx, T * cdst, const int ne_elements,
cudaStream_t stream, T ** cdst_indirect, int & graph_cpynode_index) {

if (ne_elements <= 0) {
return;
}

const int elements_per_thread = 4;
const int threads_needed = (ne_elements + elements_per_thread - 1) / elements_per_thread;
const int num_blocks = max(1, min(65535, (threads_needed + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this won't work if ne_elements is larger than a certain amount (I think 16726016). We can add an assert here or check whether num_blocks limit can be higher than 65535


cpy_contiguous<T><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
(cx, cdst, ne_elements, cdst_indirect, graph_cpynode_index++);
}

template<typename src_t, typename dst_t>
static void ggml_cpy_flt_cuda(
const char * cx, char * cdst, const int ne,
Expand Down Expand Up @@ -330,7 +375,7 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
#endif // GGML_USE_MUSA && GGML_MUSA_MUDNN_COPY
{
if (src0->type == GGML_TYPE_F32) {
ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
ggml_cpy_contiguous_cuda<float>((const float*)src0_ddc, (float*)src1_ddc, ne, main_stream, (float**)dest_ptrs_d, graph_cpynode_index);
} else {
CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
}
Expand Down Expand Up @@ -407,7 +452,7 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
// Prioritize CUDA graph compatibility over direct memory copy optimization.
// Using copy kernels here maintains graph indirection support, preventing performance regression from disabled CUDA graphs.
if (src0->type == GGML_TYPE_F32) {
return (void*) cpy_flt<cpy_1_flt<float, float>>;
return (void*) cpy_contiguous<float>;
} else {
return nullptr;
}
Expand Down
Loading