@@ -1600,52 +1600,29 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
16001600 return false ;
16011601 }
16021602
1603- void forward_mul_mat_one_chunk (ggml_compute_params * params,
1604- ggml_tensor * op,
1605- int64_t src0_start,
1606- int64_t src0_end,
1607- int64_t src1_start,
1608- int64_t src1_end) {
1603+ void forward_mul_mat_one_chunk (ggml_compute_params * params, ggml_tensor * op, int64_t src0_start, int64_t src0_end) {
16091604 const ggml_tensor * src0 = op->src [0 ];
16101605 const ggml_tensor * src1 = op->src [1 ];
16111606 ggml_tensor * dst = op;
16121607
16131608 GGML_TENSOR_BINARY_OP_LOCALS
16141609
1610+ const void * src1_wdata = params->wdata ;
16151611 const size_t src1_col_stride = ggml_row_size (PARAM_TYPE, ne10);
16161612
1617- GGML_ASSERT (ne03 == 1 && ne13 == 1 );
1618- GGML_ASSERT (ne12 % ne02 == 0 );
1619- const int64_t r2 = ne12 / ne02;
1620-
1621- const int64_t i12 = src1_start / ne1;
1622- const int64_t i11 = src1_start - i12 * ne1;
1623-
1624- // Determine batch index
1625- const int64_t i02 = i12 / r2;
1626-
1627- const int64_t i1 = i11;
1628- const int64_t i2 = i12;
1629-
1630- const char * src0_ptr = (const char *) src0->data + i02 * nb02;
1631- const char * src1_ptr = (const char *) params->wdata + (i11 + i12 * ne11) * src1_col_stride;
1632- char * dst_ptr = ((char *) dst->data + (i1 * nb1 + i2 * nb2));
1633-
1634- const int64_t nrows = src1_end - src1_start;
1635- const int64_t ncols = src0_end - src0_start;
1636-
1637- GGML_ASSERT (src1_ptr + src1_col_stride * nrows <= (const char *) params->wdata + params->wsize );
1638-
16391613 // If there are more than three rows in src1, use gemm; otherwise, use gemv.
1640- if (nrows > 3 ) {
1641- gemm<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00, (float *) (dst_ptr) + src0_start, nb1 / nb0,
1642- src0_ptr + src0_start * nb01, src1_ptr,
1643- nrows - (nrows % 4 ), ncols);
1614+ if (ne11 > 3 ) {
1615+ gemm<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
1616+ (float *) ((char *) dst->data ) + src0_start, ne01,
1617+ (const char *) src0->data + src0_start * nb01,
1618+ (const char *) src1_wdata, ne11 - ne11 % 4 , src0_end - src0_start);
16441619 }
1645- for (int iter = nrows - (nrows % 4 ); iter < nrows; iter++) {
1646- gemv<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00, (float *) (dst_ptr + (iter * nb1)) + src0_start,
1647- ne01, src0_ptr + src0_start * nb01,
1648- src1_ptr + (src1_col_stride * iter), 1 /* nrows */ , ncols);
1620+ for (int iter = ne11 - ne11 % 4 ; iter < ne11; iter++) {
1621+ gemv<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
1622+ (float *) ((char *) dst->data + (iter * nb1)) + src0_start, ne01,
1623+ (const char *) src0->data + src0_start * nb01,
1624+ (const char *) src1_wdata + (src1_col_stride * iter), 1 ,
1625+ src0_end - src0_start);
16491626 }
16501627 }
16511628
@@ -1670,73 +1647,54 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
16701647 GGML_ASSERT (nb1 <= nb2);
16711648 GGML_ASSERT (nb2 <= nb3);
16721649
1673- // TODO: General batched mul mat for 4D tensors
1674- // Currently only supports 3D tensors
1675- GGML_ASSERT (ne03 == 1 );
1676- GGML_ASSERT (ne13 == 1 );
1677- GGML_ASSERT (ne3 == 1 );
1678-
16791650 GGML_ASSERT (src1->type == GGML_TYPE_F32);
16801651
16811652 GGML_ASSERT (ggml_n_dims (op->src [0 ]) == 2 );
16821653 // GGML_ASSERT(ggml_n_dims(op->src[1]) == 2);
16831654
16841655 char * wdata = static_cast <char *>(params->wdata );
16851656 const size_t nbw1 = ggml_row_size (PARAM_TYPE, ne10);
1686- const size_t nbw2 = nbw1 * ne11;
16871657
1688- assert (params->wsize >= nbw2 * ne12 );
1658+ assert (params->wsize >= nbw1 * ne11 );
16891659
16901660 const ggml_from_float_t from_float = ggml_get_type_traits_cpu (PARAM_TYPE)->from_float ;
16911661
1692- for (int64_t i12 = 0 ; i12 < ne12; i12++) {
1693- char * data_ptr = (char *) src1->data + i12 * nb12;
1694- char * wdata_ptr = wdata + i12 * nbw2;
1695-
1696- for (int64_t i11 = ith * 4 ; i11 < ne11 - ne11 % 4 ; i11 += nth * 4 ) {
1697- ggml_quantize_mat_t <INTER_SIZE, PARAM_TYPE>((float *) (data_ptr + i11 * nb11),
1698- (void *) (wdata_ptr + i11 * nbw1), 4 , ne10);
1699- }
1662+ int64_t i11_processed = 0 ;
1663+ for (int64_t i11 = ith * 4 ; i11 < ne11 - ne11 % 4 ; i11 += nth * 4 ) {
1664+ ggml_quantize_mat_t <INTER_SIZE, PARAM_TYPE>((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), 4 , ne10);
1665+ }
17001666
1701- const int64_t i11_processed = ne11 - ne11 % 4 ;
1702- for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) {
1703- from_float ((float *) (data_ptr + i11 * nb11), (void *) (wdata_ptr + i11 * nbw1), ne10);
1704- }
1667+ i11_processed = ne11 - ne11 % 4 ;
1668+ for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) {
1669+ from_float ((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), ne10);
17051670 }
17061671
17071672 // disable for NUMA
17081673 const bool disable_chunking = ggml_is_numa ();
17091674
17101675 // 4x chunks per thread
1711- const int64_t nr0 = ggml_nrows (op->src [0 ]);
1712- const int64_t nr1 = ne1 * ne2 * ne3;
1713-
1714- int nth_scaled = nth * 4 ;
1715- int64_t chunk_size0 = (nr0 + nth_scaled - 1 ) / nth_scaled;
1716- // avoid too small chunks for narrow src1
1717- int64_t chunk_size1 = MAX (16 , (nr1 + nth - 1 ) / nth);
1718- int64_t nchunk0 = (nr0 + chunk_size0 - 1 ) / chunk_size0;
1719- int64_t nchunk1 = (nr1 + chunk_size1 - 1 ) / chunk_size1;
1676+ int64_t nr = ggml_nrows (op->src [0 ]);
1677+ int nth_scaled = nth * 4 ;
1678+ int64_t chunk_size = (nr + nth_scaled - 1 ) / nth_scaled;
1679+ int64_t nchunk = (nr + chunk_size - 1 ) / chunk_size;
17201680
17211681 // Ensure minimum chunk size to avoid alignment issues with high thread counts
17221682 // Minimum chunk size should be at least NB_COLS to prevent overlapping chunks after alignment
17231683 const int64_t min_chunk_size = NB_COLS;
1724- if (nchunk0 > 0 && (nr0 / nchunk0 ) < min_chunk_size && nr0 >= min_chunk_size) {
1725- nchunk0 = (nr0 + min_chunk_size - 1 ) / min_chunk_size;
1684+ if (nchunk > 0 && (nr / nchunk ) < min_chunk_size && nr >= min_chunk_size) {
1685+ nchunk = (nr + min_chunk_size - 1 ) / min_chunk_size;
17261686 }
17271687
1728- if (nth == 1 || nchunk0 * nchunk1 < nth || disable_chunking) {
1729- nchunk0 = nr0 > nr1 ? nth : 1 ;
1730- nchunk1 = nr0 > nr1 ? 1 : nth;
1688+ if (nth == 1 || nchunk < nth || disable_chunking) {
1689+ nchunk = nth;
17311690 }
17321691
1733- const int64_t dr0 = (nr0 + nchunk0 - 1 ) / nchunk0;
1734- const int64_t dr1 = (nr1 + nchunk1 - 1 ) / nchunk1;
1735-
17361692 // Ensure nchunk doesn't exceed the number of rows divided by minimum chunk size
17371693 // This prevents creating too many tiny chunks that could overlap after alignment
1738- const int64_t max_nchunk = (nr0 + min_chunk_size - 1 ) / min_chunk_size;
1739- nchunk0 = MIN (nchunk0, max_nchunk);
1694+ const int64_t max_nchunk = (nr + min_chunk_size - 1 ) / min_chunk_size;
1695+ if (nchunk > max_nchunk) {
1696+ nchunk = max_nchunk;
1697+ }
17401698
17411699 if (ith == 0 ) {
17421700 // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
@@ -1748,29 +1706,23 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
17481706 // The first chunk comes from our thread_id, the rest will get auto-assigned.
17491707 int current_chunk = ith;
17501708
1751- while (current_chunk < nchunk0 * nchunk1) {
1752- const int64_t ith0 = current_chunk % nchunk0;
1753- const int64_t ith1 = current_chunk / nchunk0;
1754-
1755- int64_t src0_start = dr0 * ith0;
1756- int64_t src0_end = MIN (src0_start + dr0, nr0);
1757-
1758- int64_t src1_start = dr1 * ith1;
1759- int64_t src1_end = MIN (src1_start + dr1, nr1);
1709+ while (current_chunk < nchunk) {
1710+ int64_t src0_start = (current_chunk * ne01) / nchunk;
1711+ int64_t src0_end = ((current_chunk + 1 ) * ne01) / nchunk;
17601712
17611713 // Align boundaries to NB_COLS - round up to ensure all data is included
17621714 // The chunk size limiting above ensures chunks are large enough to prevent overlaps
17631715 src0_start = (src0_start % NB_COLS) ? src0_start + NB_COLS - (src0_start % NB_COLS) : src0_start;
1764- src0_end = (src0_end % NB_COLS) ? src0_end + NB_COLS - (src0_end % NB_COLS) : src0_end;
1765- src0_end = MIN (src0_end, ne01);
1716+ src0_end = (src0_end % NB_COLS) ? src0_end + NB_COLS - (src0_end % NB_COLS) : src0_end;
1717+ if (src0_end > ne01) {
1718+ src0_end = ne01;
1719+ }
17661720
1767- // Make sure current plane is the last one before exiting
17681721 if (src0_start >= src0_end) {
1769- current_chunk = ggml_threadpool_chunk_add (params->threadpool , 1 );
1770- continue ;
1722+ break ;
17711723 }
17721724
1773- forward_mul_mat_one_chunk (params, dst, src0_start, src0_end, src1_start, src1_end );
1725+ forward_mul_mat_one_chunk (params, dst, src0_start, src0_end);
17741726
17751727 current_chunk = ggml_threadpool_chunk_add (params->threadpool , 1 );
17761728 }
0 commit comments