Skip to content

Commit 9114078

Browse files
ikawrakowIwan Kawrakow
andauthored
Fix quantized k-cache without FA (#105)
* Added Johannes' changes, still getting NaNs with quantized k-cache. Also getting NaN's on Johannes's mainline branch. * This fixes it --------- Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent b61cf7d commit 9114078

File tree

3 files changed

+12
-13
lines changed

3 files changed

+12
-13
lines changed

ggml/src/ggml-cuda.cu

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1170,8 +1170,8 @@ static cudaError_t ggml_cuda_cpy_tensor_2d(
11701170
void * dst, const struct ggml_tensor * src, int64_t i3, int64_t i2, int64_t i1_low, int64_t i1_high, cudaStream_t stream) {
11711171

11721172
GGML_ASSERT(ggml_backend_buffer_is_cuda(src->buffer));
1173-
char * src_ptr = (char *) src->data;
1174-
char * dst_ptr = (char *) dst;
1173+
const char * src_ptr = (const char *) src->data;
1174+
char * dst_ptr = (char *) dst;
11751175

11761176
const int64_t ne0 = src->ne[0];
11771177
const int64_t nb0 = src->nb[0];
@@ -1182,7 +1182,7 @@ static cudaError_t ggml_cuda_cpy_tensor_2d(
11821182
const int64_t ts = ggml_type_size(type);
11831183
const int64_t rs = ggml_row_size(type, ne0);
11841184
const int64_t bs = ggml_blck_size(type);
1185-
int64_t i1_diff = i1_high - i1_low;
1185+
const int64_t i1_diff = i1_high - i1_low;
11861186

11871187
const char * x = src_ptr + i1_low*nb1 + i2*nb2 + i3*nb3;
11881188
if (nb0 == ts && nb1 == rs) {
@@ -1532,10 +1532,14 @@ static void ggml_cuda_op_mul_mat(
15321532
if (src0_is_contiguous) {
15331533
dev[id].src0_dd = split ? (char *) src0_extra->data_device[id] : (char *) src0->data;
15341534
} else {
1535-
dev[id].src0_dd = dev[id].src0_dd_alloc.alloc(ctx.pool(id), ggml_nbytes(src0));
1535+
// If src0 is not contiguous it will be copied to a temporary buffer, it may then be necessary to clear padding.
1536+
const size_t nbytes_data = ggml_nbytes(src0);
1537+
const size_t nbytes_padding = ggml_row_size(src0->type, MATRIX_ROW_PADDING - ne00 % MATRIX_ROW_PADDING);
1538+
dev[id].src0_dd = dev[id].src0_dd_alloc.alloc(ctx.pool(id), nbytes_data + nbytes_padding);
1539+
CUDA_CHECK(cudaMemsetAsync(dev[id].src0_dd, 0, nbytes_data + nbytes_padding, stream));
15361540
}
15371541

1538-
// If src0 is on a temporary compute buffers (partial offloading) there may be some padding that needs to be cleared:
1542+
// If src0 is on a temporary compute buffer (partial offloading) there may be some padding that needs to be cleared:
15391543
if (ne00 % MATRIX_ROW_PADDING != 0 && ggml_is_quantized(src0->type) && ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE && src0->view_src == nullptr) {
15401544
const int64_t nbytes_data = ggml_row_size(src0->type, (dev[id].row_high - dev[id].row_low)*ne00);
15411545
const int64_t nbytes_padding = ggml_row_size(src0->type, MATRIX_ROW_PADDING - ne00 % MATRIX_ROW_PADDING);

ggml/src/ggml-cuda/mmq.cu

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,14 @@ void ggml_cuda_op_mul_mat_q(
88

99
const int64_t ne00 = src0->ne[0];
1010

11-
const int64_t nb01 = src0->nb[1];
12-
1311
const int64_t ne10 = src1->ne[0];
1412
const int64_t ne11 = src1->ne[1];
1513
GGML_ASSERT(ne10 % QK8_1 == 0);
1614

1715
const int64_t ne0 = dst->ne[0];
1816

1917
const int64_t row_diff = row_high - row_low;
20-
const int64_t stride00 = nb01 / ggml_type_size(src0->type);
18+
const int64_t stride00 = ne00 / ggml_blck_size(src0->type);
2119

2220
int id = ggml_cuda_get_device();
2321
const int compute_capability = ggml_cuda_info().devices[id].cc;

ggml/src/ggml-cuda/quantize.cu

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,8 @@ static __global__ void quantize_mmq_q8_1(
8484
}
8585
}
8686

87-
const float d_inv = 127.0f / amax;
87+
const float d = amax/127.f;
88+
const float d_inv = d > 0 ? 1/d : 0.f;
8889
char4 q;
8990
q.x = roundf(xi.x*d_inv);
9091
q.y = roundf(xi.y*d_inv);
@@ -106,8 +107,6 @@ static __global__ void quantize_mmq_q8_1(
106107
return;
107108
}
108109

109-
const float d = 1.0f / d_inv;
110-
111110
y[ib].d2s6[iqs/64] = d;
112111

113112
return;
@@ -117,8 +116,6 @@ static __global__ void quantize_mmq_q8_1(
117116
return;
118117
}
119118

120-
const float d = 1.0f / d_inv;
121-
122119
if (ds_layout == MMQ_Q8_1_DS_LAYOUT_DS4) {
123120
y[ib].ds4[iqs/32] = make_half2(d, sum);
124121
} else {

0 commit comments

Comments
 (0)