diff --git a/ggml/src/ggml-cpu/repack.cpp b/ggml/src/ggml-cpu/repack.cpp index 72ee93a5abc..b62361ef9df 100644 --- a/ggml/src/ggml-cpu/repack.cpp +++ b/ggml/src/ggml-cpu/repack.cpp @@ -1180,6 +1180,11 @@ template src[0]; + + switch (src0->type) { + case GGML_TYPE_Q4_K: + ggml_compute_forward_get_rows_q4_Kx8(params, dst); + break; + default: + GGML_ABORT("fatal error"); + break; + } + } + + static void ggml_compute_forward_get_rows_q4_Kx8( + const ggml_compute_params * params, + ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + + GGML_TENSOR_BINARY_OP_LOCALS + const int64_t nc = ne00; + const int64_t nr = ggml_nelements(src1); + + assert(ne0 == nc); + assert(ne02 == ne11); + assert(nb00 == ggml_type_size(src0->type)); + assert(ggml_nrows(dst) == nr); + + const int ith = params->ith; + const int nth = params->nth; + + // rows per thread + const int dr = (nr + nth - 1) / nth; + + // row range for this thread + const int ir0 = dr * ith; + const int ir1 = MIN(ir0 + dr, nr); + + constexpr int nrows_interleaved = 8; + const size_t sizeof_one_repacked_block = sizeof(block_q4_Kx8); + + const int num_repacked_blocks_per_row_width = nc / QK_K; + + const size_t stride_between_actual_row_groups = num_repacked_blocks_per_row_width * sizeof_one_repacked_block; + + for (int64_t i = ir0; i < ir1; ++i) { + const int64_t i12 = i / (ne11 * ne10); + const int64_t i11 = (i - i12 * ne11 * ne10) / ne10; + const int64_t i10 = (i - i12 * ne11 * ne10 - i11 * ne10); + const int64_t i01 = *(int32_t *)((char *)src1->data + i10 * nb10 + i11 * nb11 + i12 * nb12); // original logical row + + GGML_ASSERT(i01 >= 0 && i01 < ne01); + + const int row_group_idx = i01 / nrows_interleaved; + const int row_idx_in_group = i01 % nrows_interleaved; + + const char * base_ptr_for_higher_dims_in_src0 = (const char *)src0->data + i11 * nb02 + i12 * nb03; + + // Pointer to the first block_q4_Kx8 of the identified row_group_idx + const block_q4_Kx8 * p_first_repacked_block_of_group_x8 = (const block_q4_Kx8 *)(base_ptr_for_higher_dims_in_src0 + row_group_idx * stride_between_actual_row_groups); + + dequantize_row_q4_Kx8( + p_first_repacked_block_of_group_x8, + (float *)((char *)dst->data + i10 * nb1 + i11 * nb2 + i12 * nb3), nc, row_idx_in_group); + } + } + + /** + * Dequantizes a single logical row from the repacked q4_Kx8 data format. + * + * @param p_repacked_blocks Pointer to the start of the 'block_q4_Kx8' structures for the entire row. + * @param y Output buffer for the dequantized float values. + * @param k Total number of elements (columns) in the logical row. + * @param row_idx_in_group The index (0-7) of the logical row to extract from the interleaved data. + */ + + static void dequantize_row_q4_Kx8( + const void * GGML_RESTRICT p_repacked_blocks, + float * GGML_RESTRICT y, + int64_t k, + int row_idx_in_group) { + + assert(k % QK_K == 0); + assert(row_idx_in_group >= 0 && row_idx_in_group < 8); + + const int nb = k / QK_K; + const block_q4_Kx8 * blocks = (const block_q4_Kx8 *)p_repacked_blocks; + + for (int i = 0; i < nb; i++) { + const block_q4_Kx8 * current_block = &blocks[i]; + + const float d_super_block = GGML_FP16_TO_FP32(current_block->d[row_idx_in_group]); + const float dmin_super_block = GGML_FP16_TO_FP32(current_block->dmin[row_idx_in_group]); + + const uint8_t * ptr_qs_base = current_block->qs; + const uint8_t * ptr_repacked_scales = (const uint8_t *)current_block->scales; + int is = 0, chunk_group_start_idx = 0; + for (int j = 0; j < QK_K; j += 64) { + + uint8_t sc1, m1_val, sc2, m2_val; + const uint8_t *scales_repacked_data; + + scales_repacked_data = &ptr_repacked_scales[(is + 0) * 12]; + get_scale_min_k4(row_idx_in_group, scales_repacked_data, &sc1, &m1_val); + + scales_repacked_data = &ptr_repacked_scales[(is + 1) * 12]; + get_scale_min_k4(row_idx_in_group, scales_repacked_data, &sc2, &m2_val); + + const float d1 = d_super_block * sc1; + const float m1 = dmin_super_block * m1_val; + const float d2 = d_super_block * sc2; + const float m2 = dmin_super_block * m2_val; + + for (int idx = 0; idx < 4; idx++) { + const uint8_t * ptr_qs_chunk = ptr_qs_base + ((chunk_group_start_idx + idx) * 64) + row_idx_in_group * 8; + for (int l = 0; l < 8; ++l) *y++ = d1 * (ptr_qs_chunk[l] & 0xF) - m1; // 16 elements of quants + } + + for (int idx = 0; idx < 4; idx++) { + const uint8_t * ptr_qs_chunk = ptr_qs_base + ((chunk_group_start_idx + idx) * 64) + row_idx_in_group * 8; + for (int l = 0; l < 8; ++l) *y++ = d2 * (ptr_qs_chunk[l] >> 4) - m2; // 16 elements of quants + } + is += 2; + chunk_group_start_idx += 4; + } + } + } + + static inline void get_scale_min_k4(int j, const uint8_t *GGML_RESTRICT s, uint8_t *GGML_RESTRICT d, uint8_t *GGML_RESTRICT m) { + if (j < 4) { + *d = s[j] & 63; + *m = s[j + 4] & 63; + } else { + *d = (s[j + 4] & 0xF) | ((s[j - 4] >> 6) << 4); + *m = (s[j + 4] >> 4) | ((s[j - 0] >> 6) << 4); + } + } + int repack(struct ggml_tensor * t, const void * data, size_t data_size) override { GGML_LOG_DEBUG("%s: repack tensor %s with %s_%dx%d\n", __func__, t->name, ggml_type_name(t->type), (int) NB_COLS, (int) INTER_SIZE); @@ -1538,12 +1685,23 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type { //if (op->src[1]->type == GGML_TYPE_Q8_0) { // return true; //} + } else if (op->op == GGML_OP_GET_ROWS + && op->src[0]->buffer + && (ggml_n_dims(op->src[0]) == 2) + && op->src[0]->buffer->buft == ggml_backend_cpu_repack_buffer_type() + && ggml_repack_get_optimal_repack_type(op->src[0])) { + if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) { + return false; + } + if (op->src[0]->type == GGML_TYPE_Q4_K) { + return true; + } } return false; } ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override { - if (op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_MUL_MAT_ID) { + if (op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_MUL_MAT_ID || op->op == GGML_OP_GET_ROWS) { if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_repack_buffer_type()) { return (ggml::cpu::tensor_traits *) op->src[0]->extra; } diff --git a/src/whisper.cpp b/src/whisper.cpp index 347cc178ee7..e5922227e6e 100644 --- a/src/whisper.cpp +++ b/src/whisper.cpp @@ -1437,24 +1437,25 @@ static bool weight_buft_supported(const whisper_hparams & hparams, ggml_tensor * // GPU and default CPU backend support all operators op_supported = true; } else { - switch (op) { - // The current extra_buffer_type implementations only support GGML_OP_MUL_MAT - case GGML_OP_MUL_MAT: { - ggml_init_params params = { - /*.mem_size =*/ 2 * ggml_tensor_overhead(), - /*.mem_buffer =*/ nullptr, - /*.no_alloc =*/ true, - }; + ggml_init_params params = { + /*.mem_size =*/ 2 * ggml_tensor_overhead(), + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; - ggml_context_ptr ctx_ptr { ggml_init(params) }; - if (!ctx_ptr) { - throw std::runtime_error("failed to create ggml context"); - } - ggml_context * ctx = ctx_ptr.get(); + ggml_context_ptr ctx_ptr { ggml_init(params) }; + if (!ctx_ptr) { + throw std::runtime_error("failed to create ggml context"); + } + ggml_context * ctx = ctx_ptr.get(); - ggml_tensor * op_tensor = nullptr; + ggml_tensor * op_tensor = nullptr; + + int64_t n_ctx = hparams.n_audio_ctx; - int64_t n_ctx = hparams.n_audio_ctx; + switch (op) { + // The current extra_buffer_type implementations only support GGML_OP_MUL_MAT & GGML_OP_GET_ROWS (repacked - q4_K) + case GGML_OP_MUL_MAT: { ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], n_ctx, w->ne[2], w->ne[3]); op_tensor = ggml_mul_mat(ctx, w, b); @@ -1466,6 +1467,18 @@ static bool weight_buft_supported(const whisper_hparams & hparams, ggml_tensor * w->buffer = nullptr; break; } + case GGML_OP_GET_ROWS: { + ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_ctx); + op_tensor = ggml_get_rows(ctx, w, b); + + // create a temporary dummy buffer for the weight so that supports_op can check the buffer type + GGML_ASSERT(w->buffer == nullptr); + w->buffer = ggml_backend_buft_alloc_buffer(buft, 0); + op_supported = ggml_backend_dev_supports_op(dev, op_tensor); + ggml_backend_buffer_free(w->buffer); + w->buffer = nullptr; + break; + } default: { op_supported = false; break;