Skip to content

Commit d67fcb8

Browse files
committed
check fusion=false for ncols_dst!=1
1 parent e212c85 commit d67fcb8

File tree

2 files changed

+8
-9
lines changed

2 files changed

+8
-9
lines changed

ggml/src/ggml-cuda/mmvf.cu

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -163,13 +163,7 @@ static __global__ void mul_mat_vec_f(
163163
} else {
164164
#ifdef FP16_AVAILABLE
165165
half2 sumh2[ncols_dst] = {{0.0f, 0.0f}};
166-
half2 sumh2_gate[ncols_dst];
167-
if constexpr (has_fusion) {
168-
#pragma unroll
169-
for (int j = 0; j < ncols_dst; ++j) {
170-
sumh2_gate[j] = make_half2(0.0f, 0.0f);
171-
}
172-
}
166+
half2 sumh2_gate[ncols_dst] = {{0.0f, 0.0f}};
173167

174168
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
175169
const half2 tmpx = x2[col2];
@@ -359,8 +353,8 @@ static void mul_mat_vec_f_switch_fusion(
359353
const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
360354
const dim3 & block_dims, const dim3 & block_nums, const int nbytes_shared, const cudaStream_t stream) {
361355

356+
const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
362357
if constexpr (ncols_dst == 1) {
363-
const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
364358
if (has_fusion) {
365359
mul_mat_vec_f<T, type_acc, ncols_dst, block_size, true><<<block_nums, block_dims, nbytes_shared, stream>>>
366360
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
@@ -370,6 +364,8 @@ static void mul_mat_vec_f_switch_fusion(
370364
}
371365
}
372366

367+
GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1");
368+
373369
mul_mat_vec_f<T, type_acc, ncols_dst, block_size><<<block_nums, block_dims, nbytes_shared, stream>>>
374370
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
375371
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,

ggml/src/ggml-cuda/mmvq.cu

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ static __global__ void mul_mat_vec_q(
161161
const int blocks_per_row_x = ncols_x / qk;
162162
constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi;
163163

164+
// The MUL_MAT_ID code path with ids != nullptr is only implemented for ncols_dst == 1.
164165
const uint32_t channel_dst = blockIdx.y;
165166
const uint32_t channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : fastdiv(channel_dst, channel_ratio);
166167
const uint32_t channel_y = ncols_dst == 1 && ids ? fastmodulo(channel_dst, nchannels_y) : channel_dst;
@@ -324,8 +325,8 @@ static void mul_mat_vec_q_switch_fusion(
324325
const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst,
325326
const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared, cudaStream_t stream) {
326327

328+
const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
327329
if constexpr (c_ncols_dst == 1) {
328-
const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
329330
if (has_fusion) {
330331
mul_mat_vec_q<type, c_ncols_dst, true><<<block_nums, block_dims, nbytes_shared, stream>>>
331332
(vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
@@ -335,6 +336,8 @@ static void mul_mat_vec_q_switch_fusion(
335336
}
336337
}
337338

339+
GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1");
340+
338341
mul_mat_vec_q<type, c_ncols_dst, false><<<block_nums, block_dims, nbytes_shared, stream>>>
339342
(vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
340343
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,

0 commit comments

Comments
 (0)