@@ -163,13 +163,7 @@ static __global__ void mul_mat_vec_f(
163163 } else {
164164#ifdef FP16_AVAILABLE
165165 half2 sumh2[ncols_dst] = {{0 .0f , 0 .0f }};
166- half2 sumh2_gate[ncols_dst];
167- if constexpr (has_fusion) {
168- #pragma unroll
169- for (int j = 0 ; j < ncols_dst; ++j) {
170- sumh2_gate[j] = make_half2 (0 .0f , 0 .0f );
171- }
172- }
166+ half2 sumh2_gate[ncols_dst] = {{0 .0f , 0 .0f }};
173167
174168 for (int col2 = tid; col2 < ncols2; col2 += block_size) {
175169 const half2 tmpx = x2[col2];
@@ -359,8 +353,8 @@ static void mul_mat_vec_f_switch_fusion(
359353 const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
360354 const dim3 & block_dims, const dim3 & block_nums, const int nbytes_shared, const cudaStream_t stream) {
361355
356+ const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr ;
362357 if constexpr (ncols_dst == 1 ) {
363- const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr ;
364358 if (has_fusion) {
365359 mul_mat_vec_f<T, type_acc, ncols_dst, block_size, true ><<<block_nums, block_dims, nbytes_shared, stream>>>
366360 (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
@@ -370,6 +364,8 @@ static void mul_mat_vec_f_switch_fusion(
370364 }
371365 }
372366
367+ GGML_ASSERT (!has_fusion && " fusion only supported for ncols_dst=1" );
368+
373369 mul_mat_vec_f<T, type_acc, ncols_dst, block_size><<<block_nums, block_dims, nbytes_shared, stream>>>
374370 (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
375371 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
0 commit comments