Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions ggml/src/ggml-cpu/repack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1678,10 +1678,24 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled;
int64_t nchunk = (nr + chunk_size - 1) / chunk_size;

// Ensure minimum chunk size to avoid alignment issues with high thread counts
// Minimum chunk size should be at least NB_COLS to prevent overlapping chunks after alignment
const int64_t min_chunk_size = NB_COLS;
if (nchunk > 0 && (nr / nchunk) < min_chunk_size && nr >= min_chunk_size) {
nchunk = (nr + min_chunk_size - 1) / min_chunk_size;
}

if (nth == 1 || nchunk < nth || disable_chunking) {
nchunk = nth;
}

// Ensure nchunk doesn't exceed the number of rows divided by minimum chunk size
// This prevents creating too many tiny chunks that could overlap after alignment
const int64_t max_nchunk = (nr + min_chunk_size - 1) / min_chunk_size;
if (nchunk > max_nchunk) {
nchunk = max_nchunk;
}

if (ith == 0) {
// Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
ggml_threadpool_chunk_set(params->threadpool, nth);
Expand All @@ -1695,8 +1709,15 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
while (current_chunk < nchunk) {
int64_t src0_start = (current_chunk * ne01) / nchunk;
int64_t src0_end = ((current_chunk + 1) * ne01) / nchunk;

// Align boundaries to NB_COLS - round up to ensure all data is included
// The chunk size limiting above ensures chunks are large enough to prevent overlaps
src0_start = (src0_start % NB_COLS) ? src0_start + NB_COLS - (src0_start % NB_COLS) : src0_start;
src0_end = (src0_end % NB_COLS) ? src0_end + NB_COLS - (src0_end % NB_COLS) : src0_end;
if (src0_end > ne01) {
src0_end = ne01;
}

if (src0_start >= src0_end) {
break;
}
Expand Down Expand Up @@ -1808,8 +1829,12 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
int64_t src0_cur_start = (ith * ne01) / nth;
int64_t src0_cur_end = ((ith + 1) * ne01) / nth;

// Align boundaries to NB_COLS - round up to ensure all data is included
src0_cur_start = (src0_cur_start % NB_COLS) ? src0_cur_start + NB_COLS - (src0_cur_start % NB_COLS) : src0_cur_start;
src0_cur_end = (src0_cur_end % NB_COLS) ? src0_cur_end + NB_COLS - (src0_cur_end % NB_COLS) : src0_cur_end;
if (src0_cur_end > ne01) {
src0_cur_end = ne01;
}

if (src0_cur_start >= src0_cur_end) {
return;
Expand Down
Loading