@@ -6046,6 +6046,9 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
6046
6046
size = GGML_PAD (size, sizeof (int64_t )); // + padding for next bloc.
6047
6047
size += sizeof (int64_t ) * (1 +op->src [0 ]->ne [2 ]) * op->src [1 ]->ne [2 ];
6048
6048
return true ;
6049
+ case GGML_OP_GET_ROWS:
6050
+ size = 0 ; // GET_ROWS (standard and repacked) doesn't need a work buffer
6051
+ return true ;
6049
6052
default :
6050
6053
// GGML_ABORT("fatal error");
6051
6054
break ;
@@ -6061,13 +6064,142 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
6061
6064
case GGML_OP_MUL_MAT_ID:
6062
6065
forward_mul_mat_id (params, op);
6063
6066
return true ;
6067
+ case GGML_OP_GET_ROWS:
6068
+ forward_get_rows (params, op);
6069
+ return true ;
6064
6070
default :
6065
6071
// GGML_ABORT("fatal error");
6066
6072
break ;
6067
6073
}
6068
6074
return false ;
6069
6075
}
6070
6076
6077
+ void forward_get_rows (const ggml_compute_params *params,
6078
+ ggml_tensor *dst) {
6079
+ const ggml_tensor *src0 = dst->src [0 ];
6080
+
6081
+ switch (src0->type ) {
6082
+ case GGML_TYPE_Q4_0: {
6083
+ ggml_compute_forward_get_rows_q4_0x8 (params, dst);
6084
+ } break ;
6085
+ default :
6086
+ GGML_ABORT (" fatal error" );
6087
+ break ;
6088
+ }
6089
+ }
6090
+
6091
+ static void ggml_compute_forward_get_rows_q4_0x8 (
6092
+ const ggml_compute_params *params,
6093
+ ggml_tensor *dst) {
6094
+ const ggml_tensor *src0 = dst->src [0 ];
6095
+ const ggml_tensor *src1 = dst->src [1 ];
6096
+
6097
+ GGML_TENSOR_BINARY_OP_LOCALS
6098
+
6099
+ const int64_t nc = ne00;
6100
+ const int64_t nr = ggml_nelements (src1);
6101
+
6102
+ assert (ne0 == nc);
6103
+ assert (ne02 == ne11);
6104
+ assert (nb00 == ggml_type_size (src0->type ));
6105
+ assert (ggml_nrows (dst) == nr);
6106
+
6107
+ const int ith = params->ith ;
6108
+ const int nth = params->nth ;
6109
+
6110
+ // rows per thread
6111
+ const int dr = (nr + nth - 1 ) / nth;
6112
+
6113
+ // row range for this thread
6114
+ const int ir0 = dr * ith;
6115
+ const int ir1 = MIN (ir0 + dr, nr);
6116
+
6117
+ constexpr int nrows_interleaved = 8 ;
6118
+ const size_t sizeof_one_repacked_block = sizeof (block_q4_0x8);
6119
+
6120
+ const int num_repacked_blocks_per_row_width = nc / QK4_0;
6121
+
6122
+ const size_t stride_between_actual_row_groups = num_repacked_blocks_per_row_width * sizeof_one_repacked_block;
6123
+
6124
+ for (int64_t i = ir0; i < ir1; ++i) {
6125
+ const int64_t i12 = i / (ne11 * ne10);
6126
+ const int64_t i11 = (i - i12 * ne11 * ne10) / ne10;
6127
+ const int64_t i10 = (i - i12 * ne11 * ne10 - i11 * ne10);
6128
+ const int64_t i01 = *(int32_t *)((char *)src1->data + i10 * nb10 + i11 * nb11 + i12 * nb12); // original logical row
6129
+
6130
+ GGML_ASSERT (i01 >= 0 && i01 < ne01);
6131
+
6132
+ int row_group_idx = i01 / nrows_interleaved;
6133
+ const int row_idx_in_group = i01 % nrows_interleaved;
6134
+
6135
+ const char *base_ptr_for_higher_dims_in_src0 = (const char *)src0->data + i11 * nb02 + i12 * nb03;
6136
+
6137
+ // Pointer to the first block_q4_0x8 of the identified row_group_idx
6138
+ const block_q4_0x8 *p_first_repacked_block_of_group_x8 = (const block_q4_0x8 *)(base_ptr_for_higher_dims_in_src0 + row_group_idx * stride_between_actual_row_groups);
6139
+
6140
+ dequantize_row_q4_0x8 (
6141
+ p_first_repacked_block_of_group_x8,
6142
+ (float *)((char *)dst->data + i10 * nb1 + i11 * nb2 + i12 * nb3), nc, row_idx_in_group);
6143
+ }
6144
+ }
6145
+
6146
+ /* *
6147
+ * Dequantizes a single logical row from data repacked with quant interleaving.
6148
+ *
6149
+ * @param p_repacked_group_column_blocks Pointer to the start of 'block_q4_0x8' for the row group.
6150
+ * @param y Output buffer for the dequantized float values.
6151
+ * @param k Total number of elements (columns) in the logical row.
6152
+ * @param row_idx_in_group Index (0-7) of the logical row to dequantize.
6153
+ */
6154
+ static void dequantize_row_q4_0x8 (
6155
+ const block_q4_0x8 *GGML_RESTRICT p_repacked_group_column_blocks,
6156
+ float *GGML_RESTRICT y,
6157
+ int64_t k,
6158
+ int row_idx_in_group) {
6159
+ const int GGML_Q4_0_X8_INTERLEAVE_SIZE = 8 ;
6160
+ assert (k % QK4_0 == 0 );
6161
+ assert (row_idx_in_group >= 0 && row_idx_in_group < GGML_Q4_0_X8_INTERLEAVE_SIZE);
6162
+
6163
+ const int nb = k / QK4_0;
6164
+ const int bytes_for_half_elements = (QK4_0 / 2 ) / 2 ;
6165
+
6166
+ const int offset_to_second_half_data = bytes_for_half_elements * GGML_Q4_0_X8_INTERLEAVE_SIZE;
6167
+ const uint64_t xor_mask = 0x8888888888888888ULL ;
6168
+ const int qk4_0_half_elements = QK4_0 / 2 ;
6169
+
6170
+ for (int i = 0 ; i < nb; ++i) {
6171
+ const block_q4_0x8 *current_column_repacked_block = &p_repacked_group_column_blocks[i];
6172
+ const float d_val = GGML_FP16_TO_FP32 (current_column_repacked_block->d [row_idx_in_group]);
6173
+ float *y_curr = y + i * QK4_0;
6174
+
6175
+ const int8_t *qs_first_half_repacked_ptr = &(current_column_repacked_block->qs [row_idx_in_group * bytes_for_half_elements]);
6176
+
6177
+ uint64_t first_half_chunk_u64;
6178
+ memcpy (&first_half_chunk_u64, qs_first_half_repacked_ptr, sizeof (uint64_t ));
6179
+ first_half_chunk_u64 ^= xor_mask; // Reverse the XOR
6180
+ const uint8_t *original_qs_first_half_bytes = (const uint8_t *)&first_half_chunk_u64;
6181
+
6182
+ const int8_t *qs_second_half_repacked_ptr = &(current_column_repacked_block->qs [offset_to_second_half_data + (row_idx_in_group * bytes_for_half_elements)]);
6183
+
6184
+ uint64_t second_half_chunk_u64;
6185
+ memcpy (&second_half_chunk_u64, qs_second_half_repacked_ptr, sizeof (uint64_t ));
6186
+ second_half_chunk_u64 ^= xor_mask; // Reverse the XOR
6187
+ const uint8_t *original_qs_second_half_bytes = (const uint8_t *)&second_half_chunk_u64;
6188
+
6189
+ // dequantizing all QK4_0's for this block.
6190
+ for (int j = 0 ; j < bytes_for_half_elements; ++j) {
6191
+ const uint8_t quant_byte_first = original_qs_first_half_bytes[j];
6192
+ y_curr[j] = ((quant_byte_first & 0x0F ) - 8 ) * d_val;
6193
+ y_curr[j + qk4_0_half_elements] = ((quant_byte_first >> 4 ) - 8 ) * d_val;
6194
+
6195
+ const uint8_t quant_byte_second = original_qs_second_half_bytes[j];
6196
+ const int out_idx_base_second_half = j + bytes_for_half_elements; // Offset for the second set of low nibbles
6197
+ y_curr[out_idx_base_second_half] = ((quant_byte_second & 0x0F ) - 8 ) * d_val;
6198
+ y_curr[out_idx_base_second_half + qk4_0_half_elements] = ((quant_byte_second >> 4 ) - 8 ) * d_val;
6199
+ }
6200
+ }
6201
+ }
6202
+
6071
6203
void forward_mul_mat (ggml_compute_params * params, ggml_tensor * op) {
6072
6204
const ggml_tensor * src0 = op->src [0 ];
6073
6205
const ggml_tensor * src1 = op->src [1 ];
@@ -6398,12 +6530,23 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type {
6398
6530
// if (op->src[1]->type == GGML_TYPE_Q8_0) {
6399
6531
// return true;
6400
6532
// }
6533
+ } else if (op->op == GGML_OP_GET_ROWS
6534
+ && op->src [0 ]->buffer
6535
+ && (ggml_n_dims (op->src [0 ]) == 2 )
6536
+ && op->src [0 ]->buffer ->buft == ggml_backend_cpu_aarch64_buffer_type ()
6537
+ && ggml_aarch64_get_optimal_repack_type (op->src [0 ])) {
6538
+ if (op->src [1 ]->buffer && !ggml_backend_buft_is_host (op->src [1 ]->buffer ->buft )) {
6539
+ return false ;
6540
+ }
6541
+ if (op->src [0 ]->type == GGML_TYPE_Q4_0) {
6542
+ return true ;
6543
+ }
6401
6544
}
6402
6545
return false ;
6403
6546
}
6404
6547
6405
6548
ggml::cpu::tensor_traits * get_tensor_traits (const struct ggml_tensor * op) override {
6406
- if (op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_MUL_MAT_ID) {
6549
+ if (op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_MUL_MAT_ID || op-> op == GGML_OP_GET_ROWS ) {
6407
6550
if (op->src [0 ]->buffer && op->src [0 ]->buffer ->buft == ggml_backend_cpu_aarch64_buffer_type ()) {
6408
6551
return (ggml::cpu::tensor_traits *) op->src [0 ]->extra ;
6409
6552
}
0 commit comments