Skip to content

Commit ed1d3a2

Browse files
committed
Get_Rows & Dequantize implementation adapted to work for repacked weights of type q4_0
1 parent 705db0f commit ed1d3a2

File tree

2 files changed

+172
-16
lines changed

2 files changed

+172
-16
lines changed

ggml/src/ggml-cpu/repack.cpp

Lines changed: 144 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1181,6 +1181,9 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
11811181
size = GGML_PAD(size, sizeof(int64_t)); // + padding for next bloc.
11821182
size += sizeof(int64_t) * (1+op->src[0]->ne[2]) * op->src[1]->ne[2];
11831183
return true;
1184+
case GGML_OP_GET_ROWS:
1185+
size = 0; // GET_ROWS (standard and repacked) doesn't need a work buffer
1186+
return true;
11841187
default:
11851188
// GGML_ABORT("fatal error");
11861189
break;
@@ -1196,6 +1199,9 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
11961199
case GGML_OP_MUL_MAT_ID:
11971200
forward_mul_mat_id(params, op);
11981201
return true;
1202+
case GGML_OP_GET_ROWS:
1203+
forward_get_rows(params, op);
1204+
return true;
11991205
default:
12001206
// GGML_ABORT("fatal error");
12011207
break;
@@ -1401,6 +1407,132 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
14011407
#undef MMID_MATRIX_ROW
14021408
}
14031409

