Skip to content

Commit becc481

Browse files
authored
ggml-cpu: handle 3d tensors in repack mat_mul (#17241)
* ggml-cpu: handle 3d tensors in repack mul_mat * Removed unnecessary branch, removed need for <algorithm> * Fixed dst_ptr pointer in chunk + clang_format * GGML_ASSERT to check wdata within bounds * Accidental ggml.h inclusion * Improved GGML_ASSERT on wdata boundaries * Address performance regression in Qwen and llama.cpp due to chunking
1 parent c4abcb2 commit becc481

File tree

1 file changed

+95
-42
lines changed

1 file changed

+95
-42
lines changed

ggml/src/ggml-cpu/repack.cpp

Lines changed: 95 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1600,29 +1600,52 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
16001600
return false;
16011601
}
16021602

1603-
void forward_mul_mat_one_chunk(ggml_compute_params * params, ggml_tensor * op, int64_t src0_start, int64_t src0_end) {
1603+
void forward_mul_mat_one_chunk(ggml_compute_params * params,
1604+
ggml_tensor * op,
1605+
int64_t src0_start,
1606+
int64_t src0_end,
1607+
int64_t src1_start,
1608+
int64_t src1_end) {
16041609
const ggml_tensor * src0 = op->src[0];
16051610
const ggml_tensor * src1 = op->src[1];
16061611
ggml_tensor * dst = op;
16071612

16081613
GGML_TENSOR_BINARY_OP_LOCALS
16091614

1610-
const void * src1_wdata = params->wdata;
16111615
const size_t src1_col_stride = ggml_row_size(PARAM_TYPE, ne10);
16121616

1617+
GGML_ASSERT(ne03 == 1 && ne13 == 1);
1618+
GGML_ASSERT(ne12 % ne02 == 0);
1619+
const int64_t r2 = ne12 / ne02;
1620+
1621+
const int64_t i12 = src1_start / ne1;
1622+
const int64_t i11 = src1_start - i12 * ne1;
1623+
1624+
// Determine batch index
1625+
const int64_t i02 = i12 / r2;
1626+
1627+
const int64_t i1 = i11;
1628+
const int64_t i2 = i12;
1629+
1630+
const char * src0_ptr = (const char *) src0->data + i02 * nb02;
1631+
const char * src1_ptr = (const char *) params->wdata + (i11 + i12 * ne11) * src1_col_stride;
1632+
char * dst_ptr = ((char *) dst->data + (i1 * nb1 + i2 * nb2));
1633+
1634+
const int64_t nrows = src1_end - src1_start;
1635+
const int64_t ncols = src0_end - src0_start;
1636+
1637+
GGML_ASSERT(src1_ptr + src1_col_stride * nrows <= (const char *) params->wdata + params->wsize);
1638+
16131639
// If there are more than three rows in src1, use gemm; otherwise, use gemv.
1614-
if (ne11 > 3) {
1615-
gemm<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
1616-
(float *) ((char *) dst->data) + src0_start, ne01,
1617-
(const char *) src0->data + src0_start * nb01,
1618-
(const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start);
1640+
if (nrows > 3) {
1641+
gemm<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00, (float *) (dst_ptr) + src0_start, nb1 / nb0,
1642+
src0_ptr + src0_start * nb01, src1_ptr,
1643+
nrows - (nrows % 4), ncols);
16191644
}
1620-
for (int iter = ne11 - ne11 % 4; iter < ne11; iter++) {
1621-
gemv<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
1622-
(float *) ((char *) dst->data + (iter * nb1)) + src0_start, ne01,
1623-
(const char *) src0->data + src0_start * nb01,
1624-
(const char *) src1_wdata + (src1_col_stride * iter), 1,
1625-
src0_end - src0_start);
1645+
for (int iter = nrows - (nrows % 4); iter < nrows; iter++) {
1646+
gemv<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00, (float *) (dst_ptr + (iter * nb1)) + src0_start,
1647+
ne01, src0_ptr + src0_start * nb01,
1648+
src1_ptr + (src1_col_stride * iter), 1 /* nrows */, ncols);
16261649
}
16271650
}
16281651

@@ -1647,54 +1670,77 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
16471670
GGML_ASSERT(nb1 <= nb2);
16481671
GGML_ASSERT(nb2 <= nb3);
16491672

1673+
// TODO: General batched mul mat for 4D tensors
1674+
// Currently only supports 3D tensors
1675+
GGML_ASSERT(ne03 == 1);
1676+
GGML_ASSERT(ne13 == 1);
1677+
GGML_ASSERT(ne3 == 1);
1678+
16501679
GGML_ASSERT(src1->type == GGML_TYPE_F32);
16511680

16521681
GGML_ASSERT(ggml_n_dims(op->src[0]) == 2);
16531682
// GGML_ASSERT(ggml_n_dims(op->src[1]) == 2);
16541683

16551684
char * wdata = static_cast<char *>(params->wdata);
16561685
const size_t nbw1 = ggml_row_size(PARAM_TYPE, ne10);
1686+
const size_t nbw2 = nbw1 * ne11;
16571687

1658-
assert(params->wsize >= nbw1 * ne11);
1688+
assert(params->wsize >= nbw2 * ne12);
16591689

16601690
const ggml_from_float_t from_float = ggml_get_type_traits_cpu(PARAM_TYPE)->from_float;
16611691

