Skip to content

Commit 05eb910

Browse files
authored
Merge branch 'ggml-org:master' into master
2 parents 1d863bf + 4258e0c commit 05eb910

25 files changed

+921
-574
lines changed

ggml/src/ggml-cuda/common.cuh

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -944,13 +944,6 @@ struct ggml_cuda_graph {
944944
bool disable_due_to_failed_graph_capture = false;
945945
int number_consecutive_updates = 0;
946946
std::vector<ggml_graph_node_properties> ggml_graph_properties;
947-
bool use_cpy_indirection = false;
948-
std::vector<char *> cpy_dest_ptrs;
949-
char ** dest_ptrs_d;
950-
int dest_ptrs_size = 0;
951-
// Index to allow each cpy kernel to be aware of it's position within the graph
952-
// relative to other cpy nodes.
953-
int graph_cpynode_index = -1;
954947
#endif
955948
};
956949

ggml/src/ggml-cuda/cpy.cu

Lines changed: 55 additions & 163 deletions
Large diffs are not rendered by default.

ggml/src/ggml-cuda/cpy.cuh

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,6 @@
22

33
#define CUDA_CPY_BLOCK_SIZE 64
44

5-
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1, bool disable_indirection = false);
5+
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1);
66

77
void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
8-
9-
void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1);
10-
11-
void ggml_cuda_cpy_dest_ptrs_copy(ggml_cuda_graph * cuda_graph, char ** host_dest_ptrs, const int host_dest_ptrs_size, cudaStream_t stream);

ggml/src/ggml-cuda/fattn-vec.cuh

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -516,8 +516,8 @@ void ggml_cuda_flash_attn_ext_vec_case_impl(ggml_backend_cuda_context & ctx, ggm
516516
const int nthreads = ggml_cuda_fattn_vec_get_nthreads_host(cc);
517517
const int nwarps = nthreads / WARP_SIZE;
518518
fattn_kernel_t fattn_kernel = flash_attn_ext_vec<D, cols_per_block, type_K, type_V, use_logit_softcap>;
519-
constexpr bool need_f16_K = false;
520-
constexpr bool need_f16_V = false;
519+
const bool need_f16_K = type_K == GGML_TYPE_F16;
520+
const bool need_f16_V = type_V == GGML_TYPE_F16;
521521
constexpr size_t nbytes_shared = 0;
522522
launch_fattn<D, cols_per_block, 1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false);
523523
}
@@ -526,11 +526,6 @@ template <int D, ggml_type type_K, ggml_type type_V>
526526
void ggml_cuda_flash_attn_ext_vec_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
527527
const ggml_tensor * KQV = dst;
528528
const ggml_tensor * Q = dst->src[0];
529-
const ggml_tensor * K = dst->src[1];
530-
const ggml_tensor * V = dst->src[2];
531-
532-
GGML_ASSERT(K->type == type_K);
533-
GGML_ASSERT(V->type == type_V);
534529

535530
float logit_softcap;
536531
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));

ggml/src/ggml-cuda/fattn.cu

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -116,11 +116,15 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
116116
}
117117
}
118118

119-
#define FATTN_VEC_CASE(D, type_K, type_V) \
120-
if (Q->ne[0] == (D) && K->type == (type_K) && V->type == (type_V)) { \
121-
ggml_cuda_flash_attn_ext_vec_case<D, type_K, type_V>(ctx, dst); \
122-
return; \
123-
} \
119+
#define FATTN_VEC_CASE(D, type_K, type_V) \
120+
{ \
121+
const bool type_K_okay = K->type == (type_K) || (K->type == GGML_TYPE_F32 && (type_K) == GGML_TYPE_F16); \
122+
const bool type_V_okay = V->type == (type_V) || (V->type == GGML_TYPE_F32 && (type_V) == GGML_TYPE_F16); \
123+
if (Q->ne[0] == (D) && type_K_okay && type_V_okay) { \
124+
ggml_cuda_flash_attn_ext_vec_case<D, type_K, type_V>(ctx, dst); \
125+
return; \
126+
} \
127+
} \
124128

125129
#define FATTN_VEC_CASES_ALL_D(type_K, type_V) \
126130
FATTN_VEC_CASE( 64, type_K, type_V) \
@@ -247,6 +251,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
247251
#endif // GGML_CUDA_FA_ALL_QUANTS
248252