1410+
void forward_get_rows(const ggml_compute_params * params,
1411+
ggml_tensor * dst) {
1412+
const ggml_tensor * src0 = dst->src[0];
1413+
1414+
switch (src0->type) {
1415+
case GGML_TYPE_Q4_0: {
1416+
ggml_compute_forward_get_rows_q4_0x8(params, dst);
1417+
} break;
1418+
default:
1419+
GGML_ABORT("fatal error");
1420+
break;
1421+
}
1422+
}
1423+
1424+
static void ggml_compute_forward_get_rows_q4_0x8(
1425+
const ggml_compute_params * params,
1426+
ggml_tensor * dst) {
1427+
const ggml_tensor * src0 = dst->src[0];
1428+
const ggml_tensor * src1 = dst->src[1];
1429+
1430+
GGML_TENSOR_BINARY_OP_LOCALS
1431+
1432+
const int64_t nc = ne00;
1433+
const int64_t nr = ggml_nelements(src1);
1434+
1435+
assert(ne0 == nc);
1436+
assert(ne02 == ne11);
1437+
assert(nb00 == ggml_type_size(src0->type));
1438+
assert(ggml_nrows(dst) == nr);
1439+
1440+
const int ith = params->ith;
1441+
const int nth = params->nth;
1442+
1443+
// rows per thread
1444+
const int dr = (nr + nth - 1) / nth;
1445+
1446+
// row range for this thread
1447+
const int ir0 = dr * ith;
1448+
const int ir1 = MIN(ir0 + dr, nr);
1449+
1450+
constexpr int nrows_interleaved = 8;
1451+
const size_t sizeof_one_repacked_block = sizeof(block_q4_0x8);
1452+
1453+
const int num_repacked_blocks_per_row_width = nc / QK4_0;
1454+
1455+
const size_t stride_between_actual_row_groups = num_repacked_blocks_per_row_width * sizeof_one_repacked_block;
1456+
1457+
for (int64_t i = ir0; i < ir1; ++i) {
1458+
const int64_t i12 = i / (ne11 * ne10);
1459+
const int64_t i11 = (i - i12 * ne11 * ne10) / ne10;
1460+
const int64_t i10 = (i - i12 * ne11 * ne10 - i11 * ne10);
1461+
const int64_t i01 = *(int32_t *)((char *)src1->data + i10 * nb10 + i11 * nb11 + i12 * nb12); // original logical row
1462+
1463+
GGML_ASSERT(i01 >= 0 && i01 < ne01);
1464+
1465+
int row_group_idx = i01 / nrows_interleaved;
1466+
const int row_idx_in_group = i01 % nrows_interleaved;
1467+
1468+
const char * base_ptr_for_higher_dims_in_src0 = (const char *)src0->data + i11 * nb02 + i12 * nb03;
1469+
1470+
// Pointer to the first block_q4_0x8 of the identified row_group_idx
1471+
const block_q4_0x8 * p_first_repacked_block_of_group_x8 = (const block_q4_0x8 *)(base_ptr_for_higher_dims_in_src0 + row_group_idx * stride_between_actual_row_groups);
1472+
1473+
dequantize_row_q4_0x8(
1474+
p_first_repacked_block_of_group_x8,
1475+
(float *)((char *)dst->data + i10 * nb1 + i11 * nb2 + i12 * nb3), nc, row_idx_in_group);
1476+
}
1477+
}
1478+
1479+
/**
1480+
* Dequantizes a single logical row from data repacked with quant interleaving.
1481+
*
1482+
* @param p_repacked_group_column_blocks Pointer to the start of 'block_q4_0x8' for the row group.
1483+
* @param y Output buffer for the dequantized float values.
1484+
* @param k Total number of elements (columns) in the logical row.
1485+
* @param row_idx_in_group Index (0-7) of the logical row to dequantize.
1486+
*/
1487+
static void dequantize_row_q4_0x8(
1488+
const block_q4_0x8 * GGML_RESTRICT p_repacked_group_column_blocks,
1489+
float * GGML_RESTRICT y,
1490+
int64_t k,
1491+
int row_idx_in_group) {
1492+
const int GGML_Q4_0_X8_INTERLEAVE_SIZE = 8;
1493+
assert(k % QK4_0 == 0);
1494+
assert(row_idx_in_group >= 0 && row_idx_in_group < GGML_Q4_0_X8_INTERLEAVE_SIZE);
1495+
1496+
const int nb = k / QK4_0;
1497+
const int bytes_for_half_elements = (QK4_0 / 2) / 2;
1498+
1499+
const int offset_to_second_half_data = bytes_for_half_elements * GGML_Q4_0_X8_INTERLEAVE_SIZE;
1500+
const uint64_t xor_mask = 0x8888888888888888ULL;
1501+
const int qk4_0_half_elements = QK4_0 / 2;
1502+
1503+
for (int i = 0; i < nb; ++i) {
1504+
const block_q4_0x8 * current_column_repacked_block = &p_repacked_group_column_blocks[i];
1505+
const float d_val = GGML_FP16_TO_FP32(current_column_repacked_block->d[row_idx_in_group]);
1506+
float * y_curr = y + i * QK4_0;
1507+
1508+
const int8_t * qs_first_half_repacked_ptr = &(current_column_repacked_block->qs[row_idx_in_group * bytes_for_half_elements]);
1509+
1510+
uint64_t first_half_chunk_u64;
1511+
memcpy(&first_half_chunk_u64, qs_first_half_repacked_ptr, sizeof(uint64_t));
1512+
first_half_chunk_u64 ^= xor_mask; // Reverse the XOR
1513+
const uint8_t * original_qs_first_half_bytes = (const uint8_t *)&first_half_chunk_u64;
1514+
1515+
const int8_t * qs_second_half_repacked_ptr = &(current_column_repacked_block->qs[offset_to_second_half_data + (row_idx_in_group * bytes_for_half_elements)]);
1516+
1517+
uint64_t second_half_chunk_u64;
1518+
memcpy(&second_half_chunk_u64, qs_second_half_repacked_ptr, sizeof(uint64_t));
1519+
second_half_chunk_u64 ^= xor_mask; // Reverse the XOR
1520+
const uint8_t * original_qs_second_half_bytes = (const uint8_t *)&second_half_chunk_u64;
1521+
1522+
// dequantizing all QK4_0's for this block.
1523+
for (int j = 0; j < bytes_for_half_elements; ++j) {
1524+
const uint8_t quant_byte_first = original_qs_first_half_bytes[j];
1525+
y_curr[j] = ((quant_byte_first & 0x0F) - 8) * d_val;
1526+
y_curr[j + qk4_0_half_elements] = ((quant_byte_first >> 4) - 8) * d_val;
1527+
1528+
const uint8_t quant_byte_second = original_qs_second_half_bytes[j];
1529+
const int out_idx_base_second_half = j + bytes_for_half_elements; // Offset for the second set of low nibbles
1530+
y_curr[out_idx_base_second_half] = ((quant_byte_second & 0x0F) - 8) * d_val;
1531+
y_curr[out_idx_base_second_half + qk4_0_half_elements] = ((quant_byte_second >> 4) - 8) * d_val;
1532+
}
1533+
}
1534+
}
1535+
14041536
int repack(struct ggml_tensor * t, const void * data, size_t data_size) override {
14051537
GGML_LOG_DEBUG("%s: repack tensor %s with %s_%dx%d\n", __func__, t->name, ggml_type_name(t->type),
14061538
(int) NB_COLS, (int) INTER_SIZE);
@@ -1533,12 +1665,23 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type {
15331665
//if (op->src[1]->type == GGML_TYPE_Q8_0) {
15341666
// return true;
15351667
//}
1668+
} else if (op->op == GGML_OP_GET_ROWS
1669+
&& op->src[0]->buffer
1670+
&& (ggml_n_dims(op->src[0]) == 2)
1671+
&& op->src[0]->buffer->buft == ggml_backend_cpu_repack_buffer_type()
1672+
&& ggml_repack_get_optimal_repack_type(op->src[0])) {
1673+
if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
1674+
return false;
1675+
}
1676+
if (op->src[0]->type == GGML_TYPE_Q4_0) {
1677+
return true;
1678+
}
15361679
}
15371680
return false;
15381681
}
15391682

