Skip to content

Commit 4f1275d

Browse files
committed
Revert "cuda : remove nrows_x in mul_mat_q_process_tile (ggml-org#13325)"
This reverts commit 1f73301.
1 parent 5ee6f19 commit 4f1275d

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

ggml/src/ggml-cuda/mmq.cuh

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2523,7 +2523,7 @@ template <ggml_type type, int mmq_x, int nwarps, bool need_check, bool fixup>
25232523
static __device__ __forceinline__ void mul_mat_q_process_tile(
25242524
const char * __restrict__ x, const int offset_x, const int * __restrict__ y,
25252525
const int * __restrict__ ids_dst, float * __restrict__ dst, float * __restrict__ tmp_fixup,
2526-
const int stride_row_x, const int ncols_y, const int stride_col_dst,
2526+
const int nrows_x, const int stride_row_x, const int ncols_y, const int stride_col_dst,
25272527
const int tile_x_max_i, const int tile_y_max_j, const int kb0_start, const int kb0_stop) {
25282528

25292529
constexpr int qk = ggml_cuda_type_traits<type>::qk;
@@ -2690,7 +2690,7 @@ static __global__ void mul_mat_q(
26902690

26912691
constexpr bool fixup = false;
26922692
mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
2693-
(x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst,
2693+
(x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, nrows_x, stride_row_x, ncols_y, stride_col_dst,
26942694
tile_x_max_i, tile_y_max_j, 0, ncols_x/qk);
26952695
return;
26962696
}
@@ -2768,7 +2768,7 @@ static __global__ void mul_mat_q(
27682768

27692769
constexpr bool fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
27702770
mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
2771-
(x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst,
2771+
(x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, nrows_x, stride_row_x, ncols_y, stride_col_dst,
27722772
tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop);
27732773

27742774
kbc += blocks_per_ne00;
@@ -2835,7 +2835,7 @@ static __global__ void mul_mat_q(
28352835

28362836
constexpr bool fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
28372837
mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
2838-
(x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst,
2838+
(x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, nrows_x, stride_row_x, ncols_y, stride_col_dst,
28392839
tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop);
28402840
}
28412841

0 commit comments

Comments
 (0)