Skip to content

Commit 517b717

Browse files
cpu: introduce chunking for repack matmuls and enable matmul-id chunking on ARM64 (#16833)
Very similar implementation to the flash-attention chunking, with similar benefits.
1 parent 835e918 commit 517b717

File tree

2 files changed

+57
-26
lines changed

2 files changed

+57
-26
lines changed

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1613,13 +1613,8 @@ static void ggml_compute_forward_mul_mat_id(
16131613
chunk_size = 64;
16141614
}
16151615

1616-
#if defined(__aarch64__)
1617-
// disable for ARM
1618-
const bool disable_chunking = true;
1619-
#else
16201616
// disable for NUMA
16211617
const bool disable_chunking = ggml_is_numa();
1622-
#endif // defined(__aarch64__)
16231618

16241619
int64_t nchunk0 = (nr0 + chunk_size - 1) / chunk_size;
16251620
int64_t nchunk1 = (nr1 + chunk_size - 1) / chunk_size;

ggml/src/ggml-cpu/repack.cpp

Lines changed: 57 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1600,6 +1600,32 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
16001600
return false;
16011601
}
16021602

1603+
void forward_mul_mat_one_chunk(ggml_compute_params * params, ggml_tensor * op, int64_t src0_start, int64_t src0_end) {
1604+
const ggml_tensor * src0 = op->src[0];
1605+
const ggml_tensor * src1 = op->src[1];
1606+
ggml_tensor * dst = op;
1607+
1608+
GGML_TENSOR_BINARY_OP_LOCALS
1609+
1610+
const void * src1_wdata = params->wdata;
1611+
const size_t src1_col_stride = ggml_row_size(PARAM_TYPE, ne10);
1612+
1613+
// If there are more than three rows in src1, use gemm; otherwise, use gemv.
1614+
if (ne11 > 3) {
1615+
gemm<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
1616+
(float *) ((char *) dst->data) + src0_start, ne01,
1617+
(const char *) src0->data + src0_start * nb01,
1618+
(const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start);
1619+
}
1620+
for (int iter = ne11 - ne11 % 4; iter < ne11; iter++) {
1621+
gemv<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
1622+
(float *) ((char *) dst->data + (iter * nb1)) + src0_start, ne01,
1623+
(const char *) src0->data + src0_start * nb01,
1624+
(const char *) src1_wdata + (src1_col_stride * iter), 1,
1625+
src0_end - src0_start);
1626+
}
1627+
}
1628+
16031629
void forward_mul_mat(ggml_compute_params * params, ggml_tensor * op) {
16041630
const ggml_tensor * src0 = op->src[0];
16051631
const ggml_tensor * src1 = op->src[1];
@@ -1643,31 +1669,41 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
16431669
from_float((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), ne10);
16441670
}
16451671

1646-
ggml_barrier(params->threadpool);
1672+
// disable for NUMA
1673+
const bool disable_chunking = ggml_is_numa();
16471674

1648-
const void * src1_wdata = params->wdata;
1649-
const size_t src1_col_stride = ggml_row_size(PARAM_TYPE, ne10);
1650-
int64_t src0_start = (ith * ne01) / nth;
1651-
int64_t src0_end = ((ith + 1) * ne01) / nth;
1652-
src0_start = (src0_start % NB_COLS) ? src0_start + NB_COLS - (src0_start % NB_COLS) : src0_start;
1653-
src0_end = (src0_end % NB_COLS) ? src0_end + NB_COLS - (src0_end % NB_COLS) : src0_end;
1654-
if (src0_start >= src0_end) {
1655-
return;
1675+
// 4x chunks per thread
1676+
int64_t nr = ggml_nrows(op->src[0]);
1677+
int nth_scaled = nth * 4;
1678+
int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled;
1679+
int64_t nchunk = (nr + chunk_size - 1) / chunk_size;
1680+
1681+
if (nth == 1 || nchunk < nth || disable_chunking) {
1682+
nchunk = nth;
16561683
}
16571684

1658-
// If there are more than three rows in src1, use gemm; otherwise, use gemv.
1659-
if (ne11 > 3) {
1660-
gemm<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
1661-
(float *) ((char *) dst->data) + src0_start, ne01,
1662-
(const char *) src0->data + src0_start * nb01,
1663-
(const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start);
1685+
if (ith == 0) {
1686+
// Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
1687+
ggml_threadpool_chunk_set(params->threadpool, nth);
16641688
}
1665-
for (int iter = ne11 - ne11 % 4; iter < ne11; iter++) {
1666-
gemv<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
1667-
(float *) ((char *) dst->data + (iter * nb1)) + src0_start, ne01,
1668-
(const char *) src0->data + src0_start * nb01,
1669-
(const char *) src1_wdata + (src1_col_stride * iter), 1,
1670-
src0_end - src0_start);
1689+
1690+
ggml_barrier(params->threadpool);
1691+
1692+
// The first chunk comes from our thread_id, the rest will get auto-assigned.
1693+
int current_chunk = ith;
1694+
1695+
while (current_chunk < nchunk) {
1696+
int64_t src0_start = (current_chunk * ne01) / nchunk;
1697+
int64_t src0_end = ((current_chunk + 1) * ne01) / nchunk;
1698+
src0_start = (src0_start % NB_COLS) ? src0_start + NB_COLS - (src0_start % NB_COLS) : src0_start;
1699+
src0_end = (src0_end % NB_COLS) ? src0_end + NB_COLS - (src0_end % NB_COLS) : src0_end;
1700+
if (src0_start >= src0_end) {
1701+
break;
1702+
}
1703+
1704+
forward_mul_mat_one_chunk(params, dst, src0_start, src0_end);
1705+
1706+
current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
16711707
}
16721708
}
16731709

0 commit comments

Comments
 (0)