249253
switch (K->type) {
254+
case GGML_TYPE_F32:
250255
case GGML_TYPE_F16:
251256
break;
252257
case GGML_TYPE_Q4_1:
@@ -272,7 +277,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
272277
// If Turing tensor cores available, use them:
273278
if (turing_mma_available(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40) {
274279
if (can_use_vector_kernel) {
275-
if (K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16) {
280+
if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {
276281
if (cc >= GGML_CUDA_CC_ADA_LOVELACE && Q->ne[1] == 1 && Q->ne[3] == 1 && !(gqa_ratio > 4 && K->ne[1] >= 8192)) {
277282
return BEST_FATTN_KERNEL_VEC;
278283
}
@@ -305,7 +310,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
305310

306311
// If there are no tensor cores available, use the generic tile kernel:
307312
if (can_use_vector_kernel) {
308-
if (K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16) {
313+
if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {
309314
if (Q->ne[1] == 1) {
310315
if (!gqa_opt_applies) {
311316
return BEST_FATTN_KERNEL_VEC;

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 2 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2633,11 +2633,10 @@ static void ggml_backend_cuda_synchronize(ggml_backend_t backend) {
26332633
}
26342634

26352635
#ifdef USE_CUDA_GRAPH
2636-
static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
2636+
static bool check_node_graph_compatibility(ggml_cgraph * cgraph,
26372637
bool use_cuda_graph) {
26382638

26392639
// Loop over nodes in GGML graph to obtain info needed for CUDA graph
2640-
cuda_ctx->cuda_graph->cpy_dest_ptrs.clear();
26412640

26422641
const std::string gemma3n_per_layer_proj_src0_name = "inp_per_layer_selected";
26432642
const std::string gemma3n_per_layer_proj_src1_name = "per_layer_proj";
@@ -2688,33 +2687,11 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
26882687
#endif
26892688
}
26902689

2691-
if (node->op == GGML_OP_CPY) {
2692-
2693-
// Store the pointers which are updated for each token, such that these can be sent
2694-
// to the device and accessed using indirection from CUDA graph
2695-
cuda_ctx->cuda_graph->cpy_dest_ptrs.push_back((char *) node->src[1]->data);
2696-
2697-
// store a pointer to each copy op CUDA kernel to identify it later
2698-
void * ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]);
2699-
if (!ptr) {
2700-
use_cuda_graph = false;
2701-
#ifndef NDEBUG
2702-
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported copy op\n", __func__);
2703-
#endif
2704-
}
2705-
}
2706-
27072690
if (!use_cuda_graph) {
27082691
break;
27092692
}
27102693
}
27112694

2712-
if (use_cuda_graph) {
2713-
cuda_ctx->cuda_graph->use_cpy_indirection = true;
2714-
// copy pointers to GPU so they can be accessed via indirection within CUDA graph
2715-
ggml_cuda_cpy_dest_ptrs_copy(cuda_ctx->cuda_graph.get(), cuda_ctx->cuda_graph->cpy_dest_ptrs.data(), cuda_ctx->cuda_graph->cpy_dest_ptrs.size(), cuda_ctx->stream());
2716-
}
2717-
27182695
return use_cuda_graph;
27192696
}
27202697

@@ -2733,7 +2710,6 @@ static void set_ggml_graph_node_properties(ggml_tensor * node, ggml_graph_node_p
27332710

27342711
static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
27352712
if (node->data != graph_node_properties->node_address &&
2736-
node->op != GGML_OP_CPY &&
27372713
node->op != GGML_OP_VIEW) {
27382714
return false;
27392715
}
@@ -2754,7 +2730,6 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
27542730
for (int i = 0; i < GGML_MAX_SRC; i++) {
27552731
if (node->src[i] &&
27562732
node->src[i]->data != graph_node_properties->src_address[i] &&
2757-
node->op != GGML_OP_CPY &&
27582733
node->op != GGML_OP_VIEW
27592734
) {
27602735
return false;
@@ -3120,7 +3095,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
31203095
if (use_cuda_graph) {
31213096
cuda_graph_update_required = is_cuda_graph_update_required(cuda_ctx, cgraph);
31223097

3123-
use_cuda_graph = check_node_graph_compatibility_and_refresh_copy_ops(cuda_ctx, cgraph, use_cuda_graph);
3098+
use_cuda_graph = check_node_graph_compatibility(cgraph, use_cuda_graph);
31243099

31253100
// Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates.
31263101
if (use_cuda_graph && cuda_graph_update_required) {
@@ -3147,10 +3122,6 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
31473122
CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed));
31483123
}
31493124

3150-
if (!use_cuda_graph) {
3151-
cuda_ctx->cuda_graph->use_cpy_indirection = false;
3152-
}
3153-
31543125
#else
31553126
bool use_cuda_graph = false;
31563127
bool cuda_graph_update_required = false;

ggml/src/ggml-cuda/mmf.cu

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#include "ggml.h"
22
#include "mmf.cuh"
3+
#include "mmid.cuh"
4+
35

46
void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
57
GGML_ASSERT( src1->type == GGML_TYPE_F32);
@@ -37,6 +39,12 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
3739
const int64_t ids_s0 = ids ? ids->nb[0] / ggml_type_size(ids->type) : 0;
3840
const int64_t ids_s1 = ids ? ids->nb[1] / ggml_type_size(ids->type) : 0;
3941

42+
mmf_ids_data ids_info{};
43+
mmf_ids_data * ids_info_ptr = nullptr;
44+
ggml_cuda_pool_alloc<int32_t> ids_src_compact_dev;
45+
ggml_cuda_pool_alloc<int32_t> ids_dst_compact_dev;
46+
ggml_cuda_pool_alloc<int32_t> expert_bounds_dev;
47+
4048
// For MUL_MAT_ID the memory layout is different than for MUL_MAT:
4149
const int64_t ncols_dst = ids ? ne2 : ne1;
4250
const int64_t nchannels_dst = ids ? ne1 : ne2;
@@ -54,30 +62,57 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
5462
nchannels_y = ids->ne[0];
5563
}
5664

65+
if (ids && ncols_dst > 16) {
66+
const int64_t n_expert_used = ids->ne[0];
67+
const int64_t n_experts = ne02;
68+
const int64_t n_tokens = ne12;
69+
const int64_t ne_get_rows = n_tokens * n_expert_used;
70+
71+
ids_src_compact_dev.alloc(ctx.pool(), ne_get_rows);
72+
ids_dst_compact_dev.alloc(ctx.pool(), ne_get_rows);
73+
expert_bounds_dev.alloc(ctx.pool(), n_experts + 1);
74+
75+
const int si1 = static_cast<int>(ids_s1);
76+
const int sis1 = static_cast<int>(src1->nb[2] / src1->nb[1]);
77+
78+
GGML_ASSERT(sis1 > 0);
79+
80+
ggml_cuda_launch_mm_ids_helper(ids_d, ids_src_compact_dev.get(), ids_dst_compact_dev.get(), expert_bounds_dev.get(),
81+
static_cast<int>(n_experts), static_cast<int>(n_tokens), static_cast<int>(n_expert_used), static_cast<int>(ne11), si1, sis1, ctx.stream());
82+
CUDA_CHECK(cudaGetLastError());
83+
84+
ids_info.ids_src_compact = ids_src_compact_dev.get();
85+
ids_info.ids_dst_compact = ids_dst_compact_dev.get();
86+
ids_info.expert_bounds_dev = expert_bounds_dev.get();
87+
ids_info.n_experts = static_cast<int>(n_experts);
88+
ids_info.sis1 = sis1;
89+
ids_info_ptr = &ids_info;
90+
}
91+
5792
switch (src0->type) {
5893
case GGML_TYPE_F32: {
5994
const float * src0_d = (const float *) src0->data;
6095
constexpr int vals_per_T = 1;
6196
mul_mat_f_switch_cols_per_block(
6297
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
6398
ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
64-
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream());
99+
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr);
65100
} break;
66101
case GGML_TYPE_F16: {
67102
const half2 * src0_d = (const half2 *) src0->data;
68103
constexpr int vals_per_T = 2;
69104
mul_mat_f_switch_cols_per_block(
70105
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
71106
ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
72-
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream());
107+
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr);
73108
} break;
74109
case GGML_TYPE_BF16: {
75110
const nv_bfloat162 * src0_d = (const nv_bfloat162 *) src0->data;
76111
constexpr int vals_per_T = 2;
77112
mul_mat_f_switch_cols_per_block(
78113
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
79114
ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
80-
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream());
115+
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr);
81116
} break;
82117
default:
83118
GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
@@ -98,10 +133,9 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
98133
}
99134

100135
if (mul_mat_id) {
101-
if (type == GGML_TYPE_F32 && src1_ncols > 32) {
136+
if (src0_ne[1] <= 1024 && src1_ncols > 512) {
102137
return false;
103-
}
104-
if ((type == GGML_TYPE_F16 || type == GGML_TYPE_BF16) && src1_ncols > 64) {
138+
} else if(src0_ne[1] > 1024 && src1_ncols > 128) {
105139
return false;
106140
}
107141
} else {

0 commit comments

Comments
 (0)