Skip to content

Commit 4ff70ea

Browse files
committed
Reapply commit "CUDA: noncont MMVQ + batched bs1 MUL_MAT_ID (ggml-org#13014)" quantize
1 parent c7c4d03 commit 4ff70ea

File tree

3 files changed

+57
-35
lines changed

3 files changed

+57
-35
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1630,6 +1630,11 @@ static void ggml_cuda_op_mul_mat(
16301630
const int64_t ne0 = dst->ne[0];
16311631
const int64_t ne1 = dst->ne[1];
16321632

1633+
// const int64_t nb10 = src1->nb[0];
1634+
const int64_t nb11 = src1->nb[1];
1635+
const int64_t nb12 = src1->nb[2];
1636+
const int64_t nb13 = src1->nb[3];
1637+
16331638
const int64_t nb2 = dst->nb[2];
16341639
const int64_t nb3 = dst->nb[3];
16351640

@@ -1764,7 +1769,10 @@ static void ggml_cuda_op_mul_mat(
17641769
dev[id].src1_ddq = dev[id].src1_ddq_alloc.alloc(ctx.pool(id), src_1_ddq_size);
17651770

17661771
if (src1_on_device && src1_is_contiguous) {
1767-
quantize_src1(dev[id].src1_ddf, dev[id].src1_ddq, ne10, ne11, ne12*ne13, src1_padded_col_size, src0->type, stream);
1772+
quantize_src1(
1773+
dev[id].src1_ddf, dev[id].src1_ddq, src0->type, ne10,
1774+
nb11/sizeof(float), nb12/sizeof(float), nb13/sizeof(float),
1775+
src1_padded_col_size, ne11, ne12, ne13, stream);
17681776
CUDA_CHECK(cudaGetLastError());
17691777
}
17701778
}
@@ -1862,7 +1870,9 @@ static void ggml_cuda_op_mul_mat(
18621870
}
18631871

18641872
if (quantize_src1 && !src1_is_contiguous) {
1865-
quantize_src1(src1_ddf_i, src1_ddq_i, ne10, src1_ncols, 1, src1_padded_col_size, src0->type, stream);
1873+
quantize_src1(
1874+
src1_ddf_i, src1_ddq_i, src0->type, ne10, ne10, ne11*ne10, ne12*ne11*ne10,
1875+
src1_padded_col_size, src1_ncols, 1, 1, stream);
18661876
CUDA_CHECK(cudaGetLastError());
18671877
}
18681878

ggml/src/ggml-cuda/quantize.cu

Lines changed: 39 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,40 @@
11
#include "quantize.cuh"
22
#include <cstdint>
33

4-
static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int64_t kx, const int64_t kx0_padded) {
5-
const int64_t ix0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
4+
static __global__ void quantize_q8_1(
5+
const float * __restrict__ x, void * __restrict__ vy,
6+
const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
7+
const int64_t ne0, const int ne1, const int ne2) {
8+
const int64_t i0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
69

7-
if (ix0 >= kx0_padded) {
10+
if (i0 >= ne0) {
811
return;
912
}
1013

11-
const int64_t ix1 = blockIdx.y;
14+
const int64_t i1 = blockIdx.y;
15+
const int64_t i2 = blockIdx.z % ne2;
16+
const int64_t i3 = blockIdx.z / ne2;
17+
18+
const int64_t & i00 = i0;
19+
const int64_t & i01 = i1;
20+
const int64_t & i02 = i2;
21+
const int64_t & i03 = i3;
1222

13-
const int64_t i_padded = ix1*kx0_padded + ix0;
23+
const int64_t i_cont = ((i3*ne2 + i2) * ne1 + i1) * ne0 + i0;
1424

1525
block_q8_1 * y = (block_q8_1 *) vy;
1626

17-
const int64_t ib = i_padded / QK8_1; // block index
18-
const int64_t iqs = i_padded % QK8_1; // quant index
27+
const int64_t ib = i_cont / QK8_1; // block index
28+
const int64_t iqs = i_cont % QK8_1; // quant index
1929

20-
const float xi = ix0 < kx ? x[ix1*kx + ix0] : 0.0f;
30+
const float xi = i0 < ne00 ? x[i03*s03 + i02*s02 + i01*s01 + i00] : 0.0f;
2131
float amax = fabsf(xi);
2232
float sum = xi;
2333

2434
amax = warp_reduce_max(amax);
25-
sum = warp_reduce_sum(sum);
35+
sum = warp_reduce_sum(sum);
2636

27-
const float d = amax / 127;
37+
const float d = amax / 127;
2838
const int8_t q = amax == 0.0f ? 0 : roundf(xi / d);
2939

3040
y[ib].qs[iqs] = q;
@@ -124,43 +134,45 @@ static __global__ void quantize_mmq_q8_1(
124134
}
125135

126136
void quantize_row_q8_1_cuda(
127-
const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels,
128-
const int64_t kx0_padded, const ggml_type type_x, cudaStream_t stream) {
137+
const float * x, void * vy, const ggml_type type_src0, const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
138+
const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) {
129139

130-
GGML_ASSERT(kx0_padded % QK8_1 == 0);
140+
GGML_ASSERT(ne0 % QK8_1 == 0);
131141

132-
const int64_t block_num_x = (kx0_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
133-
const dim3 num_blocks(block_num_x, kx1*channels, 1);
142+
const int64_t block_num_x = (ne0 + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
143+
const dim3 num_blocks(block_num_x, ne1, ne2*ne3);
134144
const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1);
135-
quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx0_padded);
136-
137-
GGML_UNUSED(type_x);
145+
quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, ne00, s01, s02, s03, ne0, ne1, ne2);
146+
GGML_UNUSED(type_src0);
138147
}
139148

140149
void quantize_mmq_q8_1_cuda(
141-
const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels,
142-
const int64_t kx0_padded, const ggml_type type_x, cudaStream_t stream) {
150+
const float * x, void * vy, const ggml_type type_src0, const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
151+
const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) {
143152

144-
GGML_ASSERT(kx0_padded % (4*QK8_1) == 0);
153+
GGML_ASSERT(ne0 % (4*QK8_1) == 0);
145154

146-
const int64_t block_num_x = (kx0_padded + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ);
147-
const dim3 num_blocks(block_num_x, kx1, channels);
155+
const int64_t block_num_x = (ne0 + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ);
156+
const dim3 num_blocks(block_num_x, ne1, ne2*ne3);
148157
const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE_MMQ, 1, 1);
149-
switch (mmq_get_q8_1_ds_layout(type_x)) {
158+
switch (mmq_get_q8_1_ds_layout(type_src0)) {
150159
case MMQ_Q8_1_DS_LAYOUT_D4:
151160
quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_D4>
152-
<<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
161+
<<<num_blocks, block_size, 0, stream>>>(x, vy, ne00, ne1, ne0);
153162
break;
154163
case MMQ_Q8_1_DS_LAYOUT_DS4:
155164
quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_DS4>
156-
<<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
165+
<<<num_blocks, block_size, 0, stream>>>(x, vy, ne00, ne1, ne0);
157166
break;
158167
case MMQ_Q8_1_DS_LAYOUT_D2S6:
159168
quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_D2S6>
160-
<<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
169+
<<<num_blocks, block_size, 0, stream>>>(x, vy, ne00, ne1, ne0);
161170
break;
162171
default:
163172
GGML_ABORT("fatal error");
164173
break;
165174
}
175+
GGML_UNUSED(s01);
176+
GGML_UNUSED(s02);
177+
GGML_UNUSED(s03);
166178
}

ggml/src/ggml-cuda/quantize.cuh

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,13 @@ static_assert(MATRIX_ROW_PADDING % CUDA_QUANTIZE_BLOCK_SIZE == 0, "Risk
1212
static_assert(MATRIX_ROW_PADDING % (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ) == 0, "Risk of out-of-bounds access.");
1313

1414
typedef void (*quantize_cuda_t)(
15-
const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const int64_t kx0_padded,
16-
const ggml_type type_x, cudaStream_t stream);
15+
const float * x, void * vy, const ggml_type type_src0, const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
16+
const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream);
1717

1818
void quantize_row_q8_1_cuda(
19-
const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const int64_t kx0_padded,
20-
const ggml_type type_x, cudaStream_t stream);
19+
const float * x, void * vy, const ggml_type type_src0, const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
20+
const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream);
2121

2222
void quantize_mmq_q8_1_cuda(
23-
const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const int64_t kx0_padded,
24-
const ggml_type type_x, cudaStream_t stream);
23+
const float * x, void * vy, const ggml_type type_src0, const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
24+
const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream);

0 commit comments

Comments
 (0)