@@ -1181,6 +1181,9 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
1181
1181
size = GGML_PAD (size, sizeof (int64_t )); // + padding for next bloc.
1182
1182
size += sizeof (int64_t ) * (1 +op->src [0 ]->ne [2 ]) * op->src [1 ]->ne [2 ];
1183
1183
return true ;
1184
+ case GGML_OP_GET_ROWS:
1185
+ size = 0 ; // GET_ROWS (standard and repacked) doesn't need a work buffer
1186
+ return true ;
1184
1187
default :
1185
1188
// GGML_ABORT("fatal error");
1186
1189
break ;
@@ -1196,6 +1199,9 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
1196
1199
case GGML_OP_MUL_MAT_ID:
1197
1200
forward_mul_mat_id (params, op);
1198
1201
return true ;
1202
+ case GGML_OP_GET_ROWS:
1203
+ forward_get_rows (params, op);
1204
+ return true ;
1199
1205
default :
1200
1206
// GGML_ABORT("fatal error");
1201
1207
break ;
@@ -1401,6 +1407,132 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
1401
1407
#undef MMID_MATRIX_ROW
1402
1408
}
1403
1409
1410
+ void forward_get_rows (const ggml_compute_params * params,
1411
+ ggml_tensor * dst) {
1412
+ const ggml_tensor * src0 = dst->src [0 ];
1413
+
1414
+ switch (src0->type ) {
1415
+ case GGML_TYPE_Q4_0: {
1416
+ ggml_compute_forward_get_rows_q4_0x8 (params, dst);
1417
+ } break ;
1418
+ default :
1419
+ GGML_ABORT (" fatal error" );
1420
+ break ;
1421
+ }
1422
+ }
1423
+
1424
+ static void ggml_compute_forward_get_rows_q4_0x8 (
1425
+ const ggml_compute_params * params,
1426
+ ggml_tensor * dst) {
1427
+ const ggml_tensor * src0 = dst->src [0 ];
1428
+ const ggml_tensor * src1 = dst->src [1 ];
1429
+
1430
+ GGML_TENSOR_BINARY_OP_LOCALS
1431
+
1432
+ const int64_t nc = ne00;
1433
+ const int64_t nr = ggml_nelements (src1);
1434
+
1435
+ assert (ne0 == nc);
1436
+ assert (ne02 == ne11);
1437
+ assert (nb00 == ggml_type_size (src0->type ));
1438
+ assert (ggml_nrows (dst) == nr);
1439
+
1440
+ const int ith = params->ith ;
1441
+ const int nth = params->nth ;
1442
+
1443
+ // rows per thread
1444
+ const int dr = (nr + nth - 1 ) / nth;
1445
+
1446
+ // row range for this thread
1447
+ const int ir0 = dr * ith;
1448
+ const int ir1 = MIN (ir0 + dr, nr);
1449
+
1450
+ constexpr int nrows_interleaved = 8 ;
1451
+ const size_t sizeof_one_repacked_block = sizeof (block_q4_0x8);
1452
+
1453
+ const int num_repacked_blocks_per_row_width = nc / QK4_0;
1454
+
1455
+ const size_t stride_between_actual_row_groups = num_repacked_blocks_per_row_width * sizeof_one_repacked_block;
1456
+
1457
+ for (int64_t i = ir0; i < ir1; ++i) {
1458
+ const int64_t i12 = i / (ne11 * ne10);
1459
+ const int64_t i11 = (i - i12 * ne11 * ne10) / ne10;
1460
+ const int64_t i10 = (i - i12 * ne11 * ne10 - i11 * ne10);
1461
+ const int64_t i01 = *(int32_t *)((char *)src1->data + i10 * nb10 + i11 * nb11 + i12 * nb12); // original logical row
1462
+
1463
+ GGML_ASSERT (i01 >= 0 && i01 < ne01);
1464
+
1465
+ int row_group_idx = i01 / nrows_interleaved;
1466
+ const int row_idx_in_group = i01 % nrows_interleaved;
1467
+
1468
+ const char * base_ptr_for_higher_dims_in_src0 = (const char *)src0->data + i11 * nb02 + i12 * nb03;
1469
+
1470
+ // Pointer to the first block_q4_0x8 of the identified row_group_idx
1471
+ const block_q4_0x8 * p_first_repacked_block_of_group_x8 = (const block_q4_0x8 *)(base_ptr_for_higher_dims_in_src0 + row_group_idx * stride_between_actual_row_groups);
1472
+
1473
+ dequantize_row_q4_0x8 (
1474
+ p_first_repacked_block_of_group_x8,
1475
+ (float *)((char *)dst->data + i10 * nb1 + i11 * nb2 + i12 * nb3), nc, row_idx_in_group);
1476
+ }
1477
+ }
1478
+
1479
+ /* *
1480
+ * Dequantizes a single logical row from data repacked with quant interleaving.
1481
+ *
1482
+ * @param p_repacked_group_column_blocks Pointer to the start of 'block_q4_0x8' for the row group.
1483
+ * @param y Output buffer for the dequantized float values.
1484
+ * @param k Total number of elements (columns) in the logical row.
1485
+ * @param row_idx_in_group Index (0-7) of the logical row to dequantize.
1486
+ */
1487
+ static void dequantize_row_q4_0x8 (
1488
+ const block_q4_0x8 * GGML_RESTRICT p_repacked_group_column_blocks,
1489
+ float * GGML_RESTRICT y,
1490
+ int64_t k,
1491
+ int row_idx_in_group) {
1492
+ const int GGML_Q4_0_X8_INTERLEAVE_SIZE = 8 ;
1493
+ assert (k % QK4_0 == 0 );
1494
+ assert (row_idx_in_group >= 0 && row_idx_in_group < GGML_Q4_0_X8_INTERLEAVE_SIZE);
1495
+
1496
+ const int nb = k / QK4_0;
1497
+ const int bytes_for_half_elements = (QK4_0 / 2 ) / 2 ;
1498
+
1499
+ const int offset_to_second_half_data = bytes_for_half_elements * GGML_Q4_0_X8_INTERLEAVE_SIZE;
1500
+ const uint64_t xor_mask = 0x8888888888888888ULL ;
1501
+ const int qk4_0_half_elements = QK4_0 / 2 ;
1502
+
1503
+ for (int i = 0 ; i < nb; ++i) {
1504
+ const block_q4_0x8 * current_column_repacked_block = &p_repacked_group_column_blocks[i];
1505
+ const float d_val = GGML_FP16_TO_FP32 (current_column_repacked_block->d [row_idx_in_group]);
1506
+ float * y_curr = y + i * QK4_0;
1507
+
1508
+ const int8_t * qs_first_half_repacked_ptr = &(current_column_repacked_block->qs [row_idx_in_group * bytes_for_half_elements]);
1509
+
1510
+ uint64_t first_half_chunk_u64;
1511
+ memcpy (&first_half_chunk_u64, qs_first_half_repacked_ptr, sizeof (uint64_t ));
1512
+ first_half_chunk_u64 ^= xor_mask; // Reverse the XOR
1513
+ const uint8_t * original_qs_first_half_bytes = (const uint8_t *)&first_half_chunk_u64;
1514
+
1515
+ const int8_t * qs_second_half_repacked_ptr = &(current_column_repacked_block->qs [offset_to_second_half_data + (row_idx_in_group * bytes_for_half_elements)]);
1516
+
1517
+ uint64_t second_half_chunk_u64;
1518
+ memcpy (&second_half_chunk_u64, qs_second_half_repacked_ptr, sizeof (uint64_t ));
1519
+ second_half_chunk_u64 ^= xor_mask; // Reverse the XOR
1520
+ const uint8_t * original_qs_second_half_bytes = (const uint8_t *)&second_half_chunk_u64;
1521
+
1522
+ // dequantizing all QK4_0's for this block.
1523
+ for (int j = 0 ; j < bytes_for_half_elements; ++j) {
1524
+ const uint8_t quant_byte_first = original_qs_first_half_bytes[j];
1525
+ y_curr[j] = ((quant_byte_first & 0x0F ) - 8 ) * d_val;
1526
+ y_curr[j + qk4_0_half_elements] = ((quant_byte_first >> 4 ) - 8 ) * d_val;
1527
+
1528
+ const uint8_t quant_byte_second = original_qs_second_half_bytes[j];
1529
+ const int out_idx_base_second_half = j + bytes_for_half_elements; // Offset for the second set of low nibbles
1530
+ y_curr[out_idx_base_second_half] = ((quant_byte_second & 0x0F ) - 8 ) * d_val;
1531
+ y_curr[out_idx_base_second_half + qk4_0_half_elements] = ((quant_byte_second >> 4 ) - 8 ) * d_val;
1532
+ }
1533
+ }
1534
+ }
1535
+
1404
1536
int repack (struct ggml_tensor * t, const void * data, size_t data_size) override {
1405
1537
GGML_LOG_DEBUG (" %s: repack tensor %s with %s_%dx%d\n " , __func__, t->name , ggml_type_name (t->type ),
1406
1538
(int ) NB_COLS, (int ) INTER_SIZE);
@@ -1533,12 +1665,23 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type {
1533
1665
// if (op->src[1]->type == GGML_TYPE_Q8_0) {
1534
1666
// return true;
1535
1667
// }
1668
+ } else if (op->op == GGML_OP_GET_ROWS
1669
+ && op->src [0 ]->buffer
1670
+ && (ggml_n_dims (op->src [0 ]) == 2 )
1671
+ && op->src [0 ]->buffer ->buft == ggml_backend_cpu_repack_buffer_type ()
1672
+ && ggml_repack_get_optimal_repack_type (op->src [0 ])) {
1673
+ if (op->src [1 ]->buffer && !ggml_backend_buft_is_host (op->src [1 ]->buffer ->buft )) {
1674
+ return false ;
1675
+ }
1676
+ if (op->src [0 ]->type == GGML_TYPE_Q4_0) {
1677
+ return true ;
1678
+ }
1536
1679
}
1537
1680
return false ;
1538
1681
}
1539
1682
1540
1683
ggml::cpu::tensor_traits * get_tensor_traits (const struct ggml_tensor * op) override {
1541
- if (op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_MUL_MAT_ID) {
1684
+ if (op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_MUL_MAT_ID || op-> op == GGML_OP_GET_ROWS ) {
1542
1685
if (op->src [0 ]->buffer && op->src [0 ]->buffer ->buft == ggml_backend_cpu_repack_buffer_type ()) {
1543
1686
return (ggml::cpu::tensor_traits *) op->src [0 ]->extra ;
1544
1687
}
0 commit comments