@@ -1600,6 +1600,32 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
16001600 return false ;
16011601 }
16021602
1603+ void forward_mul_mat_one_chunk (ggml_compute_params * params, ggml_tensor * op, int64_t src0_start, int64_t src0_end) {
1604+ const ggml_tensor * src0 = op->src [0 ];
1605+ const ggml_tensor * src1 = op->src [1 ];
1606+ ggml_tensor * dst = op;
1607+
1608+ GGML_TENSOR_BINARY_OP_LOCALS
1609+
1610+ const void * src1_wdata = params->wdata ;
1611+ const size_t src1_col_stride = ggml_row_size (PARAM_TYPE, ne10);
1612+
1613+ // If there are more than three rows in src1, use gemm; otherwise, use gemv.
1614+ if (ne11 > 3 ) {
1615+ gemm<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
1616+ (float *) ((char *) dst->data ) + src0_start, ne01,
1617+ (const char *) src0->data + src0_start * nb01,
1618+ (const char *) src1_wdata, ne11 - ne11 % 4 , src0_end - src0_start);
1619+ }
1620+ for (int iter = ne11 - ne11 % 4 ; iter < ne11; iter++) {
1621+ gemv<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
1622+ (float *) ((char *) dst->data + (iter * nb1)) + src0_start, ne01,
1623+ (const char *) src0->data + src0_start * nb01,
1624+ (const char *) src1_wdata + (src1_col_stride * iter), 1 ,
1625+ src0_end - src0_start);
1626+ }
1627+ }
1628+
16031629 void forward_mul_mat (ggml_compute_params * params, ggml_tensor * op) {
16041630 const ggml_tensor * src0 = op->src [0 ];
16051631 const ggml_tensor * src1 = op->src [1 ];
@@ -1643,31 +1669,41 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
16431669 from_float ((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), ne10);
16441670 }
16451671
1646- ggml_barrier (params->threadpool );
1672+ // disable for NUMA
1673+ const bool disable_chunking = ggml_is_numa ();
16471674
1648- const void * src1_wdata = params-> wdata ;
1649- const size_t src1_col_stride = ggml_row_size (PARAM_TYPE, ne10 );
1650- int64_t src0_start = (ith * ne01) / nth ;
1651- int64_t src0_end = ((ith + 1 ) * ne01 ) / nth ;
1652- src0_start = (src0_start % NB_COLS) ? src0_start + NB_COLS - (src0_start % NB_COLS) : src0_start ;
1653- src0_end = (src0_end % NB_COLS) ? src0_end + NB_COLS - (src0_end % NB_COLS) : src0_end;
1654- if (src0_start >= src0_end ) {
1655- return ;
1675+ // 4x chunks per thread
1676+ int64_t nr = ggml_nrows (op-> src [ 0 ] );
1677+ int nth_scaled = nth * 4 ;
1678+ int64_t chunk_size = (nr + nth_scaled - 1 ) / nth_scaled ;
1679+ int64_t nchunk = (nr + chunk_size - 1 ) / chunk_size ;
1680+
1681+ if (nth == 1 || nchunk < nth || disable_chunking ) {
1682+ nchunk = nth ;
16561683 }
16571684
1658- // If there are more than three rows in src1, use gemm; otherwise, use gemv.
1659- if (ne11 > 3 ) {
1660- gemm<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
1661- (float *) ((char *) dst->data ) + src0_start, ne01,
1662- (const char *) src0->data + src0_start * nb01,
1663- (const char *) src1_wdata, ne11 - ne11 % 4 , src0_end - src0_start);
1685+ if (ith == 0 ) {
1686+ // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
1687+ ggml_threadpool_chunk_set (params->threadpool , nth);
16641688 }
1665- for (int iter = ne11 - ne11 % 4 ; iter < ne11; iter++) {
1666- gemv<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
1667- (float *) ((char *) dst->data + (iter * nb1)) + src0_start, ne01,
1668- (const char *) src0->data + src0_start * nb01,
1669- (const char *) src1_wdata + (src1_col_stride * iter), 1 ,
1670- src0_end - src0_start);
1689+
1690+ ggml_barrier (params->threadpool );
1691+
1692+ // The first chunk comes from our thread_id, the rest will get auto-assigned.
1693+ int current_chunk = ith;
1694+
1695+ while (current_chunk < nchunk) {
1696+ int64_t src0_start = (current_chunk * ne01) / nchunk;
1697+ int64_t src0_end = ((current_chunk + 1 ) * ne01) / nchunk;
1698+ src0_start = (src0_start % NB_COLS) ? src0_start + NB_COLS - (src0_start % NB_COLS) : src0_start;
1699+ src0_end = (src0_end % NB_COLS) ? src0_end + NB_COLS - (src0_end % NB_COLS) : src0_end;
1700+ if (src0_start >= src0_end) {
1701+ break ;
1702+ }
1703+
1704+ forward_mul_mat_one_chunk (params, dst, src0_start, src0_end);
1705+
1706+ current_chunk = ggml_threadpool_chunk_add (params->threadpool , 1 );
16711707 }
16721708 }
16731709
0 commit comments