1662-
int64_t i11_processed = 0;
1663-
for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) {
1664-
ggml_quantize_mat_t<INTER_SIZE, PARAM_TYPE>((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), 4, ne10);
1665-
}
1692+
// INFO: Quantization is done in planes to avoid extra complexity in chunking.
1693+
// Flattening dimensions not multiple of INTER_SIZE would require extra handling depending on how
1694+
// the planes are broadcast.
1695+
for (int64_t i12 = 0; i12 < ne12; i12++) {
1696+
char * data_ptr = (char *) src1->data + i12 * nb12;
1697+
char * wdata_ptr = wdata + i12 * nbw2;
16661698

1667-
i11_processed = ne11 - ne11 % 4;
1668-
for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) {
1669-
from_float((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), ne10);
1699+
for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) {
1700+
ggml_quantize_mat_t<INTER_SIZE, PARAM_TYPE>((float *) (data_ptr + i11 * nb11),
1701+
(void *) (wdata_ptr + i11 * nbw1), 4, ne10);
1702+
}
1703+
1704+
const int64_t i11_processed = ne11 - ne11 % 4;
1705+
for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) {
1706+
from_float((float *) (data_ptr + i11 * nb11), (void *) (wdata_ptr + i11 * nbw1), ne10);
1707+
}
16701708
}
16711709

16721710
// disable for NUMA
16731711
const bool disable_chunking = ggml_is_numa();
16741712

16751713
// 4x chunks per thread
1676-
int64_t nr = ggml_nrows(op->src[0]);
1677-
int nth_scaled = nth * 4;
1678-
int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled;
1679-
int64_t nchunk = (nr + chunk_size - 1) / chunk_size;
1714+
const int64_t nr0 = ggml_nrows(op->src[0]);
1715+
1716+
int nth_scaled = nth * 4;
1717+
int64_t chunk_size0 = (nr0 + nth_scaled - 1) / nth_scaled;
1718+
int64_t nchunk0 = (nr0 + chunk_size0 - 1) / chunk_size0;
1719+
1720+
// src1 is chunked only by full planes.
1721+
// When we flatten we need to address dimensions not multiple of the q8 INTER_SIZE
1722+
// to route them thorugh GEMV.
1723+
// nchunk1 = ne12 also avoids messing the chunking for models with no 3d tensors
1724+
// to avoid affecting their performance
1725+
int64_t nchunk1 = ne12;
16801726

16811727
// Ensure minimum chunk size to avoid alignment issues with high thread counts
16821728
// Minimum chunk size should be at least NB_COLS to prevent overlapping chunks after alignment
16831729
const int64_t min_chunk_size = NB_COLS;
1684-
if (nchunk > 0 && (nr / nchunk) < min_chunk_size && nr >= min_chunk_size) {
1685-
nchunk = (nr + min_chunk_size - 1) / min_chunk_size;
1730+
if (nchunk0 > 0 && (nr0 / nchunk0) < min_chunk_size && nr0 >= min_chunk_size) {
1731+
nchunk0 = (nr0 + min_chunk_size - 1) / min_chunk_size;
16861732
}
16871733

1688-
if (nth == 1 || nchunk < nth || disable_chunking) {
1689-
nchunk = nth;
1734+
if (nth == 1 || nchunk0 < nth || disable_chunking) {
1735+
nchunk0 = nth;
16901736
}
16911737

1738+
const int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0;
1739+
16921740
// Ensure nchunk doesn't exceed the number of rows divided by minimum chunk size
16931741
// This prevents creating too many tiny chunks that could overlap after alignment
1694-
const int64_t max_nchunk = (nr + min_chunk_size - 1) / min_chunk_size;
1695-
if (nchunk > max_nchunk) {
1696-
nchunk = max_nchunk;
1697-
}
1742+
const int64_t max_nchunk = (nr0 + min_chunk_size - 1) / min_chunk_size;
1743+
nchunk0 = MIN(nchunk0, max_nchunk);
16981744

16991745
if (ith == 0) {
17001746
// Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
@@ -1706,23 +1752,30 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
17061752
// The first chunk comes from our thread_id, the rest will get auto-assigned.
17071753
int current_chunk = ith;
17081754

1709-
while (current_chunk < nchunk) {
1710-
int64_t src0_start = (current_chunk * ne01) / nchunk;
1711-
int64_t src0_end = ((current_chunk + 1) * ne01) / nchunk;
1755+
while (current_chunk < nchunk0 * nchunk1) {
1756+
const int64_t ith0 = current_chunk % nchunk0;
1757+
const int64_t ith1 = current_chunk / nchunk0;
1758+
1759+
int64_t src0_start = dr0 * ith0;
1760+
int64_t src0_end = MIN(src0_start + dr0, nr0);
1761+
1762+
// full-plane range for src1
1763+
int64_t src1_start = ith1 * ne11;
1764+
int64_t src1_end = (ith1 + 1) * ne11;
17121765

17131766
// Align boundaries to NB_COLS - round up to ensure all data is included
17141767
// The chunk size limiting above ensures chunks are large enough to prevent overlaps
17151768
src0_start = (src0_start % NB_COLS) ? src0_start + NB_COLS - (src0_start % NB_COLS) : src0_start;
1716-
src0_end = (src0_end % NB_COLS) ? src0_end + NB_COLS - (src0_end % NB_COLS) : src0_end;
1717-
if (src0_end > ne01) {
1718-
src0_end = ne01;
1719-
}
1769+
src0_end = (src0_end % NB_COLS) ? src0_end + NB_COLS - (src0_end % NB_COLS) : src0_end;
1770+
src0_end = MIN(src0_end, ne01);
17201771

1772+
// Make sure current plane is the last one before exiting
17211773
if (src0_start >= src0_end) {
1722-
break;
1774+
current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
1775+
continue;
17231776
}
17241777

1725-
forward_mul_mat_one_chunk(params, dst, src0_start, src0_end);
1778+
forward_mul_mat_one_chunk(params, dst, src0_start, src0_end, src1_start, src1_end);
17261779

17271780
current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
17281781
}

0 commit comments

Comments
 (0)