15401683
ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override {
1541-
if (op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_MUL_MAT_ID) {
1684+
if (op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_MUL_MAT_ID || op->op == GGML_OP_GET_ROWS) {
15421685
if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_repack_buffer_type()) {
15431686
return (ggml::cpu::tensor_traits *) op->src[0]->extra;
15441687
}

src/whisper.cpp

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1437,24 +1437,25 @@ static bool weight_buft_supported(const whisper_hparams & hparams, ggml_tensor *
14371437
// GPU and default CPU backend support all operators
14381438
op_supported = true;
14391439
} else {
1440-
switch (op) {
1441-
// The current extra_buffer_type implementations only support GGML_OP_MUL_MAT
1442-
case GGML_OP_MUL_MAT: {
1443-
ggml_init_params params = {
1444-
/*.mem_size =*/ 2 * ggml_tensor_overhead(),
1445-
/*.mem_buffer =*/ nullptr,
1446-
/*.no_alloc =*/ true,
1447-
};
1440+
ggml_init_params params = {
1441+
/*.mem_size =*/ 2 * ggml_tensor_overhead(),
1442+
/*.mem_buffer =*/ nullptr,
1443+
/*.no_alloc =*/ true,
1444+
};
14481445

1449-
ggml_context_ptr ctx_ptr { ggml_init(params) };
1450-
if (!ctx_ptr) {
1451-
throw std::runtime_error("failed to create ggml context");
1452-
}
1453-
ggml_context * ctx = ctx_ptr.get();
1446+
ggml_context_ptr ctx_ptr { ggml_init(params) };
1447+
if (!ctx_ptr) {
1448+
throw std::runtime_error("failed to create ggml context");
1449+
}
1450+
ggml_context * ctx = ctx_ptr.get();
14541451

1455-
ggml_tensor * op_tensor = nullptr;
1452+
ggml_tensor * op_tensor = nullptr;
1453+
1454+
int64_t n_ctx = hparams.n_audio_ctx;
14561455

1457-
int64_t n_ctx = hparams.n_audio_ctx;
1456+
switch (op) {
1457+
// The current extra_buffer_type implementations only support GGML_OP_MUL_MAT & GGML_OP_GET_ROWS (q4_0)
1458+
case GGML_OP_MUL_MAT: {
14581459
ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], n_ctx, w->ne[2], w->ne[3]);
14591460
op_tensor = ggml_mul_mat(ctx, w, b);
14601461

@@ -1466,6 +1467,18 @@ static bool weight_buft_supported(const whisper_hparams & hparams, ggml_tensor *
14661467
w->buffer = nullptr;
14671468
break;
14681469
}
1470+
case GGML_OP_GET_ROWS: {
1471+
ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_ctx);
1472+
op_tensor = ggml_get_rows(ctx, w, b);
1473+
1474+
// create a temporary dummy buffer for the weight so that supports_op can check the buffer type
1475+
GGML_ASSERT(w->buffer == nullptr);
1476+
w->buffer = ggml_backend_buft_alloc_buffer(buft, 0);
1477+
op_supported = ggml_backend_dev_supports_op(dev, op_tensor);
1478+
ggml_backend_buffer_free(w->buffer);
1479+
w->buffer = nullptr;
1480+
break;
1481+
}
14691482
default: {
14701483
op_supported = false;
14711484
break;

0 commit comments

Comments
 (0)