diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index 5bc773a95..e0eea7704 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -39,6 +39,8 @@ #include "ggml-cuda/conv-transpose-1d.cuh" #include "ggml-cuda/add-id.cuh" #include "ggml-cuda/graph.cuh" +#include "ggml-cuda/mmq_id.cuh" +#include "ggml-cuda/quantize_id.cuh" #include #include @@ -2316,7 +2318,7 @@ static inline bool prepare_row_mappigs(ggml_backend_cuda_context& ctx, int64_t n CUDA_CHECK(cudaMemcpyAsync(dev_row_mapping.get(), rmapping.data(), cum_moe_counts[n_as]*sizeof(mmid_row_mapping), cudaMemcpyHostToDevice, stream)); - CUDA_CHECK(cudaStreamSynchronize(stream)); + //CUDA_CHECK(cudaStreamSynchronize(stream)); return is_ser; } @@ -2392,6 +2394,11 @@ static bool ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * } } + if (ggml_is_quantized(src0->type) && ggml_cuda_can_use_mmq_id(src0->type, ggml_cuda_info().devices[ctx.device].cc, src1->ne[2])) { + ggml_cuda_mul_mat_q_id(ctx, src0, src1, ids, dst, nullptr, nullptr); + return false; + } + GGML_TENSOR_BINARY_OP_LOCALS GGML_ASSERT(!ggml_backend_buffer_is_cuda_split(src0->buffer) && "mul_mat_id does not support split buffers"); @@ -2662,16 +2669,85 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor } } - GGML_TENSOR_BINARY_OP_LOCALS - GGML_ASSERT(!ggml_backend_buffer_is_cuda_split(src0_1->buffer) && "mul_mat_id does not support split buffers"); GGML_ASSERT(!ggml_backend_buffer_is_cuda_split(src0_2->buffer) && "mul_mat_id does not support split buffers"); + GGML_TENSOR_BINARY_OP_LOCALS + cudaStream_t stream = ctx.stream(); const int64_t n_as = ne02; const int64_t n_ids = ids->ne[0]; + ggml_tensor dst_row = *dst; + + if (src1->ne[2] <= 2048 && // TODO: this depends on number of total vs number of active experts -> need to find optimum threshod + ggml_is_quantized(src0_1->type) && src0_1->type == src0_2->type && src1->ne[1] == 1 && src1->ne[3] == 1 && + ggml_cuda_can_use_mmq_id(src0_1->type, ggml_cuda_info().devices[ctx.device].cc, src1->ne[2])) { + + const int64_t ne_get_rows = ne12 * n_ids; + ggml_cuda_pool_alloc ids_device(ctx.pool(), ne_get_rows + ne_get_rows + n_as + 1); + auto ids_src1 = ids_device.get(); + auto ids_dst = ids_src1 + ne_get_rows; + auto expert_bounds = ids_dst + ne_get_rows; + + compute_row_ids((const int32_t *)ids->data, ids_src1, ids_dst, expert_bounds, + ne02, ne12, n_ids, ne11, nb11, nb12, ids->nb[1], stream); + + const int64_t ne11_flat = ne12*n_ids; + const int64_t ne10_padded = GGML_PAD(ne10, MATRIX_ROW_PADDING); + size_t nbytes_src1_q8_1 = ne11_flat*ne10_padded * sizeof(block_q8_1)/QK8_1 + + get_mmq_x_max_host(ggml_cuda_info().devices[ctx.device].cc)*sizeof(block_q8_1_mmq); + ggml_cuda_pool_alloc src1_quantized(ctx.pool(), nbytes_src1_q8_1); + + size_t ts_src1 = ggml_type_size(src1->type); + quantize_mmq_q8_1_cuda_id((const float *)src1->data, ids_src1, src1_quantized.get(), + src0_1->type, ne10, src1->nb[1] / ts_src1, src1->nb[2] / ts_src1, src1->nb[2] / ts_src1, + ne10_padded, ne11_flat, 1, 1, stream); + + ggml_cuda_pool_alloc dst_up_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst)); + ggml_cuda_pool_alloc dst_gate_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst)); + + dst_row.data = dst_up_contiguous.get(); + ggml_cuda_mul_mat_q_id(ctx, src0_1, src1, ids, &dst_row, (char *)ids_device.get(), src1_quantized.get()); + if (dst->src[4]) { + ggml_cuda_add_id((const float *)dst_row.data, (const float *)dst->src[4]->data, (const int32_t *)ids->data, + (float *)dst_row.data, dst_row.ne[0], dst_row.ne[1], dst_row.ne[2], dst_row.ne[0], dst_row.ne[1], + dst_row.nb[1], dst_row.nb[2], dst->src[4]->nb[1], ids->nb[1], stream); + CUDA_CHECK(cudaGetLastError()); + } + + dst_row.data = dst_gate_contiguous.get(); + ggml_cuda_mul_mat_q_id(ctx, src0_2, src1, ids, &dst_row, (char *)ids_device.get(), src1_quantized.get()); + if (dst->src[5]) { + ggml_cuda_add_id((const float *)dst_row.data, (const float *)dst->src[5]->data, (const int32_t *)ids->data, + (float *)dst_row.data, dst_row.ne[0], dst_row.ne[1], dst_row.ne[2], dst_row.ne[0], dst_row.ne[1], + dst_row.nb[1], dst_row.nb[2], dst->src[4]->nb[1], ids->nb[1], stream); + CUDA_CHECK(cudaGetLastError()); + } + + auto unary_op = (ggml_unary_op)dst->op_params[0]; + if (unary_op == GGML_UNARY_OP_SWIGLU_OAI) { + ggml_swiglu_oai_cuda_f32((const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), + (float *)dst->data, ggml_nelements(dst), dst_row.ne[0], dst_row.ne[0], dst_row.ne[0], + 1.702f, 7.0f, stream); + } else { + ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], ggml_nelements(&dst_row), + (const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), + (float *)dst->data); + } + CUDA_CHECK(cudaGetLastError()); + + if (next && next->op == GGML_OP_MUL_MAT_ID && ggml_is_quantized(next->src[0]->type) && + ggml_cuda_should_use_mmq(next->src[0]->type, ggml_cuda_info().devices[ctx.device].cc, src1->ne[2])) { + //ggml_cuda_mul_mat_q_id(ctx, next->src[0], dst, ids, next, (char *)ids_device.get(), nullptr); + ggml_cuda_mul_mat_q_id(ctx, next->src[0], dst, ids, next, nullptr, nullptr); + return true; + } + + return false; + } + std::vector ids_host(ggml_nbytes(ids)); const char * ids_dev = (const char *) ids->data; CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream)); @@ -2680,7 +2756,6 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor ggml_tensor src0_1_row = *src0_1; ggml_tensor src0_2_row = *src0_2; ggml_tensor src1_row = *src1; - ggml_tensor dst_row = *dst; ggml_tensor final_dst; ggml_tensor final_src; @@ -2723,222 +2798,170 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor final_src.nb[3] = final_src.nb[2]; } - if (false && ne12 == 1) { - ggml_cuda_pool_alloc dst_up_contiguous(ctx.pool(), sizeof(float)*dst_row.ne[0]); - ggml_cuda_pool_alloc dst_gate_contiguous(ctx.pool(), sizeof(float)*dst_row.ne[0]); - if (fuse_down) { - final_dst.src[1] = &dst_row; + ggml_cuda_pool_alloc src1_quantized(ctx.pool()); + bool use_quantized_src1 = false; + int64_t src1_padded_num_cols = 0, src1_padded_row_size = 0, src1_quantized_size = 0; + if (ggml_is_quantized(src0_1->type) && src0_1->type == src0_2->type && src1->ne[1] == 1 && src1->ne[3] == 1) { + if (ggml_cuda_should_use_mmq(src0_1->type, ggml_cuda_info().devices[ctx.device].cc, src1->ne[2])) { + src1_padded_num_cols = GGML_PAD(src1->ne[0], MATRIX_ROW_PADDING); + src1_padded_row_size = src1_padded_num_cols/ggml_blck_size(GGML_TYPE_Q8_1)*ggml_type_size(GGML_TYPE_Q8_1); + src1_quantized_size = src1_padded_row_size*src1->ne[2] + get_mmq_x_max_host(ggml_cuda_info().devices[ctx.device].cc)*sizeof(block_q8_1_mmq); + src1_quantized.alloc(src1_quantized_size); + use_quantized_src1 = true; } - for (int64_t id = 0; id < n_ids; id++) { - const int32_t i02 = *(const int32_t *) (ids_host.data() + id*ids->nb[0]); - - if (i02 < 0 || i02 >= n_as) continue; - //GGML_ASSERT(i02 >= 0 && i02 < n_as); - - const int64_t i11 = id % ne11; - const int64_t i12 = 0; - - const int64_t i1 = id; - const int64_t i2 = i12; - - src0_1_row.data = src0_1_original + i02*nb02; - src0_2_row.data = src0_2_original + i02*nb02; - src1_row.data = src1_original + i11*nb11 + i12*nb12; - //dst_row.data = dst_original + i1*nb1 + i2*nb2; - - dst_row.data = dst_up_contiguous.get(); - ggml_cuda_mul_mat(ctx, &src0_1_row, &src1_row, &dst_row); - CUDA_CHECK(cudaGetLastError()); - - dst_row.data = dst_gate_contiguous.get(); - ggml_cuda_mul_mat(ctx, &src0_2_row, &src1_row, &dst_row); - CUDA_CHECK(cudaGetLastError()); + } + ggml_cuda_pool_alloc src1_contiguous(ctx.pool()); + if (!use_quantized_src1) { + src1_contiguous.alloc(sizeof(float)*ggml_nelements(src1)); + } + ggml_cuda_pool_alloc dst_up_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst)); + ggml_cuda_pool_alloc dst_gate_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst)); + ggml_cuda_pool_alloc final_dst_contiguous(ctx.pool()); + if (fuse_down) { + final_dst.data = final_dst_contiguous.alloc(ggml_nelements(next)); + final_dst.src[1] = &dst_row; + } - if (fuse_down) { - ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], dst_row.ne[0], - (const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)dst_gate_contiguous.get()); - CUDA_CHECK(cudaGetLastError()); + src1_row.data = src1_contiguous.get(); - final_src.data = (char *)next->src[0]->data + i02*next->src[0]->nb[2]; - final_dst.data = (char *)next->data + i1*next->nb[1] + i2*next->nb[2]; - ggml_cuda_mul_mat(ctx, &final_src, &dst_row, &final_dst); - CUDA_CHECK(cudaGetLastError()); + bool first = false; //true; - } else { + ggml_cuda_pool_alloc dev_row_mapping(ctx.pool()); + std::vector moe_counts, cum_moe_counts; - ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], dst_row.ne[0], - (const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)(dst_original + i1*nb1 + i2*nb2)); - CUDA_CHECK(cudaGetLastError()); - - } - } - } else { - //printf("ne10 = %ld, ne11 = %ld, ne12 = %ld, nb10 = %zu nb11 = %zu nb12 = %zu\n", src1->ne[0], src1->ne[1], src1->ne[2], src1->nb[0], src1->nb[1], src1->nb[2]); - ggml_cuda_pool_alloc src1_quantized(ctx.pool()); - bool use_quantized_src1 = false; - int64_t src1_padded_num_cols = 0, src1_padded_row_size = 0, src1_quantized_size = 0; - if (ggml_is_quantized(src0_1->type) && src0_1->type == src0_2->type && src1->ne[1] == 1 && src1->ne[3] == 1) { - if (ggml_cuda_should_use_mmq(src0_1->type, ggml_cuda_info().devices[ctx.device].cc, src1->ne[2])) { - src1_padded_num_cols = GGML_PAD(src1->ne[0], MATRIX_ROW_PADDING); - src1_padded_row_size = src1_padded_num_cols/ggml_blck_size(GGML_TYPE_Q8_1)*ggml_type_size(GGML_TYPE_Q8_1); - src1_quantized_size = src1_padded_row_size*src1->ne[2] + get_mmq_x_max_host(ggml_cuda_info().devices[ctx.device].cc)*sizeof(block_q8_1_mmq); - src1_quantized.alloc(src1_quantized_size); - use_quantized_src1 = true; - } - } - ggml_cuda_pool_alloc src1_contiguous(ctx.pool()); - if (!use_quantized_src1) { - src1_contiguous.alloc(sizeof(float)*ggml_nelements(src1)); - } - ggml_cuda_pool_alloc dst_up_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst)); - ggml_cuda_pool_alloc dst_gate_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst)); - ggml_cuda_pool_alloc final_dst_contiguous(ctx.pool()); + bool is_ser = prepare_row_mappigs(ctx, n_as, n_ids, ids, moe_counts, cum_moe_counts, dev_row_mapping); + if (is_ser) { if (fuse_down) { - final_dst.data = final_dst_contiguous.alloc(ggml_nelements(next)); - final_dst.src[1] = &dst_row; + CUDA_CHECK(cudaMemsetAsync(next->data, 0, ggml_nbytes(next), stream)); + } else { + CUDA_CHECK(cudaMemsetAsync(dst->data, 0, ggml_nbytes(dst), stream)); } + } - src1_row.data = src1_contiguous.get(); + for (int64_t i02 = 0; i02 < n_as; i02++) { + int64_t num_src1_rows = moe_counts[i02]; - bool first = false; //true; + if (num_src1_rows == 0) continue; + size_t mapping_offset = cum_moe_counts[i02]; - ggml_cuda_pool_alloc dev_row_mapping(ctx.pool()); - std::vector moe_counts, cum_moe_counts; - - bool is_ser = prepare_row_mappigs(ctx, n_as, n_ids, ids, moe_counts, cum_moe_counts, dev_row_mapping); - if (is_ser) { - if (fuse_down) { - CUDA_CHECK(cudaMemsetAsync(next->data, 0, ggml_nbytes(next), stream)); - } else { - CUDA_CHECK(cudaMemsetAsync(dst->data, 0, ggml_nbytes(dst), stream)); - } + if (use_quantized_src1) { + quantize_mmq_q8_1_id_cuda((const float *)src1->data, src1_quantized.get(), (const char *)(dev_row_mapping.get() + mapping_offset), + src1->ne[0], num_src1_rows, src1_padded_num_cols, src0_1->type, stream); + CUDA_CHECK(cudaGetLastError()); + src1_row.data = src1_quantized.get(); + } + else { + dim3 block_dims(std::min((unsigned int)ne10, 768u)); + dim3 grid_dims(num_src1_rows); + k_copy_src_to_contiguous<<>>( + src1_original, src1_contiguous.get(), dev_row_mapping.get() + mapping_offset, ne10, ne11, nb11, nb12); + CUDA_CHECK(cudaGetLastError()); + src1_row.data = src1_contiguous.get(); } - for (int64_t i02 = 0; i02 < n_as; i02++) { - int64_t num_src1_rows = moe_counts[i02]; - - if (num_src1_rows == 0) continue; - size_t mapping_offset = cum_moe_counts[i02]; + src0_1_row.data = src0_1_original + i02*nb02; + src0_2_row.data = src0_2_original + i02*nb02; - if (use_quantized_src1) { - quantize_mmq_q8_1_id_cuda((const float *)src1->data, src1_quantized.get(), (const char *)(dev_row_mapping.get() + mapping_offset), - src1->ne[0], num_src1_rows, src1_padded_num_cols, src0_1->type, stream); - CUDA_CHECK(cudaGetLastError()); - src1_row.data = src1_quantized.get(); - } - else { - dim3 block_dims(std::min((unsigned int)ne10, 768u)); - dim3 grid_dims(num_src1_rows); - k_copy_src_to_contiguous<<>>( - src1_original, src1_contiguous.get(), dev_row_mapping.get() + mapping_offset, ne10, ne11, nb11, nb12); - CUDA_CHECK(cudaGetLastError()); - src1_row.data = src1_contiguous.get(); - } + GGML_ASSERT(nb11 == sizeof(float)*ne10); + GGML_ASSERT(nb1 == sizeof(float)*ne0); - src0_1_row.data = src0_1_original + i02*nb02; - src0_2_row.data = src0_2_original + i02*nb02; - - GGML_ASSERT(nb11 == sizeof(float)*ne10); - GGML_ASSERT(nb1 == sizeof(float)*ne0); + src1_row.ne[1] = num_src1_rows; + src1_row.nb[1] = use_quantized_src1 ? src1_padded_row_size : nb11; + src1_row.nb[2] = num_src1_rows*src1_row.nb[1]; + src1_row.nb[3] = num_src1_rows*src1_row.nb[1]; - src1_row.ne[1] = num_src1_rows; - src1_row.nb[1] = use_quantized_src1 ? src1_padded_row_size : nb11; - src1_row.nb[2] = num_src1_rows*src1_row.nb[1]; - src1_row.nb[3] = num_src1_rows*src1_row.nb[1]; + dst_row.ne[1] = num_src1_rows; + dst_row.nb[1] = nb1; + dst_row.nb[2] = num_src1_rows*nb1; + dst_row.nb[3] = num_src1_rows*nb1; - dst_row.ne[1] = num_src1_rows; - dst_row.nb[1] = nb1; - dst_row.nb[2] = num_src1_rows*nb1; - dst_row.nb[3] = num_src1_rows*nb1; + dst_row.data = dst_up_contiguous.get(); + if (use_quantized_src1) { + ggml_cuda_op_mul_mat_q(ctx, &src0_1_row, &src1_row, &dst_row, (const char *)src0_1_row.data, nullptr, src1_quantized.get(), (float *)dst_row.data, + 0, src0_1_row.ne[1], num_src1_rows, src1_padded_num_cols, stream); + } else { + ggml_cuda_mul_mat(ctx, &src0_1_row, &src1_row, &dst_row); + } + CUDA_CHECK(cudaGetLastError()); - dst_row.data = dst_up_contiguous.get(); - if (use_quantized_src1) { - ggml_cuda_op_mul_mat_q(ctx, &src0_1_row, &src1_row, &dst_row, (const char *)src0_1_row.data, nullptr, src1_quantized.get(), (float *)dst_row.data, - 0, src0_1_row.ne[1], num_src1_rows, src1_padded_num_cols, stream); - } else { - ggml_cuda_mul_mat(ctx, &src0_1_row, &src1_row, &dst_row); - } + if (dst->src[4]) { + dim3 block_dims(std::min(uint32_t(dst_row.ne[0]), 768u)); + dim3 grid_dims(num_src1_rows); + k_quick_add<<>>(dst_row.ne[0], (const float *)dst_row.data, + (const float *)((const char *)dst->src[4]->data + i02*dst->src[4]->nb[1]), (float *)dst_row.data); CUDA_CHECK(cudaGetLastError()); + } - if (dst->src[4]) { - dim3 block_dims(std::min(uint32_t(dst_row.ne[0]), 768u)); - dim3 grid_dims(num_src1_rows); - k_quick_add<<>>(dst_row.ne[0], (const float *)dst_row.data, - (const float *)((const char *)dst->src[4]->data + i02*dst->src[4]->nb[1]), (float *)dst_row.data); - CUDA_CHECK(cudaGetLastError()); - } + dst_row.data = dst_gate_contiguous.get(); + if (use_quantized_src1) { + ggml_cuda_op_mul_mat_q(ctx, &src0_2_row, &src1_row, &dst_row, (const char *)src0_2_row.data, nullptr, src1_quantized.get(), (float *)dst_row.data, + 0, src0_2_row.ne[1], num_src1_rows, src1_padded_num_cols, stream); + } else { + ggml_cuda_mul_mat(ctx, &src0_2_row, &src1_row, &dst_row); + } + CUDA_CHECK(cudaGetLastError()); - dst_row.data = dst_gate_contiguous.get(); - if (use_quantized_src1) { - ggml_cuda_op_mul_mat_q(ctx, &src0_2_row, &src1_row, &dst_row, (const char *)src0_2_row.data, nullptr, src1_quantized.get(), (float *)dst_row.data, - 0, src0_2_row.ne[1], num_src1_rows, src1_padded_num_cols, stream); - } else { - ggml_cuda_mul_mat(ctx, &src0_2_row, &src1_row, &dst_row); - } + if (dst->src[5]) { + dim3 block_dims(std::min(uint32_t(dst_row.ne[0]), 768u)); + dim3 grid_dims(num_src1_rows); + k_quick_add<<>>(dst_row.ne[0], (const float *)dst_row.data, + (const float *)((const char *)dst->src[5]->data + i02*dst->src[5]->nb[1]), (float *)dst_row.data); CUDA_CHECK(cudaGetLastError()); + } - if (dst->src[5]) { - dim3 block_dims(std::min(uint32_t(dst_row.ne[0]), 768u)); - dim3 grid_dims(num_src1_rows); - k_quick_add<<>>(dst_row.ne[0], (const float *)dst_row.data, - (const float *)((const char *)dst->src[5]->data + i02*dst->src[5]->nb[1]), (float *)dst_row.data); - CUDA_CHECK(cudaGetLastError()); - } + auto unary_op = (ggml_unary_op)dst->op_params[0]; + if (unary_op == GGML_UNARY_OP_SWIGLU_OAI) { + ggml_swiglu_oai_cuda_f32((const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), + (float *)dst_gate_contiguous.get(), ggml_nelements(&dst_row), dst_row.ne[0], dst_row.ne[0], dst_row.ne[0], + 1.702f, 7.0f, stream); + } else { + ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], ggml_nelements(&dst_row), + (const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), + (float *)dst_gate_contiguous.get()); + } + CUDA_CHECK(cudaGetLastError()); - auto unary_op = (ggml_unary_op)dst->op_params[0]; - if (unary_op == GGML_UNARY_OP_SWIGLU_OAI) { - ggml_swiglu_oai_cuda_f32((const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), - (float *)dst_gate_contiguous.get(), ggml_nelements(&dst_row), dst_row.ne[0], dst_row.ne[0], dst_row.ne[0], - 1.702f, 7.0f, stream); - } else { - ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], ggml_nelements(&dst_row), - (const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), - (float *)dst_gate_contiguous.get()); + if (fuse_down) { + + final_dst.ne[1] = num_src1_rows; + final_dst.nb[1] = final_dst.ne[0]*sizeof(float); + final_dst.nb[2] = final_dst.nb[3] = num_src1_rows*final_dst.nb[1]; + final_src.data = (char *)next->src[0]->data + i02*next->src[0]->nb[2]; + if (first) { + printf("Fusing down for %d rows: (%d x %d x %d x %d) = (%d x %d x %d x %d) * (%d x %d x %d x %d)\n", (int)num_src1_rows, + (int)next->ne[0], (int)next->ne[1], (int)next->ne[2], (int)next->ne[3], + (int)next->src[0]->ne[0], (int)next->src[0]->ne[1], (int)next->src[0]->ne[2], (int)next->src[0]->ne[3], + (int)next->src[1]->ne[0], (int)next->src[1]->ne[1], (int)next->src[1]->ne[2], (int)next->src[1]->ne[3]); + printf(" using (%d x %d x %d x %d) = (%d x %d x %d x %d) * (%d x %d x %d x %d)\n", + (int)final_dst.ne[0], (int)final_dst.ne[1], (int)final_dst.ne[2], (int)final_dst.ne[3], + (int)final_src.ne[0], (int)final_src.ne[1], (int)final_src.ne[2], (int)final_src.ne[3], + (int)dst_row.ne[0], (int)dst_row.ne[1], (int)dst_row.ne[2], (int)dst_row.ne[3]); + first = false; } + ggml_cuda_mul_mat(ctx, &final_src, &dst_row, &final_dst); + //ggml_cuda_mul_mat(ctx, next->src[0], &dst_row, &final_dst); CUDA_CHECK(cudaGetLastError()); - if (fuse_down) { - - final_dst.ne[1] = num_src1_rows; - final_dst.nb[1] = final_dst.ne[0]*sizeof(float); - final_dst.nb[2] = final_dst.nb[3] = num_src1_rows*final_dst.nb[1]; - final_src.data = (char *)next->src[0]->data + i02*next->src[0]->nb[2]; - if (first) { - printf("Fusing down for %d rows: (%d x %d x %d x %d) = (%d x %d x %d x %d) * (%d x %d x %d x %d)\n", (int)num_src1_rows, - (int)next->ne[0], (int)next->ne[1], (int)next->ne[2], (int)next->ne[3], - (int)next->src[0]->ne[0], (int)next->src[0]->ne[1], (int)next->src[0]->ne[2], (int)next->src[0]->ne[3], - (int)next->src[1]->ne[0], (int)next->src[1]->ne[1], (int)next->src[1]->ne[2], (int)next->src[1]->ne[3]); - printf(" using (%d x %d x %d x %d) = (%d x %d x %d x %d) * (%d x %d x %d x %d)\n", - (int)final_dst.ne[0], (int)final_dst.ne[1], (int)final_dst.ne[2], (int)final_dst.ne[3], - (int)final_src.ne[0], (int)final_src.ne[1], (int)final_src.ne[2], (int)final_src.ne[3], - (int)dst_row.ne[0], (int)dst_row.ne[1], (int)dst_row.ne[2], (int)dst_row.ne[3]); - first = false; - } - ggml_cuda_mul_mat(ctx, &final_src, &dst_row, &final_dst); - //ggml_cuda_mul_mat(ctx, next->src[0], &dst_row, &final_dst); - CUDA_CHECK(cudaGetLastError()); - - dim3 block_dims(std::min((unsigned int)next->ne[0], 768u)); - dim3 grid_dims(num_src1_rows); - k_copy_dst_from_contiguous<<>>( - (char *)next->data, final_dst_contiguous.get(), - dev_row_mapping.get() + mapping_offset, - next->ne[0], - next->nb[1], next->nb[2]); - CUDA_CHECK(cudaGetLastError()); - - } - else { + dim3 block_dims(std::min((unsigned int)next->ne[0], 768u)); + dim3 grid_dims(num_src1_rows); + k_copy_dst_from_contiguous<<>>( + (char *)next->data, final_dst_contiguous.get(), + dev_row_mapping.get() + mapping_offset, + next->ne[0], + next->nb[1], next->nb[2]); + CUDA_CHECK(cudaGetLastError()); - dim3 block_dims(std::min((unsigned int)ne0, 768u)); - dim3 grid_dims(num_src1_rows); - k_copy_dst_from_contiguous<<>>( - dst_original, dst_gate_contiguous.get(), - dev_row_mapping.get() + mapping_offset, - ne0, - nb1, nb2); - CUDA_CHECK(cudaGetLastError()); - } + } + else { + + dim3 block_dims(std::min((unsigned int)ne0, 768u)); + dim3 grid_dims(num_src1_rows); + k_copy_dst_from_contiguous<<>>( + dst_original, dst_gate_contiguous.get(), + dev_row_mapping.get() + mapping_offset, + ne0, + nb1, nb2); + CUDA_CHECK(cudaGetLastError()); } } diff --git a/ggml/src/ggml-cuda/mmq_id.cu b/ggml/src/ggml-cuda/mmq_id.cu new file mode 100644 index 000000000..230715c0e --- /dev/null +++ b/ggml/src/ggml-cuda/mmq_id.cu @@ -0,0 +1,564 @@ +#include "mmq_id_common.cuh" +#include "mmq_id.cuh" +#include "quantize_id.cuh" + +#include +#include +#include + +// To reduce shared memory use, store "it" and "iex_used" with 22/10 bits each. +struct mmq_ids_helper_store { + uint32_t data; + + __device__ mmq_ids_helper_store(const uint32_t it, const uint32_t iex_used) { + data = (it & 0x003FFFFF) | (iex_used << 22); + } + + __device__ uint32_t it() const { + return data & 0x003FFFFF; + } + + __device__ uint32_t iex_used() const { + return data >> 22; + } +}; +static_assert(sizeof(mmq_ids_helper_store) == 4, "unexpected size for mmq_ids_helper_store"); + +// Helper function for mul_mat_id, converts ids to a more convenient format. +// ids_src1 describes how to permute the flattened column indices of src1 in order to get a compact src1 tensor sorted by expert. +// ids_dst describes the same mapping but for the dst tensor. +// The upper and lower bounds for the ith expert in the compact src1 tensor are stored in expert_bounds[i:i+1]. +template +__launch_bounds__(ggml_cuda_get_physical_warp_size(), 1) +static __global__ void mmq_ids_helper( + const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds, + const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1) { + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + const int n_expert_used = n_expert_used_template == 0 ? n_expert_used_var : n_expert_used_template; + const int expert = blockIdx.x; + + extern __shared__ char data_mmq_ids_helper[]; + mmq_ids_helper_store * store = (mmq_ids_helper_store *) data_mmq_ids_helper; + + int nex_prev = 0; // Number of columns for experts with a lower index. + int it_compact = 0; // Running index for the compact slice of this expert. + + if constexpr (n_expert_used_template == 0) { + // Generic implementation: + for (int it = 0; it < n_tokens; ++it) { + int iex_used = -1; // The index at which the expert is used, if any. + for (int iex = threadIdx.x; iex < n_expert_used; iex += warp_size) { + const int expert_used = ids[it*si1 + iex]; + nex_prev += expert_used < expert; + if (expert_used == expert) { + iex_used = iex; + } + } + + if (iex_used != -1) { + store[it_compact] = mmq_ids_helper_store(it, iex_used); + } + + if (warp_reduce_any(iex_used != -1)) { + it_compact++; + } + } + } else { + // Implementation optimized for specific numbers of experts used: + static_assert(n_expert_used == 6 || warp_size % n_expert_used == 0, "bad n_expert_used"); + const int neu_padded = n_expert_used == 6 ? 8 : n_expert_used; // Padded to next higher power of 2. + for (int it0 = 0; it0 < n_tokens; it0 += warp_size/neu_padded) { + const int it = it0 + threadIdx.x / neu_padded; + + const int iex = threadIdx.x % neu_padded; // The index at which the expert is used, if any. + const int expert_used = (neu_padded == n_expert_used || iex < n_expert_used) && it < n_tokens ? + ids[it*si1 + iex] : INT_MAX; + const int iex_used = expert_used == expert ? iex : -1; + nex_prev += expert_used < expert; + + // Whether the threads at this token position have used the expert: + const int it_compact_add_self = warp_reduce_any(iex_used != -1); + + // Do a scan over threads at lower token positions in warp to get the correct index for writing data: + int it_compact_add_lower = 0; +#pragma unroll + for (int offset = neu_padded; offset < warp_size; offset += neu_padded) { + const int tmp = __shfl_up_sync(0xFFFFFFFF, it_compact_add_self, offset, warp_size); + if (threadIdx.x >= offset) { + it_compact_add_lower += tmp; + } + } + + if (iex_used != -1) { + store[it_compact + it_compact_add_lower] = mmq_ids_helper_store(it, iex_used); + } + + // The thread with the highest index in the warp always has the sum over the whole warp, use it to increment all threads: + it_compact += __shfl_sync(0xFFFFFFFF, it_compact_add_lower + it_compact_add_self, warp_size - 1, warp_size); + } + } + nex_prev = warp_reduce_sum(nex_prev); + + for (int itc = threadIdx.x; itc < it_compact; itc += warp_size) { + const mmq_ids_helper_store store_it = store[itc]; + const int it = store_it.it(); + const int iex_used = store_it.iex_used(); + ids_src1[nex_prev + itc] = it*sis1 + iex_used % nchannels_y; + ids_dst [nex_prev + itc] = it*n_expert_used + iex_used; + } + + if (threadIdx.x != 0) { + return; + } + + expert_bounds[expert] = nex_prev; + + if (expert < gridDim.x - 1) { + return; + } + + expert_bounds[gridDim.x] = nex_prev + it_compact; +} + +template +static void launch_mmq_ids_helper( + const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds, + const int n_experts, const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1, cudaStream_t stream) { + GGML_ASSERT(n_tokens < (1 << 22) && "too few bits in mmq_ids_helper_store"); + GGML_ASSERT(n_expert_used_var < (1 << 10) && "too few bits in mmq_ids_helper_store"); + + const int id = ggml_cuda_get_device(); + const int warp_size = ggml_cuda_get_physical_warp_size_host(); //ggml_cuda_info().devices[id].warp_size; + const size_t smpbo = ggml_cuda_info().devices[id].smpbo; + CUDA_SET_SHARED_MEMORY_LIMIT(mmq_ids_helper, smpbo); + + const dim3 num_blocks(n_experts, 1, 1); + const dim3 block_size(warp_size, 1, 1); + const size_t nbytes_shared = n_tokens*sizeof(mmq_ids_helper_store); + mmq_ids_helper<<>> + (ids, ids_src1, ids_dst, expert_bounds, n_tokens, n_expert_used_var, nchannels_y, si1, sis1); +} + +static void ggml_cuda_mul_mat_q_switch_type_id(ggml_backend_cuda_context & ctx, const mmq_args_id & args, cudaStream_t stream) { + switch (args.type_x) { + case GGML_TYPE_Q4_0: + mul_mat_q_case_id(ctx, args, stream); + break; + case GGML_TYPE_Q4_1: + mul_mat_q_case_id(ctx, args, stream); + break; + case GGML_TYPE_Q5_0: + mul_mat_q_case_id(ctx, args, stream); + break; + case GGML_TYPE_Q5_1: + mul_mat_q_case_id(ctx, args, stream); + break; + case GGML_TYPE_Q6_0: + mul_mat_q_case_id(ctx, args, stream); + break; + case GGML_TYPE_Q8_0: + mul_mat_q_case_id(ctx, args, stream); + break; + case GGML_TYPE_MXFP4: + mul_mat_q_case_id(ctx, args, stream); + break; + case GGML_TYPE_Q2_K: + mul_mat_q_case_id(ctx, args, stream); + break; + case GGML_TYPE_Q3_K: + mul_mat_q_case_id(ctx, args, stream); + break; + case GGML_TYPE_Q4_K: + mul_mat_q_case_id(ctx, args, stream); + break; + case GGML_TYPE_Q5_K: + mul_mat_q_case_id(ctx, args, stream); + break; + case GGML_TYPE_Q6_K: + mul_mat_q_case_id(ctx, args, stream); + break; + case GGML_TYPE_IQ2_XXS: + mul_mat_q_case_id(ctx, args, stream); + break; + case GGML_TYPE_IQ2_XS: + mul_mat_q_case_id(ctx, args, stream); + break; + case GGML_TYPE_IQ2_S: + mul_mat_q_case_id(ctx, args, stream); + break; + case GGML_TYPE_IQ3_XXS: + mul_mat_q_case_id(ctx, args, stream); + break; + case GGML_TYPE_IQ3_S: + mul_mat_q_case_id(ctx, args, stream); + break; + case GGML_TYPE_IQ1_S: + mul_mat_q_case_id(ctx, args, stream); + break; + case GGML_TYPE_IQ1_S_R4: + mul_mat_q_case_id(ctx, args, stream); + break; + case GGML_TYPE_IQ4_XS: + mul_mat_q_case_id(ctx, args, stream); + break; + case GGML_TYPE_IQ4_NL: + mul_mat_q_case_id(ctx, args, stream); + break; + case GGML_TYPE_IQ2_KS: + mul_mat_q_case_id(ctx, args, stream); + break; + case GGML_TYPE_IQ2_KL: + mul_mat_q_case_id(ctx, args, stream); + break; + case GGML_TYPE_IQ2_K: + mul_mat_q_case_id(ctx, args, stream); + break; + case GGML_TYPE_IQ2_K_R4: + mul_mat_q_case_id(ctx, args, stream); + break; + case GGML_TYPE_IQ3_K: + mul_mat_q_case_id(ctx, args, stream); + break; + case GGML_TYPE_IQ3_K_R4: + mul_mat_q_case_id(ctx, args, stream); + break; + case GGML_TYPE_IQ3_KS: + mul_mat_q_case_id(ctx, args, stream); + break; + case GGML_TYPE_IQ4_KSS: + mul_mat_q_case_id(ctx, args, stream); + break; + case GGML_TYPE_IQ4_KS: + mul_mat_q_case_id(ctx, args, stream); + break; + case GGML_TYPE_IQ4_KS_R4: + mul_mat_q_case_id(ctx, args, stream); + break; + case GGML_TYPE_IQ4_K: + mul_mat_q_case_id(ctx, args, stream); + break; + case GGML_TYPE_IQ4_K_R4: + mul_mat_q_case_id(ctx, args, stream); + break; + case GGML_TYPE_IQ5_KS: + mul_mat_q_case_id(ctx, args, stream); + break; + case GGML_TYPE_IQ5_KS_R4: + mul_mat_q_case_id(ctx, args, stream); + break; + case GGML_TYPE_IQ5_K: + mul_mat_q_case_id(ctx, args, stream); + break; + case GGML_TYPE_IQ5_K_R4: + mul_mat_q_case_id(ctx, args, stream); + break; + case GGML_TYPE_IQ6_K: + mul_mat_q_case_id(ctx, args, stream); + break; + case GGML_TYPE_IQ1_KT: + mul_mat_q_case_id(ctx, args, stream); + break; + case GGML_TYPE_IQ2_KT: + mul_mat_q_case_id(ctx, args, stream); + break; + case GGML_TYPE_IQ3_KT: + mul_mat_q_case_id(ctx, args, stream); + break; + case GGML_TYPE_IQ4_KT: + mul_mat_q_case_id(ctx, args, stream); + break; + default: + GGML_ABORT("fatal error"); + break; + } +} + +void compute_row_ids(const int32_t * ids, int32_t * ids_src1, int32_t * ids_dst, int32_t * expert_bounds, + int64_t ne02, int64_t ne12, int64_t n_expert_used, int64_t ne11, int64_t nb11, int64_t nb12, int64_t nb21, + cudaStream_t stream) { + + const int si1 = nb21 / sizeof(int); + const int sis1 = nb12 / nb11; + + switch (n_expert_used) { + case 2: + launch_mmq_ids_helper< 2> (ids, ids_src1, ids_dst, expert_bounds, + ne02, ne12, n_expert_used, ne11, si1, sis1, stream); + break; + case 4: + launch_mmq_ids_helper< 4> (ids, ids_src1, ids_dst, expert_bounds, + ne02, ne12, n_expert_used, ne11, si1, sis1, stream); + break; + case 6: + launch_mmq_ids_helper< 6> (ids, ids_src1, ids_dst, expert_bounds, + ne02, ne12, n_expert_used, ne11, si1, sis1, stream); + break; + case 8: + launch_mmq_ids_helper< 8> (ids, ids_src1, ids_dst, expert_bounds, + ne02, ne12, n_expert_used, ne11, si1, sis1, stream); + break; + case 16: + launch_mmq_ids_helper<16> (ids, ids_src1, ids_dst, expert_bounds, + ne02, ne12, n_expert_used, ne11, si1, sis1, stream); + break; + case 32: + launch_mmq_ids_helper<32> (ids, ids_src1, ids_dst, expert_bounds, + ne02, ne12, n_expert_used, ne11, si1, sis1, stream); + break; + default: + launch_mmq_ids_helper< 0> (ids, ids_src1, ids_dst, expert_bounds, + ne02, ne12, n_expert_used, ne11, si1, sis1, stream); + break; + } + CUDA_CHECK(cudaGetLastError()); +} + +void ggml_cuda_mul_mat_q_id(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, + const ggml_tensor * ids_tensor, ggml_tensor * dst, char * ids_data, char * src1_quantized_data) { + GGML_ASSERT( src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + GGML_ASSERT(ids_tensor->type == GGML_TYPE_I32); // Optional, used for batched GGML_MUL_MAT_ID. + + GGML_TENSOR_BINARY_OP_LOCALS; + + cudaStream_t stream = ctx.stream(); + const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; + + //const size_t ts_src0 = ggml_type_size(src0->type); + const size_t ts_src1 = ggml_type_size(src1->type); + const size_t ts_dst = ggml_type_size(dst->type); + + //GGML_ASSERT( nb00 == ts_src0); + GGML_ASSERT( nb10 == ts_src1); + GGML_ASSERT( nb0 == ts_dst); + GGML_ASSERT(ids_tensor->nb[0] == ggml_type_size(ids_tensor->type)); + + GGML_ASSERT(ne13 == 1); + GGML_ASSERT(nb12 % nb11 == 0); + GGML_ASSERT(nb2 % nb1 == 0); + + const char * src0_d = (const char *) src0->data; + const float * src1_d = (const float *) src1->data; + float * dst_d = (float *) dst->data; + + // If src0 is a temporary compute buffer, clear any potential padding. + if (ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE) { + const size_t size_data = ggml_nbytes(src0); + const size_t size_alloc = ggml_backend_buffer_get_alloc_size(src0->buffer, src0); + if (size_alloc > size_data) { + GGML_ASSERT(ggml_is_contiguously_allocated(src0)); + GGML_ASSERT(!src0->view_src); + CUDA_CHECK(cudaMemsetAsync((char *) src0->data + size_data, 0, size_alloc - size_data, stream)); + } + } + + const int64_t ne10_padded = GGML_PAD(ne10, MATRIX_ROW_PADDING); + + const int64_t s01 = src0->nb[1];// / ts_src0; + const int64_t s1 = dst->nb[1] / ts_dst; + const int64_t s02 = src0->nb[2];// / ts_src0; + const int64_t s2 = dst->nb[2] / ts_dst; + const int64_t s03 = src0->nb[3];// / ts_src0; + const int64_t s3 = dst->nb[3] / ts_dst; + + const bool use_stream_k = (GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) + || GGML_CUDA_CC_IS_CDNA(cc); + + const int64_t n_expert_used = ids_tensor->ne[0]; + const int64_t ne_get_rows = ne12 * n_expert_used; + GGML_ASSERT(ne1 == n_expert_used); + + ggml_cuda_pool_alloc ids_src1_local(ctx.pool()); + ggml_cuda_pool_alloc ids_dst_local(ctx.pool()); + ggml_cuda_pool_alloc expert_bounds_local(ctx.pool()); + + int32_t * ids_src1, *ids_dst, *expert_bounds; + if (ids_data) { + ids_src1 = (int32_t *)ids_data; + ids_dst = ids_src1 + ne_get_rows; + expert_bounds = ids_dst + ne_get_rows; + } + else { + GGML_ASSERT(ids_tensor->nb[0] == ggml_element_size(ids_tensor)); + + ids_src1_local.alloc(ne_get_rows); + ids_dst_local.alloc(ne_get_rows); + expert_bounds_local.alloc(ne02 + 1); + + ids_src1 = ids_src1_local.get(); + ids_dst = ids_dst_local.get(); + expert_bounds = expert_bounds_local.get(); + + const int si1 = ids_tensor->nb[1] / ggml_element_size(ids_tensor); + const int sis1 = nb12 / nb11; + + switch (n_expert_used) { + case 2: + launch_mmq_ids_helper< 2> ((const int32_t *) ids_tensor->data, ids_src1, ids_dst, expert_bounds, + ne02, ne12, n_expert_used, ne11, si1, sis1, stream); + break; + case 4: + launch_mmq_ids_helper< 4> ((const int32_t *) ids_tensor->data, ids_src1, ids_dst, expert_bounds, + ne02, ne12, n_expert_used, ne11, si1, sis1, stream); + break; + case 6: + launch_mmq_ids_helper< 6> ((const int32_t *) ids_tensor->data, ids_src1, ids_dst, expert_bounds, + ne02, ne12, n_expert_used, ne11, si1, sis1, stream); + break; + case 8: + launch_mmq_ids_helper< 8> ((const int32_t *) ids_tensor->data, ids_src1, ids_dst, expert_bounds, + ne02, ne12, n_expert_used, ne11, si1, sis1, stream); + break; + case 16: + launch_mmq_ids_helper<16> ((const int32_t *) ids_tensor->data, ids_src1, ids_dst, expert_bounds, + ne02, ne12, n_expert_used, ne11, si1, sis1, stream); + break; + case 32: + launch_mmq_ids_helper<32> ((const int32_t *) ids_tensor->data, ids_src1, ids_dst, expert_bounds, + ne02, ne12, n_expert_used, ne11, si1, sis1, stream); + break; + default: + launch_mmq_ids_helper< 0> ((const int32_t *) ids_tensor->data, ids_src1, ids_dst, expert_bounds, + ne02, ne12, n_expert_used, ne11, si1, sis1, stream); + break; + } + CUDA_CHECK(cudaGetLastError()); + } + + const int64_t ne11_flat = ne12*n_expert_used; + const int64_t ne12_flat = 1; + const int64_t ne13_flat = 1; + + const size_t nbytes_src1_q8_1 = ne11_flat*ne10_padded * sizeof(block_q8_1)/QK8_1 + + get_mmq_x_max_host(cc)*sizeof(block_q8_1_mmq); + + ggml_cuda_pool_alloc src1_q8_1_local(ctx.pool()); + + char * src1_q8_1; + + if (src1_quantized_data) { + src1_q8_1 = src1_quantized_data; + } else { + + src1_q8_1_local.alloc(nbytes_src1_q8_1); + src1_q8_1 = src1_q8_1_local.get(); + + const int64_t s11 = src1->nb[1] / ts_src1; + const int64_t s12 = src1->nb[2] / ts_src1; + const int64_t s13 = src1->nb[2] / ts_src1; + quantize_mmq_q8_1_cuda_id(src1_d, ids_src1, src1_q8_1, src0->type, + ne10, s11, s12, s13, ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream); + CUDA_CHECK(cudaGetLastError()); + } + + const int64_t s12 = ne11*ne10_padded * sizeof(block_q8_1)/(QK8_1*sizeof(int)); + const int64_t s13 = ne12*s12; + + // Note that ne02 is used instead of ne12 because the number of y channels determines the z dimension of the CUDA grid. + const mmq_args_id args = { + src0_d, src0->type, (const int *) src1_q8_1, ids_dst, expert_bounds, dst_d, + ne00, ne01, ne_get_rows, s01, ne_get_rows, s1, + ne02, ne02, s02, s12, s2, + ne03, ne13, s03, s13, s3, + use_stream_k, ne12}; + + //printf("ne00 = %ld, ne01 = %ld, ne_get_rows = %ld, s01 = %ld, s1 = %ld\n", ne00, ne01, ne_get_rows, s01, s1); + //printf("ne02 = %ld, s02 = %ld, s12 = %ld, s2 = %ld\n", ne02, s02, s12, s2); + //printf("ne03 = %ld, s03 = %ld, s13 = %ld, s3 = %ld\n", ne03, s03, s13, s3); + + ggml_cuda_mul_mat_q_switch_type_id(ctx, args, stream); +} + +bool ggml_cuda_can_use_mmq_id(enum ggml_type type, int cc, int64_t ne11) { + bool mmq_supported; + + switch (type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q6_0: + case GGML_TYPE_Q8_0: + case GGML_TYPE_MXFP4: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_S_R4: + case GGML_TYPE_IQ4_XS: + case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ2_KS: + case GGML_TYPE_IQ2_KL: + case GGML_TYPE_IQ2_K: + case GGML_TYPE_IQ2_K_R4: + case GGML_TYPE_IQ3_KS: + case GGML_TYPE_IQ3_K: + case GGML_TYPE_IQ3_K_R4: + case GGML_TYPE_IQ4_KSS: + case GGML_TYPE_IQ4_KS: + case GGML_TYPE_IQ4_KS_R4: + case GGML_TYPE_IQ4_K: + case GGML_TYPE_IQ4_K_R4: + case GGML_TYPE_IQ5_KS: + case GGML_TYPE_IQ5_KS_R4: + case GGML_TYPE_IQ5_K: + case GGML_TYPE_IQ5_K_R4: + case GGML_TYPE_IQ6_K: + case GGML_TYPE_IQ1_KT: + case GGML_TYPE_IQ2_KT: + case GGML_TYPE_IQ3_KT: + case GGML_TYPE_IQ4_KT: + mmq_supported = true; + break; + default: + mmq_supported = false; + break; + } + + if (!mmq_supported) { + return false; + } + + if (turing_mma_available(cc)) { + return true; + } + + if (ggml_cuda_highest_compiled_arch(cc) < GGML_CUDA_CC_DP4A) { + return false; + } + +#ifdef GGML_CUDA_FORCE_MMQ + return true; +#endif //GGML_CUDA_FORCE_MMQ + + if (GGML_CUDA_CC_IS_NVIDIA(cc)) { + return !fp16_mma_hardware_available(cc) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE; + } + + if (amd_mfma_available(cc)) { + // As of ROCM 7.0 rocblas/tensile performs very poorly on CDNA3 and hipblaslt (via ROCBLAS_USE_HIPBLASLT) + // performs better but is currently suffering from a crash on this architecture. + // TODO: Revisit when hipblaslt is fixed on CDNA3 + if (GGML_CUDA_CC_IS_CDNA3(cc)) { + return true; + } + if (ne11 <= 128 || type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1 || type == GGML_TYPE_Q5_0 + || type == GGML_TYPE_Q5_1 || type == GGML_TYPE_Q6_0) { + return true; + } + if (ne11 <= 256 && (type == GGML_TYPE_Q4_K || type == GGML_TYPE_Q5_K)) { + return true; + } + return false; + } + + return (!GGML_CUDA_CC_IS_RDNA4(cc) && !GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE; + +} diff --git a/ggml/src/ggml-cuda/mmq_id.cuh b/ggml/src/ggml-cuda/mmq_id.cuh new file mode 100644 index 000000000..567393074 --- /dev/null +++ b/ggml/src/ggml-cuda/mmq_id.cuh @@ -0,0 +1,12 @@ +#pragma once + +#include "common.cuh" + +void ggml_cuda_mul_mat_q_id( + ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, + ggml_tensor * dst, char * ids_data, char * src1_quantized_data); + +void compute_row_ids(const int32_t * ids, int32_t * ids_src1, int32_t * ids_dst, int32_t * expert_bounds, + int64_t ne02, int64_t ne12, int64_t n_expert_used, int64_t ne11, int64_t nb11, int64_t nb12, int64_t nb21, cudaStream_t stream); + +bool ggml_cuda_can_use_mmq_id(enum ggml_type type, int cc, int64_t ne11); diff --git a/ggml/src/ggml-cuda/mmq_id_common.cuh b/ggml/src/ggml-cuda/mmq_id_common.cuh new file mode 100644 index 000000000..89baa31b6 --- /dev/null +++ b/ggml/src/ggml-cuda/mmq_id_common.cuh @@ -0,0 +1,4183 @@ +#pragma once + +#include "common.cuh" +#include "mma_new.cuh" +#include "vecdotq.cuh" +#include "iqk_cuda_common.h" + +#include +#include +#include + +using namespace ggml_cuda_mma; + +#define MMQ_DP4A_MAX_BATCH_SIZE 64 // Max. batch size to use for dp4a MMQ kernels when FP16 tensor cores are available. +#define MMQ_ITER_K 256 +#define MMQ_NWARPS 8 + +typedef void (*load_tiles_mmq_t)(const char * __restrict__ x, int * x_tile, const int kbx0, const int i_max, const int stride); +typedef void (*vec_dot_mmq_t)(const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00); +typedef void (*mmq_write_back_t)(const float * __restrict__ sum, const int32_t * __restrict__ get_rows_to_sorted, + float * __restrict__ dst, const int stride, const int i_max, const int j_max); + +enum mmq_q8_1_ds_layout { + MMQ_Q8_1_DS_LAYOUT_D4, + MMQ_Q8_1_DS_LAYOUT_DS4, + MMQ_Q8_1_DS_LAYOUT_D2S6, +}; + +struct block_q8_1_mmq { + // The y float data is converted to a data layout that can simply be copied to shared memory as a contiguous block. + // The y float data is first grouped as blocks of 128 values. + // These blocks are then treated as individual data values and transposed. + // + // To avoid shared memory bank conflicts each block is padded with 16 bytes. + // This padding is also used to store block scales/partial sums. + // The scales multiplied with the quantized data are equal to the unquantized values. + // The partial sums are obtained by summing up a subgroup of the contained values (prior to quantization) + // and are only needed for performance reasons. + // + // The exact data stored depends on the x data type. + union { + float d4[4]; // 1 32 bit scale per 32 values, stored as d0,d1,d2,d3 + half2 ds4[4]; // 1 16 bit scale + 1 16 bit partial sum per 32 values, stored as d0,s0,d1,s1,d2,s2,d3,s3 + half d2s6[8]; // 1 16 bit scale per 64 values + 1 16 bit partial sum per 16 values for the first 96 values, + // stored as d0,d1,s1,s2,s3,s4,s5 + }; + int8_t qs[4*QK8_1]; // 128 values quantized to 8 bit each +}; +static_assert(sizeof(block_q8_1_mmq) == 4*QK8_1 + 4*sizeof(half2), "Unexpected block_q8_1_mmq size"); +static_assert(sizeof(block_q8_1_mmq) == 4*sizeof(block_q8_1), "Unexpected block_q8_1_mmq size"); + +static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) { + switch (type_x) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + return MMQ_Q8_1_DS_LAYOUT_DS4; + case GGML_TYPE_Q5_0: + return MMQ_Q8_1_DS_LAYOUT_D4; + case GGML_TYPE_Q5_1: + return MMQ_Q8_1_DS_LAYOUT_DS4; + case GGML_TYPE_Q6_0: + return MMQ_Q8_1_DS_LAYOUT_D4; + case GGML_TYPE_Q8_0: + return MMQ_Q8_1_DS_LAYOUT_D4; + case GGML_TYPE_MXFP4: + return MMQ_Q8_1_DS_LAYOUT_D4; + case GGML_TYPE_Q2_K: + return MMQ_Q8_1_DS_LAYOUT_D2S6; + case GGML_TYPE_Q3_K: + return MMQ_Q8_1_DS_LAYOUT_D4; + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + return MMQ_Q8_1_DS_LAYOUT_DS4; + case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: + return MMQ_Q8_1_DS_LAYOUT_D4; + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_S_R4: + return MMQ_Q8_1_DS_LAYOUT_DS4; + case GGML_TYPE_IQ4_XS: + case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ2_KS: + case GGML_TYPE_IQ2_KL: + case GGML_TYPE_IQ2_K: + case GGML_TYPE_IQ2_K_R4: + case GGML_TYPE_IQ3_KS: + case GGML_TYPE_IQ3_K: + case GGML_TYPE_IQ3_K_R4: + case GGML_TYPE_IQ4_KSS: + case GGML_TYPE_IQ4_KS: + case GGML_TYPE_IQ4_KS_R4: + case GGML_TYPE_IQ4_K: + case GGML_TYPE_IQ4_K_R4: + case GGML_TYPE_IQ5_KS: + case GGML_TYPE_IQ5_KS_R4: + case GGML_TYPE_IQ5_K: + case GGML_TYPE_IQ5_K_R4: + case GGML_TYPE_IQ6_K: + case GGML_TYPE_IQ1_KT: + case GGML_TYPE_IQ2_KT: + case GGML_TYPE_IQ3_KT: + case GGML_TYPE_IQ4_KT: + return MMQ_Q8_1_DS_LAYOUT_D4; + default: + GGML_ABORT("fatal error"); + break; + } +} + +struct tile_x_sizes { + int qs; + int dm; + int sc; +}; + +#define GGML_CUDA_CC_PASCAL 600 +#define GGML_CUDA_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products +#define GGML_CUDA_CC_VOLTA 700 +#define GGML_CUDA_CC_TURING 750 +#define GGML_CUDA_CC_AMPERE 800 +#define GGML_CUDA_CC_ADA_LOVELACE 890 +#define GGML_CUDA_CC_OFFSET_AMD 0x1000000 +#define GGML_CUDA_CC_OFFSET_MTHREADS 0x0100000 + +// AMD +// GCN/CDNA, wave size is 64 +#define GGML_CUDA_CC_GCN4 (GGML_CUDA_CC_OFFSET_AMD + 0x803) // Tonga, Fiji, Polaris, minimum for fast fp16 +#define GGML_CUDA_CC_VEGA (GGML_CUDA_CC_OFFSET_AMD + 0x900) // Vega56/64, minimum for fp16 dual issue +#define GGML_CUDA_CC_VEGA20 (GGML_CUDA_CC_OFFSET_AMD + 0x906) // MI50/Radeon VII, minimum for dp4a +#define GGML_CUDA_CC_CDNA1 (GGML_CUDA_CC_OFFSET_AMD + 0x908) // MI100, minimum for MFMA, acc registers +#define GGML_CUDA_CC_CDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x910) // MI210, minimum acc register renameing +#define GGML_CUDA_CC_CDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x942) // MI300 + +// RDNA removes MFMA, dp4a, xnack, acc registers, wave size is 32 +#define GGML_CUDA_CC_RDNA1 (GGML_CUDA_CC_OFFSET_AMD + 0x1010) // RX 5000 +#define GGML_CUDA_CC_RDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x1030) // RX 6000, minimum for dp4a +#define GGML_CUDA_CC_RDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x1100) // RX 7000, minimum for WMMA +#define GGML_CUDA_CC_RDNA4 (GGML_CUDA_CC_OFFSET_AMD + 0x1200) // RX 9000 + +#define GGML_CUDA_CC_IS_RDNA(cc) (cc >= GGML_CUDA_CC_RDNA1) +#define GGML_CUDA_CC_IS_RDNA1(cc) (cc >= GGML_CUDA_CC_RDNA1 && cc < GGML_CUDA_CC_RDNA2) +#define GGML_CUDA_CC_IS_RDNA2(cc) (cc >= GGML_CUDA_CC_RDNA2 && cc < GGML_CUDA_CC_RDNA3) +#define GGML_CUDA_CC_IS_RDNA3(cc) (cc >= GGML_CUDA_CC_RDNA3 && cc < GGML_CUDA_CC_RDNA4) +#define GGML_CUDA_CC_IS_RDNA4(cc) (cc >= GGML_CUDA_CC_RDNA4) +#define GGML_CUDA_CC_IS_GCN(cc) (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA1) +#define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA1 && cc < GGML_CUDA_CC_RDNA1) +#define GGML_CUDA_CC_IS_CDNA3(cc) (cc >= GGML_CUDA_CC_CDNA3 && cc < GGML_CUDA_CC_RDNA1) + +#define GGML_CUDA_CC_QY1 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x210) // MTT S80, MTT S3000 +#define GGML_CUDA_CC_QY2 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x220) // MTT S4000 +#define GGML_CUDA_CC_NG (GGML_CUDA_CC_OFFSET_MTHREADS + 0x310) // TBD + +#define GGML_CUDA_CC_IS_MTHREADS(cc) (cc >= GGML_CUDA_CC_OFFSET_MTHREADS && cc < GGML_CUDA_CC_OFFSET_AMD) +#define GGML_CUDA_CC_IS_QY1(cc) (cc >= GGML_CUDA_CC_QY1 && cc < GGML_CUDA_CC_QY2) +#define GGML_CUDA_CC_IS_QY2(cc) (cc >= GGML_CUDA_CC_QY2 && cc < GGML_CUDA_CC_NG) +#define GGML_CUDA_CC_IS_NG(cc) (cc >= GGML_CUDA_CC_NG) + +#if CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA) +#define GGML_CUDA_ASSUME(x) __builtin_assume(x) +#else +#define GGML_CUDA_ASSUME(x) +#endif // CUDART_VERSION >= 11010 + +#if (!defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)) || (defined(GGML_USE_HIP) && !defined(GGML_HIP_NO_VMM)) +#define GGML_USE_VMM +#endif // (!defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)) || (defined(GGML_USE_HIP) && !defined(GGML_HIP_NO_VMM)) + +#if (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA) +#define FP16_MMA_AVAILABLE +#endif // (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA) + +#if defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || (defined(GGML_HIP_ROCWMMA_FATTN_GFX12) && defined(RDNA4))) +#define FP16_MMA_AVAILABLE +#endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || (defined(GGML_HIP_ROCWMMA_FATTN_GFX12) && defined(RDNA4))) + +#if defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA) +#define AMD_MFMA_AVAILABLE +#endif // defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA) + +#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING +#define TURING_MMA_AVAILABLE +#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING + +#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +#define AMPERE_MMA_AVAILABLE +#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE + +#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +#define CP_ASYNC_AVAILABLE +#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE + +#ifdef __CUDACC__ +template +__host__ __device__ constexpr inline void ggml_unused_vars_impl(Args&&...) noexcept {} +#define GGML_UNUSED_VARS(...) ggml_unused_vars_impl(__VA_ARGS__) +#else +#define GGML_UNUSED_VARS(...) do { (void)sizeof((__VA_ARGS__, 0)); } while(0) +#endif // __CUDACC__ + +static bool amd_mfma_available(const int cc) { +#if !defined(GGML_HIP_NO_MMQ_MFMA) + return GGML_CUDA_CC_IS_CDNA(cc); +#else + return false; +#endif //!defined(GGML_HIP_NO_MMQ_MFMA) +} +static bool turing_mma_available(const int cc) { + return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= CC_TURING; +} + +static int get_mmq_x_max_host(const int cc) { + return (amd_mfma_available(cc) || turing_mma_available(cc)) ? 128 : + GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= CC_VOLTA ? +#ifdef GGML_CUDA_FORCE_MMQ + 128 : 64; +#else + MMQ_DP4A_MAX_BATCH_SIZE : 64; +#endif // GGML_CUDA_FORCE_MMQ +} +static constexpr __device__ int ggml_cuda_get_physical_warp_size() { +#if defined(GGML_USE_HIP) && (defined(__GFX9__) || defined(__GFX8__)) + return 64; +#else + return 32; +#endif // defined(GGML_USE_HIP) && (defined(__GFX9__) || defined(__GFX8__)) +} +static constexpr int ggml_cuda_get_physical_warp_size_host() { +#if defined(GGML_USE_HIP) && (defined(__GFX9__) || defined(__GFX8__)) + return 64; +#else + return 32; +#endif // defined(GGML_USE_HIP) && (defined(__GFX9__) || defined(__GFX8__)) +} +static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) { +#if CUDART_VERSION >= 12080 + const nv_bfloat16 e = __nv_cvt_e8m0_to_bf16raw(x); + return (float) e; +#else + uint32_t bits; + if (x == 0) { + bits = 0x00400000; + } else { + bits = (uint32_t) x << 23; + } + + float result; + memcpy(&result, &bits, sizeof(float)); + return result; +#endif // CUDART_VERSION >= 12050 +} +template + +static __device__ __forceinline__ int warp_reduce_any(int x) { + if (width == ggml_cuda_get_physical_warp_size()) { + return __any_sync(0xffffffff, x); + } else { +#pragma unroll + for (int offset = width/2; offset > 0; offset >>= 1) { + x = __shfl_xor_sync(0xffffffff, x, offset, width) || x; + } + return x; + } +} +template +static __device__ __forceinline__ int warp_reduce_sum(int x) { +#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE + return __reduce_add_sync(0xffffffff, x); +#else +#pragma unroll + for (int offset = width/2; offset > 0; offset >>= 1) { + x += __shfl_xor_sync(0xffffffff, x, offset, width); + } + return x; +#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +} +template +static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) { +#pragma unroll + for (int offset = width/2; offset > 0; offset >>= 1) { + a.x += __shfl_xor_sync(0xffffffff, a.x, offset, width); + a.y += __shfl_xor_sync(0xffffffff, a.y, offset, width); + } + return a; +} +template +static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) { +#ifdef FP16_AVAILABLE +#pragma unroll + for (int offset = width/2; offset > 0; offset >>= 1) { + a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, offset, width)); + } + return a; + +#else + NO_DEVICE_CODE; + return a; +#endif // FP16_AVAILABLE +} + +static bool fp16_mma_hardware_available(const int cc) { + return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_VOLTA) || + GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc) || + (GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2); +} + + +static constexpr __device__ int get_mmq_x_max_device() { +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + return 128; +#else // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + +#if defined(GGML_USE_HIP) + return 64; +#else // defined(GGML_USE_HIP) + +#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA +#ifdef GGML_CUDA_FORCE_MMQ + return 128; +#else // GGML_CUDA_FORCE_MMQ + return MMQ_DP4A_MAX_BATCH_SIZE; +#endif // GGML_CUDA_FORCE_MMQ +#else // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA + return 64; +#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA + +#endif // defined(GGML_USE_HIP) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +} + +static int get_mmq_y_host(const int cc) { + return GGML_CUDA_CC_IS_AMD(cc) ? (GGML_CUDA_CC_IS_RDNA1(cc) ? 64 : 128) : + ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) ? 128 : 64); +} + +static constexpr __device__ int get_mmq_y_device() { +#if defined(GGML_USE_HIP) +#if defined(RDNA1) + return 64; +#else + return 128; +#endif // defined RDNA1 +#else +#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA + return 128; +#else + return 64; +#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA +#endif // defined(GGML_USE_HIP) +} + +// Decouple shared memory tile sizes from WARP_SIZE to allow for different warp sizes. +// The K dimension of the tiles has either, +// 1*MMQ_TILE_NE_K==32 (always for TILE_Y_K) or 2*MMQ_TILE_NE_K==64 (typically for TILE_X_K), +// 32 bit elements for the quantized data (does not include scales). +// In other words, the size of the quantized data in the K dimension is a multiple of MMQ_TILE_NE_K. +// The final tile size in K direction is padded to avoid shared memory bank conflicts, +// in terms of 32 bit elements that means K % 2 == 1 for dp4a or K % 8 == 4 for mma. +#define MMQ_TILE_NE_K 32 + +#define MMQ_DP4A_TXS_Q4_0 tile_x_sizes{mmq_y*MMQ_TILE_NE_K + mmq_y, mmq_y*MMQ_TILE_NE_K/QI4_0 + mmq_y/QI4_0, 0} +#define MMQ_DP4A_TXS_Q4_1 tile_x_sizes{mmq_y*MMQ_TILE_NE_K + mmq_y, mmq_y*MMQ_TILE_NE_K/QI4_1 + mmq_y/QI4_1, 0} +#define MMQ_DP4A_TXS_Q8_0 tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K*2/QI8_0 + mmq_y/(QI8_0/2), 0} +#define MMQ_DP4A_TXS_Q8_0_16 tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K*4/QI8_0 + mmq_y/(QI8_0/4), 0} +#define MMQ_DP4A_TXS_Q8_1 tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K*2/QI8_1 + mmq_y/(QI8_1/2), 0} +#define MMQ_DP4A_TXS_Q2_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K + mmq_y, 0} +#define MMQ_DP4A_TXS_Q3_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y, mmq_y*MMQ_TILE_NE_K/8 + mmq_y/8} +#define MMQ_DP4A_TXS_Q4_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K + mmq_y, mmq_y*MMQ_TILE_NE_K/QI4_K, mmq_y*MMQ_TILE_NE_K/8 + mmq_y/8} +#define MMQ_DP4A_TXS_Q5_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K/QI5_K + mmq_y/QI5_K, mmq_y*MMQ_TILE_NE_K/8 + mmq_y/8} +#define MMQ_DP4A_TXS_Q6_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K/QI6_K + mmq_y/QI6_K, mmq_y*MMQ_TILE_NE_K/8 + mmq_y/8} + +static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) { + switch (type) { + case GGML_TYPE_Q4_0: return MMQ_DP4A_TXS_Q4_0; + case GGML_TYPE_Q4_1: return MMQ_DP4A_TXS_Q4_1; + case GGML_TYPE_Q5_0: return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_Q5_1: return MMQ_DP4A_TXS_Q8_1; + case GGML_TYPE_Q6_0 : return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_Q8_0: return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_MXFP4: return MMQ_DP4A_TXS_Q8_1; + case GGML_TYPE_Q2_K: return MMQ_DP4A_TXS_Q2_K; + case GGML_TYPE_Q3_K: return MMQ_DP4A_TXS_Q3_K; + case GGML_TYPE_Q4_K: return MMQ_DP4A_TXS_Q4_K; + case GGML_TYPE_Q5_K: return MMQ_DP4A_TXS_Q5_K; + case GGML_TYPE_Q6_K: return MMQ_DP4A_TXS_Q6_K; + case GGML_TYPE_IQ2_XXS: return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_IQ2_XS: return MMQ_DP4A_TXS_Q8_0_16; + case GGML_TYPE_IQ2_S: return MMQ_DP4A_TXS_Q8_0_16; + case GGML_TYPE_IQ3_XXS: return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_IQ3_S: return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_IQ1_S: return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_IQ1_S_R4:return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_IQ4_XS: return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_IQ4_NL: return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_IQ2_KS : return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_IQ2_KL : return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_IQ2_K : return MMQ_DP4A_TXS_Q8_0_16; + case GGML_TYPE_IQ2_K_R4: return MMQ_DP4A_TXS_Q8_0_16; + case GGML_TYPE_IQ3_KS : return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_IQ3_K : return MMQ_DP4A_TXS_Q8_0_16; + case GGML_TYPE_IQ3_K_R4: return MMQ_DP4A_TXS_Q8_0_16; + case GGML_TYPE_IQ4_KSS : return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_IQ4_KS : return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_IQ4_KS_R4: return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_IQ4_K : return MMQ_DP4A_TXS_Q8_0_16; + case GGML_TYPE_IQ4_K_R4: return MMQ_DP4A_TXS_Q8_0_16; + case GGML_TYPE_IQ5_KS : return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_IQ5_KS_R4: return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_IQ5_K : return MMQ_DP4A_TXS_Q8_0_16; + case GGML_TYPE_IQ5_K_R4: return MMQ_DP4A_TXS_Q8_0_16; + case GGML_TYPE_IQ6_K : return MMQ_DP4A_TXS_Q8_0_16; + case GGML_TYPE_IQ1_KT : return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_IQ2_KT : return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_IQ3_KT : return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_IQ4_KT : return MMQ_DP4A_TXS_Q8_0; + default: return tile_x_sizes{0, 0, 0}; + } +} + +#define MMQ_MMA_TILE_X_K_Q8_0 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4) +#define MMQ_MMA_TILE_X_K_Q8_1 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4) +#define MMQ_MMA_TILE_X_K_Q2_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K + 4) +#define MMQ_MMA_TILE_X_K_Q3_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4) +#define MMQ_MMA_TILE_X_K_Q6_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI6_K + MMQ_TILE_NE_K/8 + 7) + +static_assert(MMQ_MMA_TILE_X_K_Q8_0 % 8 == 4, "Wrong padding."); +static_assert(MMQ_MMA_TILE_X_K_Q8_1 % 8 == 4, "Wrong padding."); +static_assert(MMQ_MMA_TILE_X_K_Q2_K % 8 == 4, "Wrong padding."); +static_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, "Wrong padding."); +static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding."); + +static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { + switch (type) { + case GGML_TYPE_Q4_0: return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_Q4_1: return MMQ_MMA_TILE_X_K_Q8_1; + case GGML_TYPE_Q5_0: return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_Q5_1: return MMQ_MMA_TILE_X_K_Q8_1; + case GGML_TYPE_Q6_0: return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_Q8_0: return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_MXFP4: return MMQ_MMA_TILE_X_K_Q8_1; + case GGML_TYPE_Q2_K: return MMQ_MMA_TILE_X_K_Q2_K; + case GGML_TYPE_Q3_K: return MMQ_MMA_TILE_X_K_Q3_K; + case GGML_TYPE_Q4_K: return MMQ_MMA_TILE_X_K_Q8_1; + case GGML_TYPE_Q5_K: return MMQ_MMA_TILE_X_K_Q8_1; + case GGML_TYPE_Q6_K: return MMQ_MMA_TILE_X_K_Q6_K; + case GGML_TYPE_IQ2_XXS: return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_IQ2_XS: return MMQ_MMA_TILE_X_K_Q3_K; + case GGML_TYPE_IQ2_S: return MMQ_MMA_TILE_X_K_Q3_K; + case GGML_TYPE_IQ3_XXS: return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_IQ3_S: return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_IQ1_S: return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_IQ1_S_R4:return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_IQ4_XS: return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_IQ4_NL: return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_IQ2_KS : return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_IQ2_KL : return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_IQ2_K : return MMQ_MMA_TILE_X_K_Q3_K; + case GGML_TYPE_IQ2_K_R4: return MMQ_MMA_TILE_X_K_Q3_K; + case GGML_TYPE_IQ3_KS : return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_IQ3_K : return MMQ_MMA_TILE_X_K_Q3_K; + case GGML_TYPE_IQ3_K_R4: return MMQ_MMA_TILE_X_K_Q3_K; + case GGML_TYPE_IQ4_KSS : return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_IQ4_KS : return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_IQ4_KS_R4: return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_IQ4_K : return MMQ_MMA_TILE_X_K_Q3_K; + case GGML_TYPE_IQ4_K_R4: return MMQ_MMA_TILE_X_K_Q3_K; + case GGML_TYPE_IQ5_KS : return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_IQ5_KS_R4: return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_IQ5_K : return MMQ_MMA_TILE_X_K_Q3_K; + case GGML_TYPE_IQ5_K_R4: return MMQ_MMA_TILE_X_K_Q3_K; + case GGML_TYPE_IQ6_K : return MMQ_MMA_TILE_X_K_Q3_K; + case GGML_TYPE_IQ1_KT : return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_IQ2_KT : return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_IQ3_KT : return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_IQ4_KT : return MMQ_MMA_TILE_X_K_Q8_0; + default: return 0; + } +} + +// block_q8_1_mmq has (128 8-bit ints == 32 32-bit ints + 4 32-bit scales) +#define MMQ_TILE_Y_K (MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI8_1) + +static int mmq_get_granularity_host(const int mmq_x, const int cc) { + if (amd_mfma_available(cc)) { + return mmq_x >= 128 ? 32 : 16; + } else if (turing_mma_available(cc) && mmq_x >= 48) { + return 16; + } else { + return 8; + } +} + +#if defined(AMD_MFMA_AVAILABLE) +static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) { + return mmq_x >= 128 ? 32 : 16; +} +#elif defined(TURING_MMA_AVAILABLE) +static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) { + return mmq_x >= 48 ? 16 : 8; +} +#else +static constexpr __device__ int mmq_get_granularity_device(const int /*mmq_x*/) { + return 8; +} +#endif // AMD_MFMA_AVAILABLE + +#if defined(GGML_USE_HIP) +static int mmq_get_nwarps_host(const int cc, const int warp_size) { + return amd_mfma_available(cc) ? 8 : 256/warp_size; +} +#else +static int mmq_get_nwarps_host(const int /*cc*/, const int warp_size) { + return 256/warp_size; +} +#endif // (GGML_USE_HIP) + +static constexpr __device__ int mmq_get_nwarps_device() { +#if defined(AMD_MFMA_AVAILABLE) + return 8; +#else + return 256/ggml_cuda_get_physical_warp_size(); +#endif // AMD_MFMA_AVAILABLE +} + +// ------------------------------------------------------------ + +template static __device__ __forceinline__ void load_tiles_q4_0( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + 2*MMQ_TILE_NE_K); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + + constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_0); + constexpr int nrows = warp_size / threads_per_row; + const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; + const int kbx = txi / QI4_0; + const int kqsx = txi % QI4_0; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); + + if (need_check) { + i = min(i, i_max); + } + + const block_q4_0 * bxi = (const block_q4_0 *)(x + i*stride) + kbx0 + kbx; + const int qs0 = get_int_b2(bxi->qs, kqsx); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + 0] = __vsubss4((qs0 >> 0) & 0x0F0F0F0F, 0x08080808); + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + QI4_0] = __vsubss4((qs0 >> 4) & 0x0F0F0F0F, 0x08080808); +#else + x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } + + constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_0; + constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row; + const int kbxd = threadIdx.x % blocks_per_tile_x_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) { + int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_q4_0 * bxi = (const block_q4_0 *)(x + i*stride) + kbx0 + kbxd; + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; +#else + x_df[i*(MMQ_TILE_NE_K/QI4_0) + i/QI4_0 + kbxd] = bxi->d; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } +} + +template +static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a( + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y); + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + txs.qs; + const int * y_qs = (const int *) y + 4; + const half2 * y_ds = (const half2 *) y; + +// #pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR4_0*VDR_Q4_0_Q8_1_MMQ) { + const int k0 = k00 + k01; + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { + const int i = i0 + threadIdx.x; + + const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2); + + int u[2*VDR_Q4_0_Q8_1_MMQ]; + +#pragma unroll + for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) { + u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + kyqs + l]; + u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_0)]; + } + + sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_0_q8_1_impl + (&x_qs[i*(MMQ_TILE_NE_K + 1) + k0/QR4_0], u, + x_df[i*(MMQ_TILE_NE_K/QI4_0) + i/QI4_0 + k0/(QR4_0*QI4_0)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); + } + } + } +} + +template static __device__ __forceinline__ void load_tiles_q4_1( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + int * x_qs = (int *) x_tile; + half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y); + int * x_qs = (int *) x_tile; + half2 * x_dm = (half2 *) (x_qs + txs.qs); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + + constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_1); + constexpr int nrows = warp_size / threads_per_row; + const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; + const int kbx = txi / QI4_1; + const int kqsx = txi % QI4_1; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); + + if (need_check) { + i = min(i, i_max); + } + + const block_q4_1 * bxi = (const block_q4_1 *)(x + i*stride) + kbx0 + kbx; + const int qs0 = get_int_b4(bxi->qs, kqsx); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + 0] = (qs0 >> 0) & 0x0F0F0F0F; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + QI4_1] = (qs0 >> 4) & 0x0F0F0F0F; +#else + x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } + + constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_1; + constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row; + const int kbxd = threadIdx.x % blocks_per_tile_x_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) { + int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_q4_1 * bxi = (const block_q4_1 *)(x + i*stride) + kbx0 + kbxd; + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm; +#else + x_dm[i*(MMQ_TILE_NE_K/QI4_1) + i/QI4_1 + kbxd] = bxi->dm; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } +} + +template +static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a( + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y); + const int * x_qs = (const int *) x; + const half2 * x_dm = (const half2 *) x_qs + txs.qs; + const int * y_qs = (const int *) y + 4; + const half2 * y_ds = (const half2 *) y; + +// #pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR4_1*VDR_Q4_1_Q8_1_MMQ) { + const int k0 = k00 + k01; + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { + const int i = i0 + threadIdx.x; + + const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2); + + int u[2*VDR_Q4_1_Q8_1_MMQ]; + +#pragma unroll + for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) { + u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + kyqs + l]; + u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_1)]; + } + + sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_1_q8_1_impl + (&x_qs[i*(MMQ_TILE_NE_K + 1) + k0/QR4_1], u, + x_dm[i*(MMQ_TILE_NE_K/QI4_1) + i/QI4_1 + k0/(QR4_1*QI4_1)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); + } + } + } +} + +template static __device__ __forceinline__ void load_tiles_q5_0( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_0, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + + constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_0); + constexpr int nrows = warp_size / threads_per_row; + const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; + const int kbx = txi / QI5_0; + const int kqsx = txi % QI5_0; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); + + if (need_check) { + i = min(i, i_max); + } + + const block_q5_0 * bxi = (const block_q5_0 *)(x + i*stride) + kbx0 + kbx; + + const int ql = get_int_b2(bxi->qs, kqsx); + const int qh = get_int_b2(bxi->qh, 0) >> (4 * kqsx); + + int qs0 = (ql >> 0) & 0x0F0F0F0F; + qs0 |= (qh << 4) & 0x00000010; // 0 -> 4 + qs0 |= (qh << 11) & 0x00001000; // 1 -> 12 + qs0 |= (qh << 18) & 0x00100000; // 2 -> 20 + qs0 |= (qh << 25) & 0x10000000; // 3 -> 28 + qs0 = __vsubss4(qs0, 0x10101010); // subtract 16 + + int qs1 = (ql >> 4) & 0x0F0F0F0F; + qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4 + qs1 |= (qh >> 5) & 0x00001000; // 17 -> 12 + qs1 |= (qh << 2) & 0x00100000; // 18 -> 20 + qs1 |= (qh << 9) & 0x10000000; // 19 -> 28 + qs1 = __vsubss4(qs1, 0x10101010); // subtract 16 + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + 0] = qs0; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1; +#else + x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_0) + kqsx + 0] = qs0; + x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } + + constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI5_0; + constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row; + const int kbxd = threadIdx.x % blocks_per_tile_x_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) { + int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_q5_0 * bxi = (const block_q5_0 *)(x + i*stride) + kbx0 + kbxd; + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; +#else + x_df[i*(MMQ_TILE_NE_K/QI5_0) + i/QI5_0 + kbxd] = bxi->d; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } +} + +template static __device__ __forceinline__ void load_tiles_q5_1( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + int * x_qs = (int *) x_tile; + half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y); + int * x_qs = (int *) x_tile; + half2 * x_dm = (half2 *) (x_qs + txs.qs); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + + constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_1); + constexpr int nrows = warp_size / threads_per_row; + const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; + const int kbx = txi / QI5_1; + const int kqsx = txi % QI5_1; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); + + if (need_check) { + i = min(i, i_max); + } + + const block_q5_1 * bxi = (const block_q5_1 *)(x + i*stride) + kbx0 + kbx; + + const int ql = get_int_b4(bxi->qs, kqsx); + const int qh = get_int_b4(bxi->qh, 0) >> (4 * kqsx); + + int qs0 = (ql >> 0) & 0x0F0F0F0F; + qs0 |= (qh << 4) & 0x00000010; // 0 -> 4 + qs0 |= (qh << 11) & 0x00001000; // 1 -> 12 + qs0 |= (qh << 18) & 0x00100000; // 2 -> 20 + qs0 |= (qh << 25) & 0x10000000; // 3 -> 28 + + int qs1 = (ql >> 4) & 0x0F0F0F0F; + qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4 + qs1 |= (qh >> 5) & 0x00001000; // 17 -> 12 + qs1 |= (qh << 2) & 0x00100000; // 18 -> 20 + qs1 |= (qh << 9) & 0x10000000; // 19 -> 28 + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + 0] = qs0; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1; +#else + x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_1) + kqsx + 0] = qs0; + x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } + + constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI5_1; + constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row; + const int kbxd = threadIdx.x % blocks_per_tile_x_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) { + int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_q5_1 * bxi = (const block_q5_1 *)(x + i*stride) + kbx0 + kbxd; + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm; +#else + x_dm[i*(MMQ_TILE_NE_K/QI5_1) + i/QI5_1 + kbxd] = bxi->dm; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } +} + +template static __device__ __forceinline__ void load_tiles_q6_0( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + + constexpr int nwarps = mmq_get_nwarps_device(); + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_0, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE + + const int kbx = threadIdx.x / QI6_0; + const int kqsx = threadIdx.x % QI6_0; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + threadIdx.y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q6_0 * bxi = (const block_q6_0 *)(x + i*stride) + kbx0 + kbx; + + const int ql = get_int_b2(bxi->qs, kqsx); + const int qh = get_int_b2(bxi->qh, kqsx%2) >> 4*(kqsx/2); + + int qs0 = ((ql >> 0) & 0x0F0F0F0F) | ((qh << 4) & 0x30303030); + int qs1 = ((ql >> 4) & 0x0F0F0F0F) | ((qh << 2) & 0x30303030); + qs0 = __vsubss4(qs0, 0x20202020); // subtract 32 + qs1 = __vsubss4(qs1, 0x20202020); // subtract 32 + +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI6_0) + kqsx + 0] = qs0; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI6_0) + kqsx + QI6_0] = qs1; +#else + x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI6_0) + kqsx + 0] = qs0; + x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI6_0) + kqsx + QI6_0] = qs1; +#endif // INT8_MMA_AVAILABLE + } + + const int blocks_per_tile_x_row = WARP_SIZE / QI6_0; + const int kbxd = threadIdx.x % blocks_per_tile_x_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI6_0) { + int i = i0 + threadIdx.y * QI6_0 + threadIdx.x / blocks_per_tile_x_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_q6_0 * bxi = (const block_q6_0 *)(x + i*stride) + kbx0 + kbxd; + +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; +#else + x_df[i*(WARP_SIZE/QI6_0) + i/QI6_0 + kbxd] = bxi->d; +#endif // INT8_MMA_AVAILABLE + } +} + +template static __device__ __forceinline__ void load_tiles_q8_0( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_tile + 2*MMQ_TILE_NE_K); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + + // MMQ_ITER_K / (4 * QR8_0) == 64 required. but NV has only 32 threads per warp + constexpr int threads_per_row = 32; + constexpr int nrows = warp_size / threads_per_row; + const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; + const int kbx = txi / QI8_0; + const int kqsx = txi % QI8_0; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); + + if (need_check) { + i = min(i, i_max); + } + + const block_q8_0 * bxi = (const block_q8_0 *)(x + i*stride) + kbx0 + kbx; + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 0 + txi] = get_int_b2(bxi[0].qs, kqsx); + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + MMQ_TILE_NE_K + txi] = get_int_b2(bxi[MMQ_TILE_NE_K/QI8_0].qs, kqsx); +#else + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 0 + txi] = get_int_b2(bxi[0].qs, kqsx); + x_qs[i*(2*MMQ_TILE_NE_K + 1) + MMQ_TILE_NE_K + txi] = get_int_b2(bxi[MMQ_TILE_NE_K/QI8_0].qs, kqsx); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } + + constexpr int blocks_per_tile_x_row = 2*MMQ_TILE_NE_K / QI8_0; + constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row; + const int kbxd = threadIdx.x % blocks_per_tile_x_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) { + int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_q8_0 * bxi = (const block_q8_0 *)(x + i*stride) + kbx0 + kbxd; + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; +#else + x_df[i*(2*MMQ_TILE_NE_K/QI8_0) + i/(QI8_0/2) + kbxd] = bxi->d; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } +} + +template static __device__ __forceinline__ void load_tiles_mxfp4( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_MXFP4, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + + constexpr int threads_per_row = MMQ_ITER_K / (4 * QR_MXFP4); + constexpr int nrows = warp_size / threads_per_row; + const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; + const int kbx = txi / QI_MXFP4; + const int kqsx = txi % QI_MXFP4; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); + + if (need_check) { + i = min(i, i_max); + } + + const block_mxfp4 * bxi = (const block_mxfp4 *)(x + i*stride) + kbx0 + kbx; + + const int aux_q4 = get_int_b1(bxi->qs, kqsx); + const int2 v = get_int_from_table_16(aux_q4, kvalues_mxfp4); + const int k0 = kbx * (2 * QI_MXFP4) + kqsx; + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + k0 + 0] = v.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + k0 + QI_MXFP4] = v.y; +#else + x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x; + x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + QI_MXFP4] = v.y; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } + + constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI_MXFP4; + constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row; + const int kbxd = threadIdx.x % blocks_per_tile_x_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) { + int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_mxfp4 * bxi = (const block_mxfp4 *)(x + i*stride) + kbx0 + kbxd; + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_df[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = ggml_cuda_e8m0_to_fp32(bxi->e)*0.5f; +#else + x_df[i*(MMQ_TILE_NE_K/QI_MXFP4) + i/QI_MXFP4 + kbxd] = ggml_cuda_e8m0_to_fp32(bxi->e)*0.5f; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } +} + +template +static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a( + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y); + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + txs.qs; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; + +// #pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += VDR_Q8_0_Q8_1_MMQ) { + const int k0 = k00 + k01; + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { + const int i = i0 + threadIdx.x; + + sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q8_0_q8_1_impl + (&x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k0 % MMQ_TILE_NE_K], + x_df[i*(2*MMQ_TILE_NE_K/QI8_0) + i/(QI8_0/2) + k0/QI8_0], y_df[j*MMQ_TILE_Y_K + (k0/QI8_1) % (MMQ_TILE_NE_K/QI8_1)]); + } + } + } +} + +template +static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { +#if defined(AMD_MFMA_AVAILABLE) + typedef tile<16, 8, int> tile_A; + typedef tile<16, 8, int> tile_B; + typedef tile<16, 16, int> tile_C; + + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = granularity; + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + 2*MMQ_TILE_NE_K; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; + const half2 * y_ds = (const half2 *) y; + + const int i0 = (threadIdx.y / ntx) * rows_per_warp; + + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) { + const int k0 = k00 + k01; + + tile_A A[ntx]; +#pragma unroll + for (int n = 0; n < ntx; ++n) { + load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0); + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { + tile_B B; + load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + + float dB; + const int j = j0 + tile_C::get_j(0); + if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D4) { + dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1]; + } else { + dB = __low2float(y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); + } + +#pragma unroll + for (int n = 0; n < ntx; ++n) { + tile_C C; + mma(C, A[n], B); + +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + const int i = i0 + n*tile_A::I + tile_C::get_i(l); + const float dA = x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k0/QI8_0]; + sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l]*dA*dB; + } + } + } + } +#else + typedef tile<16, 8, int> tile_A; + typedef tile< 8, 8, int> tile_B; + typedef tile<16, 8, int> tile_C; + + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = 2 * granularity; + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + 2*MMQ_TILE_NE_K; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; + const half2 * y_ds = (const half2 *) y; + + tile_A A[ntx][MMQ_TILE_NE_K/QI8_0]; + float dA[ntx][tile_C::ne/2][MMQ_TILE_NE_K/QI8_0]; + + const int i0 = (threadIdx.y/ntx)*rows_per_warp; + +#pragma unroll + for (int n = 0; n < ntx; ++n) { +#pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) { + const int k0 = k00 + k01; + + load_ldmatrix(A[n][k01/QI8_0], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0); + } + +#pragma unroll + for (int l = 0; l < tile_C::ne/2; ++l) { + const int i = i0 + n*tile_A::I + tile_C::get_i(2*l); + +#pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) { + const int k0 = k00 + k01; + + dA[n][l][k01/QI8_0] = x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k0/QI8_0]; + } + } + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { +#pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) { + tile_B B; + float dB[tile_C::ne/2]; + + load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix + +#pragma unroll + for (int l = 0; l < tile_C::ne/2; ++l) { + const int j = j0 + tile_C::get_j(l); + + if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D4) { + dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1]; + } else { + dB[l] = __low2float(y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); + } + } + +#pragma unroll + for (int n = 0; n < ntx; ++n) { + tile_C C; + mma(C, A[n][k01/QI8_0], B); + +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l]*dA[n][l/2][k01/QI8_0]*dB[l%2]; + } + } + } + } +#endif // defined(AMD_MFMA_AVAILABLE) +} + +template +static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a( + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y); + const int * x_qs = (const int *) x; + const half2 * x_dm = (const half2 *) x_qs + txs.qs; + const int * y_qs = (const int *) y + 4; + const half2 * y_ds = (const half2 *) y; + +// #pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += VDR_Q8_0_Q8_1_MMQ) { + const int k0 = k00 + k01; + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { + const int i = i0 + threadIdx.x; + + sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q8_1_q8_1_impl + (&x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], + x_dm[i*(MMQ_TILE_NE_K/QI5_1) + i/QI5_1 + k0/QI8_1], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); + } + } + } +} + +template +static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma( + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { +#if defined(AMD_MFMA_AVAILABLE) + typedef tile<16, 8, int> tile_A; + typedef tile<16, 8, int> tile_B; + typedef tile<16, 16, int> tile_C; + + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = granularity; + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const half2 * x_dm = (const half2 *) x_qs + 2*MMQ_TILE_NE_K; + const int * y_qs = (const int *) y + 4; + const half2 * y_dm = (const half2 *) y; + + const int i0 = (threadIdx.y / ntx) * rows_per_warp; + + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) { + const int k0 = k00 + k01; + + tile_A A[ntx]; +#pragma unroll + for (int n = 0; n < ntx; ++n) { + load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1); + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { + tile_B B; + load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + + const int j = j0 + tile_C::get_j(0); + const float2 dsB = __half22float2(y_dm[j*MMQ_TILE_Y_K + k01/QI8_1]); + +#pragma unroll + for (int n = 0; n < ntx; ++n) { + tile_C C; + mma(C, A[n], B); + +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + const int i = i0 + n*tile_A::I + tile_C::get_i(l); + float2 dmA = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + k0/QI8_1]); + sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA.x*dsB.x*C.x[l]; + sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA.y*dsB.y; + } + } + } + } +#else + typedef tile<16, 8, int> tile_A; + typedef tile< 8, 8, int> tile_B; + typedef tile<16, 8, int> tile_C; + + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = 2 * granularity; + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const half2 * x_dm = (const half2 *) x_qs + 2*MMQ_TILE_NE_K; + const int * y_qs = (const int *) y + 4; + const half2 * y_dm = (const half2 *) y; + + tile_A A[ntx][MMQ_TILE_NE_K/QI8_1]; + float2 dmA[ntx][tile_C::ne/2][MMQ_TILE_NE_K/QI8_1]; + + const int i0 = (threadIdx.y/ntx)*rows_per_warp; + +#pragma unroll + for (int n = 0; n < ntx; ++n) { +#pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) { + const int k0 = k00 + k01; + + load_ldmatrix(A[n][k01/QI8_1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1); + } + +#pragma unroll + for (int l = 0; l < tile_C::ne/2; ++l) { + const int i = i0 + n*tile_A::I + tile_C::get_i(2*l); + +#pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) { + const int k0 = k00 + k01; + + dmA[n][l][k01/QI8_1] = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + k0/QI8_1]); + } + } + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { +#pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) { + tile_B B; + float2 dsB[tile_C::ne/2]; + + load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix + +#pragma unroll + for (int l = 0; l < tile_C::ne/2; ++l) { + const int j = j0 + tile_C::get_j(l); + + dsB[l] = __half22float2(y_dm[j*MMQ_TILE_Y_K + k01/QI8_1]); + } + +#pragma unroll + for (int n = 0; n < ntx; ++n) { + tile_C C; + mma(C, A[n][k01/QI8_1], B); + +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA[n][l/2][k01/QI8_1].x*dsB[l%2].x*C.x[l]; + sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA[n][l/2][k01/QI8_1].y*dsB[l%2].y; + } + } + } + } +#endif // defined(AMD_MFMA_AVAILABLE) +} + +// Used for Q3_K, IQ2_S, and IQ2_XS +template +static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a( + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + + constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16; + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + txs.qs; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; + +// #pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) { + const int k0 = k00 + k01; + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { + const int i = i0 + threadIdx.x; + + sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q8_0_16_q8_1_impl( + &x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], + &y_qs[j*MMQ_TILE_Y_K + k01], + &x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + k0/(QI8_0/2)], + y_df[j*MMQ_TILE_Y_K + k01/QI8_1]); + } + } + } +} + +// Used for Q3_K, IQ2_S, and IQ2_XS: +template +static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { +#if defined(AMD_MFMA_AVAILABLE) + typedef tile<16, 8, int> tile_A; + typedef tile<16, 8, int> tile_B; + typedef tile<16, 16, int> tile_C; + typedef tile<64, 2, int> tile_load; + + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = granularity; + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; + + const int i0 = (threadIdx.y / ntx) * rows_per_warp; + + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) { + const int k0 = k00 + k01; + + tile_A A[ntx]; +#pragma unroll + for (int n = 0; n < ntx; ++n) { + load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K); + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { + tile_B B[1]; + load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + + const int j = j0 + tile_C::get_j(0); + const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1] / 2; + +#pragma unroll + for (int n = 0; n < ntx; ++n) { + tile_C C; + mma(C, A[n], B[0]); + +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + const int i = i0 + n*tile_C::I + tile_C::get_i(l); + sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/4] * dB; + } + } + } + } +#elif defined(TURING_MMA_AVAILABLE) + + typedef tile<16, 4, int> tile_A; + typedef tile<16, 8, int> tile_A_8; + typedef tile< 8, 4, int> tile_B; + typedef tile<16, 8, int> tile_C; + + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = 2 * granularity; + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; + + const int i0 = (threadIdx.y / ntx) * (ntx*tile_A::I); + + tile_A A[ntx][8]; + float dA[ntx][tile_C::ne/2][8]; + +#pragma unroll + for (int n = 0; n < ntx; ++n) { +#pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 8) { + const int k0 = k00 + k01; + + load_ldmatrix(((tile_A_8 *) A[n])[k01/8], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K); + } + +#pragma unroll + for (int l = 0; l < tile_C::ne/2; ++l) { + const int i = i0 + n*tile_C::I + tile_C::get_i(2*l); + +#pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) { + const int k0 = k00 + k01; + + dA[n][l][k01/4] = x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/4]; + } + } + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { +#pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) { + tile_B B[2]; + float dB[tile_C::ne/2]; + + // Here load_generic is faster than load_ldmatrix. + load_generic(B[0], y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K); + load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + (k01 + tile_B::J), MMQ_TILE_Y_K); + +#pragma unroll + for (int l = 0; l < tile_C::ne/2; ++l) { + const int j = j0 + tile_C::get_j(l); + + dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1]; + } + +#pragma unroll + for (int n = 0; n < ntx; ++n) { + tile_C C[2]; + mma(C[0], A[n][k01/4 + 0], B[0]); + mma(C[1], A[n][k01/4 + 1], B[1]); + +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + sum[(j0/tile_C::J + n)*tile_C::ne + l] += dB[l%2]*(C[0].x[l]*dA[n][l/2][k01/4 + 0] + C[1].x[l]*dA[n][l/2][k01/4 + 1]); + } + } + } + } +#else + GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum), GGML_UNUSED(k00); + NO_DEVICE_CODE; +#endif // AMD_MFMA_AVAILABLE +} + +template static __device__ __forceinline__ void load_tiles_q2_K( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + int * x_qs = (int *) x_tile; + half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y); + int * x_qs = (int *) x_tile; + half2 * x_dm = (half2 *) (x_qs + txs.qs); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + + constexpr int threads_per_row = MMQ_ITER_K / (4 * QR2_K); + constexpr int nrows = ggml_cuda_get_physical_warp_size() / threads_per_row; + const int kqsx = threadIdx.x % threads_per_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_q2_K * bxi = (const block_q2_K *)(x + i*stride) + kbx0; + + const int x_ql_0 = get_int_b2(bxi->qs, kqsx); + +#pragma unroll + for (int l = 0; l < QR2_K; ++l) { + const int k = (kqsx/8)*32 + l*8 + kqsx % 8; + + const int x_qs_k = (x_ql_0 >> (2*l)) & 0x03030303; + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k] = x_qs_k; +#else + x_qs[i*(2*MMQ_TILE_NE_K + 1) + k] = x_qs_k; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } + + const int sc_m = bxi->scales[kqsx]; +#ifdef FAST_FP16_AVAILABLE + const half2 x_dm_ik = __hmul2(bxi->dm, make_half2(sc_m & 0x0F, sc_m >> 4)); +#else + const float2 bxi_dmf = __half22float2(bxi->dm); + const half2 x_dm_ik = make_half2(bxi_dmf.x*(sc_m & 0x0F), bxi_dmf.y*(sc_m >> 4)); +#endif // FAST_FP16_AVAILABLE + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + kqsx] = x_dm_ik; +#else + x_dm[i*(MMQ_TILE_NE_K + 1) + kqsx] = x_dm_ik; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } +} + +template +static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a( + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y); + const int * x_qs = (const int *) x; + const half2 * x_dm = (const half2 *) x_qs + txs.qs; + const int * y_qs = (const int *) y + 4; + const half2 * y_ds = (const half2 *) y; + + float2 y_df[mmq_x/nwarps]; +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + + y_df[j0/nwarps] = __half22float2(y_ds[j*MMQ_TILE_Y_K]); + } + +#pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K/2; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) { + const int k0 = k00 + k01; + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { + const int i = i0 + threadIdx.x; + + constexpr int ns = 2; + sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q2_K_q8_1_impl_mmq( + &x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], + &x_dm[i*(MMQ_TILE_NE_K + 1) + k0/4], k01 < MMQ_TILE_NE_K/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y, + &y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]); + } + } + } + + // Some compilers fail to unroll the loop over k01 if there is a conditional statement for ns in the inner loop. + // As a workaround 2 separate loops are used instead. +#pragma unroll + for (int k01 = MMQ_TILE_NE_K/2; k01 < MMQ_TILE_NE_K; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) { + const int k0 = k00 + k01; + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { + const int i = i0 + threadIdx.x; + + constexpr int ns = 1; + sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q2_K_q8_1_impl_mmq( + &x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], + &x_dm[i*(MMQ_TILE_NE_K + 1) + k0/4], k01 < MMQ_TILE_NE_K/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y, + &y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]); + } + } + } +} + +template +static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { +#if defined(AMD_MFMA_AVAILABLE) + typedef tile<16, 8, int> tile_A; + typedef tile<16, 8, int> tile_B; + typedef tile<16, 16, int> tile_C; + typedef tile<64, 2, int> tile_load; + + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = granularity; + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const half2 * x_dm = (const half2 *) x_qs + MMQ_TILE_NE_K*2; + const int * y_qs = (const int *) y + 4; + const half2 * y_ds = (const half2 *) y; + + const int i0 = (threadIdx.y / ntx) * rows_per_warp; + + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) { + const int k0 = k00 + k01; + + tile_A A[ntx]; +#pragma unroll + for (int n = 0; n < ntx; ++n) { + load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K); + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { + tile_B B[1]; + load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + + const int j = j0 + tile_C::get_j(0); + const float dB = (k01 < MMQ_TILE_NE_K/2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K]).x/2 : __half22float2(y_ds[j*MMQ_TILE_Y_K]).y/2; + const float sB = (k01 >= MMQ_TILE_NE_K * 3/4) ? 0 + : (((k01/4)%2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).y + : __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).x); + + tile_C Cm; + if (k01 >= MMQ_TILE_NE_K * 3/4) { + tile_A A1; + A1.x[0] = 0x01010101; + A1.x[1] = 0x01010101; + mma(Cm, A1, B[0]); + } + +#pragma unroll + for (int n = 0; n < ntx; ++n) { + tile_C Cd; + mma(Cd, A[n], B[0]); + +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + const int i = i0 + n*tile_C::I + tile_C::get_i(l); + const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0/4]); + float tmp = Cd.x[l]*dm.x; + if (k01 >= MMQ_TILE_NE_K * 3/4) { + tmp -= Cm.x[l]*dm.y; + } + sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*dB; + sum[(j0/tile_C::J + n)*tile_C::ne + l] -= dm.y*sB; + } + } + } + } +#elif defined(TURING_MMA_AVAILABLE) + + typedef tile<16, 4, int> tile_A; + typedef tile<16, 8, int> tile_A_8; + typedef tile< 8, 4, int> tile_B; + typedef tile<16, 8, int> tile_C; + + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = 2 * granularity; + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const half2 * x_dm = (const half2 *) x_qs + MMQ_TILE_NE_K*2; + const int * y_qs = (const int *) y + 4; + const half2 * y_ds = (const half2 *) y; + + const int i0 = (threadIdx.y / ntx) * (ntx*tile_A::I); + + tile_A A[ntx][8]; + float dA[ntx][tile_C::ne/2][8]; + float mA[ntx][tile_C::ne/2][8]; + +#pragma unroll + for (int n = 0; n < ntx; ++n) { +#pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) { + const int k0 = k00 + k01; + + load_ldmatrix(((tile_A_8 *) A[n])[k01/QI8_1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K); + } + } + +#pragma unroll + for (int n = 0; n < ntx; ++n) { +#pragma unroll + for (int l = 0; l < tile_C::ne/2; ++l) { + const int i = i0 + n*tile_C::I + tile_C::get_i(2*l); + +#pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1/2) { + const int k0 = k00 + k01; + + const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0/(QI8_1/2)]); + + dA[n][l][k01/(QI8_1/2)] = dm.x; + mA[n][l][k01/(QI8_1/2)] = dm.y; + } + } + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { + float2 dB[tile_C::ne/2]; + +#pragma unroll + for (int l = 0; l < tile_C::ne/2; ++l) { + const int j = j0 + tile_C::get_j(l); + + dB[l] = __half22float2(y_ds[j*MMQ_TILE_Y_K]); + } + +#pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) { + tile_B B[2]; + + // Here load_generic is faster than load_ldmatrix. + load_generic(B[0], y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K); + load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + (k01 + tile_B::J), MMQ_TILE_Y_K); + + tile_C Cm[2]; + if (k01 >= MMQ_TILE_NE_K * 3/4) { + tile_A A1; + A1.x[0] = 0x01010101; + A1.x[1] = 0x01010101; + mma(Cm[0], A1, B[0]); + mma(Cm[1], A1, B[1]); + } + +#pragma unroll + for (int n = 0; n < ntx; ++n) { + tile_C Cd[2]; + + mma(Cd[0], A[n][k01/4 + 0], B[0]); + mma(Cd[1], A[n][k01/4 + 1], B[1]); + +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + float tmp = Cd[0].x[l]*dA[n][l/2][k01/4 + 0] + Cd[1].x[l]*dA[n][l/2][k01/4 + 1]; + if (k01 >= MMQ_TILE_NE_K * 3/4) { + tmp -= Cm[0].x[l]*mA[n][l/2][k01/4 + 0] + Cm[1].x[l]*mA[n][l/2][k01/4 + 1]; + } + sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*(k01 < MMQ_TILE_NE_K/2 ? dB[l%2].x : dB[l%2].y); + } + } + } + +#pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K * 3/4; k01 += QI8_1) { + float2 sB[tile_C::ne/2]; + +#pragma unroll + for (int l = 0; l < tile_C::ne/2; ++l) { + const int j = j0 + tile_C::get_j(l); + + sB[l] = __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]); + } + +#pragma unroll + for (int n = 0; n < ntx; ++n) { +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + sum[(j0/tile_C::J + n)*tile_C::ne + l] -= mA[n][l/2][k01/4 + 0]*sB[l%2].x; + sum[(j0/tile_C::J + n)*tile_C::ne + l] -= mA[n][l/2][k01/4 + 1]*sB[l%2].y; + } + } + } + } +#else + GGML_UNUSED_VARS(x, y, sum, k00); + NO_DEVICE_CODE; +#endif // AMD_MFMA_AVAILABLE +} + +template static __device__ __forceinline__ void load_tiles_q3_K( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); + int * x_sc = (int *) (x_df + txs.dm); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + + constexpr int threads_per_row = MMQ_ITER_K / (4 * QR3_K); + constexpr int nrows = warp_size / threads_per_row; + const int kqsx = threadIdx.x % threads_per_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_q3_K * bxi = (const block_q3_K *)(x + i*stride) + kbx0; + + const int x_ql_0 = get_int_b2(bxi->qs, kqsx); + const int x_qh_0 = get_int_b2(bxi->hmask, kqsx % (QI3_K/2)) >> (4 * (kqsx / (QI3_K/2))); + +#pragma unroll + for (int l = 0; l < QR3_K; ++l) { + const int k = (kqsx/8)*32 + l*8 + kqsx % 8; + + const int x_ql_k = (x_ql_0 >> (2*l)) & 0x03030303; + const int x_qh_k = ((x_qh_0 >> l) << 2) & 0x04040404; + + const int x_qs_k = __vsubss4(x_ql_k | x_qh_k, 0x04040404); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k] = x_qs_k; +#else + x_qs[i*(2*MMQ_TILE_NE_K + 1) + k] = x_qs_k; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } + } + + constexpr int rows_per_warp = warp_size / 4; +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) { + int i = i0 + threadIdx.y*rows_per_warp + threadIdx.x/4; + + if (need_check) { + i = min(i, i_max); + } + + const block_q3_K * bxi = (const block_q3_K *)(x + i*stride) + kbx0; + + const int ksc = threadIdx.x % 4; + + const int ksc_low = ksc % (QI3_K/8); + const int shift_low = 4 * (ksc / (QI3_K/8)); + const int sc_low = (get_int_b2(bxi->scales, ksc_low) >> shift_low) & 0x0F0F0F0F; + + const int ksc_high = QI3_K/8; + const int shift_high = 2 * ksc; + const int sc_high = ((get_int_b2(bxi->scales, ksc_high) >> shift_high) << 4) & 0x30303030; + + const int sc = __vsubss4(sc_low | sc_high, 0x20202020); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + const int8_t * sc8 = (const int8_t *) ≻ + const float d = bxi->d; + +#pragma unroll + for (int l = 0; l < int(sizeof(int)); ++l) { + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + sizeof(int)*ksc + l] = d*sc8[l]; + } +#else + x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = sc; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } + +#if !(defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)) +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) { + int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q3_K * bxi = (const block_q3_K *)(x + i*stride) + kbx0; + + x_df[i] = bxi->d; + } +#endif // !(defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)) +} + +template +static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a( + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y); + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + txs.qs; + const int * x_sc = (const int *) x_df + txs.dm; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; + +// #pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) { + const int k0 = k00 + k01; + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { + const int i = i0 + threadIdx.x; + + const int8_t * scales = ((const int8_t *) (x_sc + i*(MMQ_TILE_NE_K/8) + i/8)) + k0/4; + + sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q3_K_q8_1_impl_mmq( + &x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], scales, + x_df[i], y_df[j*MMQ_TILE_Y_K + k01/QI8_1]); + } + } + } +} + +static __device__ __forceinline__ int unpack_scales_q45_K(const int * scales, const int ksc) { + // scale arrangement after the following two lines: + // - ksc == 0: sc0, sc1, sc2, sc3 + // - ksc == 1: sc4, sc5, sc6, sc7 + // - ksc == 2: m0, m1, m2, m3 + // - ksc == 3: m4, m5, m6, m7 + return ((scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F) | // lower 4 bits + ((scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030); // upper 2 bits +} + +template static __device__ __forceinline__ void load_tiles_q4_K( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + int * x_qs = (int *) x_tile; + half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y); + int * x_qs = (int *) x_tile; + half2 * x_dm = (half2 *) (x_qs + txs.qs); + int * x_sc = (int *) (x_dm + txs.dm); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + + constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_K); + constexpr int nrows = warp_size / threads_per_row; + const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); + + if (need_check) { + i = min(i, i_max); + } + + const block_q4_K * bxi = (const block_q4_K *)(x + i*stride) + kbx0; + const int qs0 = get_int_b4(bxi->qs, txi); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(txi/8) + txi % 8 + 0] = (qs0 >> 0) & 0x0F0F0F0F; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(txi/8) + txi % 8 + 8] = (qs0 >> 4) & 0x0F0F0F0F; +#else + x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + constexpr int rows_per_warp = warp_size / 2; +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) { +#if defined(AMD_MFMA_AVAILABLE) + // Need if on AMD instead of % because warp_size == 64 + // This causes double work and throughput loss (MI300X) + // H100 loses about 100 t/s with 'if' condition over '%' + int i = i0 + threadIdx.y*rows_per_warp + threadIdx.x/2; + if (i < mmq_y) { +#else + int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/2) % mmq_y; + { +#endif // defined(AMD_MFMA_AVAILABLE) + if (need_check) { + i = min(i, i_max); + } + + const block_q4_K * bxi = (const block_q4_K *)(x + i*stride) + kbx0; + + const int * scales = (const int *) bxi->scales; + const int ksc = threadIdx.x % 2; + + const int sc32 = unpack_scales_q45_K(scales, ksc + 0); + const int m32 = unpack_scales_q45_K(scales, ksc + 2); + + const uint8_t * sc8 = (const uint8_t *) &sc32; + const uint8_t * m8 = (const uint8_t *) &m32; + + const half2 dm = bxi->dm * make_half2(1.0f, -1.0f); + + #pragma unroll + for (int l = 0; l < sizeof(int); ++l) { + x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]); + } + } + } +#else +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) { + int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q4_K * bxi = (const block_q4_K *)(x + i*stride) + kbx0; + + x_dm[i] = bxi->dm; + } + constexpr int rows_per_warp = warp_size / 4; +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) { + int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/(MMQ_TILE_NE_K/8)) % mmq_y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q4_K * bxi = (const block_q4_K *)(x + i*stride) + kbx0 + (threadIdx.x % (MMQ_TILE_NE_K/8)) / (QI4_K/8); + + const int * scales = (const int *) bxi->scales; + + const int ksc = threadIdx.x % (MMQ_TILE_NE_K/8); + const int scales8 = unpack_scales_q45_K(scales, ksc); + + x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = scales8; + } +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +} + +template +static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a( + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y); + const int * x_qs = (const int *) x; + const half2 * x_dm = (const half2 *) x_qs + txs.qs; + const int * x_sc = (const int *) x_dm + txs.dm; + const int * y_qs = (const int *) y + 4; + const half2 * y_ds = (const half2 *) y; + +// #pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR4_K*VDR_Q4_K_Q8_1_MMQ) { + const int k0 = k00 + k01; + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { + const int i = i0 + threadIdx.x; + + const uint8_t * sc = (const uint8_t *) &x_sc[i * (MMQ_TILE_NE_K/8) + i/8 + k0/32] + 2*(k01/16); + + sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_K_q8_1_impl_mmq( + &x_qs[i*(MMQ_TILE_NE_K + 1) + k0/2], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8, + x_dm[i], &y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); + } + } + } +} + +template static __device__ __forceinline__ void load_tiles_q5_K( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + int * x_qs = (int *) x_tile; + half2 * x_dm = (half2 *) (x_qs + MMQ_TILE_NE_K*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y); + int * x_qs = (int *) x_tile; + half2 * x_dm = (half2 *) (x_qs + txs.qs); + int * x_sc = (int *) (x_dm + txs.dm); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + + constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_K); + constexpr int nrows = warp_size / threads_per_row; + const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); + + if (need_check) { + i = min(i, i_max); + } + + const block_q5_K * bxi = (const block_q5_K *)(x + i*stride) + kbx0; + const int ky = QR5_K*txi; + + const int ql = get_int_b4(bxi->qs, txi); + const int ql0 = (ql >> 0) & 0x0F0F0F0F; + const int ql1 = (ql >> 4) & 0x0F0F0F0F; + + const int qh = get_int_b4(bxi->qh, txi % (QI5_K/4)); + const int qh0 = ((qh >> (2 * (txi / (QI5_K/4)) + 0)) << 4) & 0x10101010; + const int qh1 = ((qh >> (2 * (txi / (QI5_K/4)) + 1)) << 4) & 0x10101010; + + const int kq0 = ky - ky % (QI5_K/2) + txi % (QI5_K/4) + 0; + const int kq1 = ky - ky % (QI5_K/2) + txi % (QI5_K/4) + QI5_K/4; + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq0] = ql0 | qh0; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq1] = ql1 | qh1; +#else + x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq0] = ql0 | qh0; + x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq1] = ql1 | qh1; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + constexpr int rows_per_warp = warp_size / 2; +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) { +#if defined(AMD_MFMA_AVAILABLE) + // Need if on AMD instead of % because warp_size == 64 + // This causes double work and throughput loss (MI300X) + // H100 loses about 100 t/s with 'if' condition over '%' + int i = i0 + threadIdx.y*rows_per_warp + threadIdx.x/2; + if (i < mmq_y) { +#else + int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/2) % mmq_y; + { +#endif // defined(AMD_MFMA_AVAILABLE) + if (need_check) { + i = min(i, i_max); + } + + const block_q5_K * bxi = (const block_q5_K *)(x + i*stride) + kbx0; + + const int * scales = (const int *) bxi->scales; + const int ksc = threadIdx.x % 2; + + const int sc32 = unpack_scales_q45_K(scales, ksc + 0); + const int m32 = unpack_scales_q45_K(scales, ksc + 2); + + const uint8_t * sc8 = (const uint8_t *) &sc32; + const uint8_t * m8 = (const uint8_t *) &m32; + + const half2 dm = bxi->dm * make_half2(1.0f, -1.0f); + +#pragma unroll + for (int l = 0; l < int(sizeof(int)); ++l) { + x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]); + } + } + } +#else +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) { + int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q5_K * bxi = (const block_q5_K *)(x + i*stride) + kbx0; + + x_dm[i] = bxi->dm; + } + + constexpr int rows_per_warp = warp_size / 4; +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) { + int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/(MMQ_TILE_NE_K/8)) % mmq_y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q5_K * bxi = (const block_q5_K *)(x + i*stride) + kbx0; + + const int * scales = (const int *) bxi->scales; + + const int ksc = threadIdx.x % (MMQ_TILE_NE_K/8); + const int scales8 = unpack_scales_q45_K(scales, ksc); + + x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = scales8; + } +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +} + +template +static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a( + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y); + const int * x_qs = (const int *) x; + const half2 * x_dm = (const half2 *) x_qs + txs.qs; + const int * x_sc = (const int *) x_dm + txs.dm; + const int * y_qs = (const int *) y + 4; + const half2 * y_ds = (const half2 *) y; + +// #pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR5_K*VDR_Q5_K_Q8_1_MMQ) { + const int k0 = k00 + k01; + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { + const int i = i0 + threadIdx.x; + + const uint8_t * sc = ((const uint8_t *) &x_sc[i * (MMQ_TILE_NE_K/8) + i/8 + k00/32]) + 2*(k01/16); + + sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q5_K_q8_1_impl_mmq( + &x_qs[i*(QR5_K*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8, + x_dm[i], &y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); + } + } + } +} + +template static __device__ __forceinline__ void load_tiles_q6_K( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); + int * x_sc = (int *) (x_df + MMQ_TILE_NE_K/QI6_K); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); + int * x_sc = (int *) (x_df + txs.dm); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + + constexpr int threads_per_row = MMQ_ITER_K / (4 * QR6_K); + constexpr int nrows = warp_size / threads_per_row; + const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); + + if (need_check) { + i = min(i, i_max); + } + + const block_q6_K * bxi = (const block_q6_K *)(x + i*stride) + kbx0; + + const int ql = get_int_b2(bxi->ql, txi); + const int ql0 = (ql >> 0) & 0x0F0F0F0F; + const int ql1 = (ql >> 4) & 0x0F0F0F0F; + + const int qh = get_int_b2(bxi->qh, (QI6_K/4) * (txi / (QI6_K/2)) + txi % (QI6_K/4)); + const int qh0 = ((qh >> ((txi & 0x08) >> 2)) << 4) & 0x30303030; + const int qh1 = (qh >> ((txi & 0x08) >> 2)) & 0x30303030; + + const int kq0 = 2*txi - txi % (QI6_K/2) + 0; + const int kq1 = 2*txi - txi % (QI6_K/2) + QI6_K/2; + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq0] = __vsubss4(ql0 | qh0, 0x20202020); + x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq1] = __vsubss4(ql1 | qh1, 0x20202020); +#else + x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020); + x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) { + int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q6_K * bxi = (const block_q6_K *)(x + i*stride) + kbx0; + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_df[i*MMQ_MMA_TILE_X_K_Q6_K] = bxi->d; +#else + x_df[i*(MMQ_TILE_NE_K/QI6_K) + i/QI6_K] = bxi->d; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } + + constexpr int rows_per_warp = warp_size / 4; +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) { + int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/(MMQ_TILE_NE_K/8)) % mmq_y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q6_K * bxi = (const block_q6_K *)(x + i*stride) + kbx0 + (threadIdx.x % (MMQ_TILE_NE_K/8)) / 4; + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + threadIdx.x%4] = get_int_b2(bxi->scales, threadIdx.x % (MMQ_TILE_NE_K/8)); +#else + x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + threadIdx.x%(MMQ_TILE_NE_K/8)] = get_int_b2(bxi->scales, threadIdx.x%(QI6_K/8)); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } +} + +template +static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a( + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y); + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + txs.qs; + const int * x_sc = (const int *) x_df + txs.dm; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; + +// #pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR6_K*VDR_Q6_K_Q8_1_MMQ) { + const int k0 = k00 + k01; + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { + const int i = i0 + threadIdx.x; + + const int8_t * sc = ((const int8_t *) &x_sc[i * (MMQ_TILE_NE_K/8) + i/8 + k0/16]); + + sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q6_K_q8_1_impl_mmq( + &x_qs[i*(QR6_K*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc, + x_df[i*(MMQ_TILE_NE_K/QI6_K) + i/QI6_K], &y_df[j*MMQ_TILE_Y_K + k01/QI8_1]); + } + } + } +} + +template +static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { +#if defined(AMD_MFMA_AVAILABLE) + typedef tile<16, 8, int> tile_A; + typedef tile<16, 8, int> tile_B; + typedef tile<16, 16, int> tile_C; + typedef tile<64, 2, int> tile_load; + + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = granularity; + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2; + const int * x_sc = (const int *) x_df + MMQ_TILE_NE_K/QI6_K; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; + + const int i0 = (threadIdx.y / ntx) * rows_per_warp; + + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) { + const int k0 = k00 + k01; + + tile_A A[ntx]; +#pragma unroll + for (int n = 0; n < ntx; ++n) { + load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + k0, MMQ_MMA_TILE_X_K_Q6_K); + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { + tile_B B[1]; + load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + + const int j = j0 + tile_C::get_j(0); + const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1] / 2; + +#pragma unroll + for (int n = 0; n < ntx; ++n) { + tile_C C; + mma(C, A[n], B[0]); + +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + const int i = i0 + n*tile_C::I + tile_C::get_i(l); + const int8_t * sc = (const int8_t *) (x_sc + i*MMQ_MMA_TILE_X_K_Q6_K + k00/16); + sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * sc[k01/4] * x_df[i*MMQ_MMA_TILE_X_K_Q6_K] * dB; + } + } + } + } +#elif defined(TURING_MMA_AVAILABLE) + + typedef tile<16, 4, int> tile_A; + typedef tile< 8, 4, int> tile_B; + typedef tile<16, 8, int> tile_C; + + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = 2 * granularity; + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2; + const int * x_sc = (const int *) x_df + MMQ_TILE_NE_K/QI6_K; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; + + const int i0 = (threadIdx.y / ntx) * (ntx*tile_A::I); + + tile_A A[ntx][8]; + int scA[ntx][tile_C::ne/2][8]; + float dA[ntx][tile_C::ne/2]; + +#pragma unroll + for (int n = 0; n < ntx; ++n) { +#pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 8) { + const int k0 = k00 + k01; + + load_ldmatrix(A[n][k01/4 + 0], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + 0), MMQ_MMA_TILE_X_K_Q6_K); + load_ldmatrix(A[n][k01/4 + 1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + tile_A::J), MMQ_MMA_TILE_X_K_Q6_K); + } + +#pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 16) { + const int k0 = k00 + k01; + +#pragma unroll + for (int l = 0; l < tile_C::ne/2; ++l) { + const int i = i0 + n*tile_C::I + tile_C::get_i(2*l); + + const int sc_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + k0/16]; + const int8_t * sc = (const int8_t *) &sc_packed; + +#pragma unroll + for (int ksc = 0; ksc < sizeof(int); ++ksc) { + scA[n][l][k01/4 + ksc] = sc[ksc]; + } + } + } + +#pragma unroll + for (int l = 0; l < tile_C::ne/2; ++l) { + const int i = i0 + n*tile_C::I + tile_C::get_i(2*l); + + dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q6_K]; + } + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { + float tmp[ntx][tile_C::ne] = {{0.0f}}; + +#pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 8) { + tile_B B[2]; + float dB[tile_C::ne/2]; + + // Here load_generic is faster than load_ldmatrix. + load_generic(B[0], y_qs + j0*MMQ_TILE_Y_K + 0 + k01, MMQ_TILE_Y_K); + load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + tile_B::J + k01, MMQ_TILE_Y_K); + +#pragma unroll + for (int l = 0; l < tile_C::ne/2; ++l) { + const int j = j0 + tile_C::get_j(l); + + dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1]; + } + +#pragma unroll + for (int n = 0; n < ntx; ++n) { + tile_C C[2]; + mma(C[0], A[n][k01/4 + 0], B[0]); + mma(C[1], A[n][k01/4 + 1], B[1]); + +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + tmp[n][l] += (C[0].x[l]*scA[n][l/2][k01/4 + 0] + C[1].x[l]*scA[n][l/2][k01/4 + 1])*dB[l%2]; + } + } + } + +#pragma unroll + for (int n = 0; n < ntx; ++n) { +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp[n][l]*dA[n][l/2]; + } + } + } +#else + GGML_UNUSED_VARS(x, y, sum, k00); + NO_DEVICE_CODE; +#endif // AMD_MFMA_AVAILABLE +} + +template static __device__ __forceinline__ void load_tiles_iq4_nl( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_NL, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + + constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_NL); + constexpr int nrows = warp_size / threads_per_row; + const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; + const int kbx = txi / QI4_NL; + const int kqsx = txi % QI4_NL; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); + + if (need_check) { + i = min(i, i_max); + } + + const block_iq4_nl * bxi = (const block_iq4_nl *)(x + i*stride) + kbx0 + kbx; + + const int aux_q4 = get_int_b2(bxi->qs, kqsx); + const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl); + const int k0 = kbx * (2 * QI4_NL) + kqsx; + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + QI4_NL] = v.y; +#else + x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x; + x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + QI4_NL] = v.y; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } + + constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_NL; + constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row; + const int kbxd = threadIdx.x % blocks_per_tile_x_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) { + int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_iq4_nl * bxi = (const block_iq4_nl *)(x + i*stride) + kbx0 + kbxd; + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = __half2float(bxi->d); +#else + x_df[i*(MMQ_TILE_NE_K/QI4_NL) + i/QI4_NL + kbxd] = __half2float(bxi->d); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } +} + +template static __device__ __forceinline__ void load_tiles_iq2_xxs( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_XXS, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + + constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_XXS)) / 2; + constexpr int nrows = warp_size / threads_per_row; + const int kqsx = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) { + int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_iq2_xxs * bxi = (const block_iq2_xxs *)(x + i*stride) + kbx0; + + const int q2 = get_int_b2(bxi->qs, 2*kqsx+0); + const uint8_t * aux8 = (const uint8_t *) &q2; + const uint32_t aux32 = get_int_b2(bxi->qs, 2*kqsx+1); + +#pragma unroll + for (int l = 0; l < QR2_XXS; ++l) { + const int * grid_pos = (const int *) (iq2xxs_grid + aux8[l]); + const int signs_packed = ksigns_iq2xs[(aux32 >> (7*l)) & 0x7F]; + + const int signs0 = __vcmpne4(((signs_packed & 0x03) << 7) | ((signs_packed & 0x0C) << 21), 0x00000000); + const int grid0 = __vsub4(grid_pos[0] ^ signs0, signs0); + + const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000); + const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid0; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid1; +#else + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid0; + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid1; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } + + const int ls = aux32 >> 28; + const float d = bxi->d; +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/4; +#else + x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (ls*d + d/2)/4; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } +} + +template static __device__ __forceinline__ void load_tiles_iq2_xs( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); +#else + constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16; + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + + constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_XS)) / 2; + constexpr int nrows = warp_size / threads_per_row; + const int kqsx = threadIdx.x % threads_per_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) { + int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_iq2_xs * bxi = (const block_iq2_xs *)(x + i*stride) + kbx0; + + const int2 q2_packed = make_int2(get_int_b2(bxi->qs, 2*kqsx+0), get_int_b2(bxi->qs, 2*kqsx+1)); + const uint16_t * q2 = (const uint16_t *) &q2_packed; + + #pragma unroll + for (int l = 0; l < QR2_XS; ++l) { + const uint32_t * grid_pos = (const uint32_t *)(iq2xs_grid + (q2[l] & 0x000001FF)); + const uint32_t * signs = (const uint32_t *)(ksigns64 + (q2[l] >> 9)); + + const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]); + const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l; + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h; +#else + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l; + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } + + const int ls = bxi->scales[kqsx]; + const float d = bxi->d; +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; +#else + x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; + x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } +} + +template static __device__ __forceinline__ void load_tiles_iq2_s( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_S, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + + constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_S)) / 2; + constexpr int nrows = warp_size / threads_per_row; + const int kqsx = threadIdx.x % threads_per_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) { + int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_iq2_s * bxi = (const block_iq2_s *)(x + i*stride) + kbx0; + + const int qs_packed = get_int_b2(bxi->qs, kqsx); + const uint8_t * qs = (const uint8_t *) &qs_packed; + + const int qh = bxi->qh[kqsx]; + + const int signs_packed_32 = get_int_b2(bxi->qs, QK_K/32 + kqsx); + const uint8_t * signs_packed_8 = (const uint8_t *) &signs_packed_32; + +#pragma unroll + for (int l = 0; l < QR2_S; ++l) { + const int * grid_pos = (const int *)(iq2s_grid + (qs[l] | ((qh << (8-2*l)) & 0x300))); + + const int signs0 = __vcmpne4(((signs_packed_8[l] & 0x03) << 7) | ((signs_packed_8[l] & 0x0C) << 21), 0x00000000); + const int signs1 = __vcmpne4(((signs_packed_8[l] & 0x30) << 3) | ((signs_packed_8[l] & 0xC0) << 17), 0x00000000); + + const int grid_l = __vsub4(grid_pos[0] ^ signs0, signs0); + const int grid_h = __vsub4(grid_pos[1] ^ signs1, signs1); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l; + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h; +#else + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l; + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } + + const int ls = bxi->scales[kqsx]; + const float d = bxi->d; +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; +#else + x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; + x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } +} + +template static __device__ __forceinline__ void load_tiles_iq3_xxs( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_XXS, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + + constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR3_XXS)) / 2; + constexpr int nrows = warp_size / threads_per_row; + const int kqsx = threadIdx.x % threads_per_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) { + int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_iq3_xxs * bxi = (const block_iq3_xxs *)(x + i*stride) + kbx0; + + const int2 q3_packed = make_int2(get_int_b2(bxi->qs, 2*kqsx+0), get_int_b2(bxi->qs, 2*kqsx+1)); + const uint8_t * q3 = (const uint8_t *) &q3_packed; + const uint32_t aux32 = get_int_b2(bxi->qs, QK_K/16 + kqsx); + +#pragma unroll + for (int l = 0; l < QR3_XXS; ++l) { + const int2 grid_pos = make_int2(iq3xxs_grid[q3[2*l+0]], iq3xxs_grid[q3[2*l+1]]); + + const int * signs = (const int *)(ksigns64 + ((aux32 >> (7*l)) & 0x7F)); + + const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]); + const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid_l; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid_h; +#else + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l; + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } + + const int ls = aux32 >> 28; + const float d = bxi->d; +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/2; +#else + x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (ls*d + d/2)/2; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } +} + +template static __device__ __forceinline__ void load_tiles_iq3_s( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + + constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR3_S)) / 2; + constexpr int nrows = warp_size / threads_per_row; + const int kqsx = threadIdx.x % threads_per_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) { + int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_iq3_s * bxi = (const block_iq3_s *)(x + i*stride) + kbx0; + + const int2 qs_packed = make_int2(get_int_b2(bxi->qs, 2*kqsx+0), get_int_b2(bxi->qs, 2*kqsx+1)); + const uint8_t * qs = (const uint8_t *) &qs_packed; + + const int qh = bxi->qh[kqsx]; + + const int signs_packed_32 = get_int_b2(bxi->signs, kqsx); + const uint8_t * signs_packed_8 = (const uint8_t *) &signs_packed_32; + +#pragma unroll + for (int l = 0; l < QR3_S; ++l) { + const int2 grid_pos = make_int2( + iq3s_grid[qs[2*l+0] | ((qh << (8 - 2*l)) & 0x100)], + iq3s_grid[qs[2*l+1] | ((qh << (7 - 2*l)) & 0x100)]); + + const int signs0 = __vcmpne4(((signs_packed_8[l] & 0x03) << 7) | ((signs_packed_8[l] & 0x0C) << 21), 0x00000000); + const int signs1 = __vcmpne4(((signs_packed_8[l] & 0x30) << 3) | ((signs_packed_8[l] & 0xC0) << 17), 0x00000000); + + const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0); + const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+0)] = grid_l; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+1)] = grid_h; +#else + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+0)] = grid_l; + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+1)] = grid_h; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } + + const int ls = 1 + 2*((bxi->scales[kqsx/2] >> (((2*kqsx) << 1) & 0x04)) & 0x0F); + const float d = bxi->d; +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = ls*d; +#else + x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = ls*d; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } +} + +template static __device__ __forceinline__ void load_tiles_iq1_s( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + int * x_qs = (int *) x_tile; + half2 * x_ds = (half2 *) (x_qs + MMQ_TILE_NE_K*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y); + int * x_qs = (int *) x_tile; + half2 * x_ds = (half2 *) (x_qs + txs.qs); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + + constexpr int threads_per_row = MMQ_ITER_K / (4 * QR1_S); + constexpr int nrows = warp_size / threads_per_row; + const int kqsx = threadIdx.x % threads_per_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) { + int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_iq1_s * bxi = (const block_iq1_s *)(x + i*stride) + kbx0; + + const int qs_packed = get_int_b2(bxi->qs, kqsx); + const uint8_t * qs = (const uint8_t *) &qs_packed; + + const int qh = bxi->qh[kqsx]; + + #pragma unroll + for (int l = 0; l < QR1_S/2; ++l) { + const int grid = iq1s_grid_gpu[qs[l] | (((qh >> (3*l)) & 0x07) << 8)]; + + const int grid0 = (grid >> 0) & 0x0F0F0F0F; + const int grid1 = (grid >> 4) & 0x0F0F0F0F; + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+0)] = grid0; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+1)] = grid1; +#else + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+0)] = grid0; + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+1)] = grid1; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } + + const float d1q = __half2float(bxi->d) * (((qh >> 11) & 0x0E) + 1); + const float delta = -1.0f + IQ1S_DELTA - (qh & 0x8000) * (2.0f*IQ1S_DELTA/0x8000); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_ds[i*MMQ_MMA_TILE_X_K_Q8_1 + kqsx] = make_half2(d1q, d1q*delta); +#else + x_ds[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = make_half2(d1q, d1q*delta); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } +} + +template static __device__ __forceinline__ void load_tiles_iq1_s_r4( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + + constexpr int nwarps = mmq_get_nwarps_device(); + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + 2*WARP_SIZE); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE + + const int kbx = threadIdx.x / 4; + const int kqsx = threadIdx.x % 4; + + int32_t grid32[2]; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + threadIdx.y; + + if (need_check) { + i = min(i, i_max); + } + const int i4 = i/4; + const int ir = i%4; + + const block_iq1_s_r4 * bxi = (const block_iq1_s_r4 *)(x + 4*i4*stride + 4*sizeof(half)) + kbx0 + kbx; + + grid32[0] = iq1s_grid_gpu[bxi->qs[4*kqsx+ir] | (((bxi->qh[ir] >> 3*kqsx) & 7) << 8)]; + grid32[1] = ((grid32[0] >> 4) & 0x0f0f0f0f) << 3; + grid32[0] = (grid32[0] & 0x0f0f0f0f) << 3; + const int shift = bxi->qh[ir] & 0x8000 ? 0x09090909 : 0x07070707; + +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kbx + 2*kqsx + 0] = __vsubss4(grid32[0], shift); + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kbx + 2*kqsx + 1] = __vsubss4(grid32[1], shift); +#else + // TODO + //x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = qs0; +#endif // INT8_MMA_AVAILABLE + } + + const int blocks_per_tile_x_row = WARP_SIZE / 4; + const int kbxd = threadIdx.x % blocks_per_tile_x_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) { + int i = i0 + threadIdx.y * 4 + threadIdx.x / blocks_per_tile_x_row; + + if (need_check) { + i = min(i, i_max); + } + const int i4 = i/4; + const int ir = i%4; + + const half * dptr = (const half *)(x + 4*i4*stride); + const block_iq1_s_r4 * bxi = (const block_iq1_s_r4 *)(dptr + 4) + kbx0 + kbxd; + +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = 0.125f * __half2float(dptr[ir]) * (((bxi->qh[ir] >> 11) & 14) + 1); +#else + // TODO + //x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + kbxd] = bxi->d; +#endif // INT8_MMA_AVAILABLE + } +} + +template static __device__ __forceinline__ void load_tiles_iq4_xs( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + + constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_XS); + constexpr int nrows = warp_size / threads_per_row; + const int kqsx = threadIdx.x % threads_per_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); + + if (need_check) { + i = min(i, i_max); + } + + const block_iq4_xs * bxi = (const block_iq4_xs *)(x + i*stride) + kbx0; + + const int aux_q4 = get_int_b4(bxi->qs, kqsx); + const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl); + const int k0 = 8 * (kqsx / 4) + kqsx % 4; + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y; +#else + x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x; + x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 4] = v.y; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } + + constexpr int rows_per_warp = warp_size / 8; +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) { + int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / (MMQ_TILE_NE_K/4); + + if (need_check) { + i = min(i, i_max); + } + + const block_iq4_xs * bxi = (const block_iq4_xs *)(x + i*stride) + kbx0; + + const float d = __half2float(bxi->d); + + const int ls = ((bxi->scales_l[(threadIdx.x % 8)/2] >> (4*(threadIdx.x % 2))) & 0x0F) + | (((bxi->scales_h >> (2*(threadIdx.x % 8))) & 0x03) << 4); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = d * (ls - 32); +#else + x_df[i*(MMQ_TILE_NE_K/4) + i/4 + threadIdx.x % 8] = d * (ls - 32); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } +} + +template +static __device__ __forceinline__ void mmq_write_back_dp4a_id( + const float * __restrict__ sum, const int32_t * __restrict__ ids_dst, float * __restrict__ dst, + const int stride, const int i_max, const int j_max) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + + if (j > j_max) { + return; + } + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { + const int i = i0 + threadIdx.x; + + if (need_check && i > i_max) { + continue; + } + + dst[ids_dst[j]*stride + i] = sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size]; + } + } +} + +template +static __device__ __forceinline__ void mmq_write_back_mma_id( + const float * __restrict__ sum, const int * __restrict__ ids_dst, float * __restrict__ dst, + const int stride, const int i_max, const int j_max) { + + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int nwarps = mmq_get_nwarps_device(); + +#if defined(AMD_MFMA_AVAILABLE) + constexpr int tileC_IJ = mmq_get_granularity_device(0); + typedef tile tile_C; + constexpr int rows_per_warp = granularity; +#else + typedef tile<16, 8, int> tile_C; + constexpr int rows_per_warp = 2 * granularity; +#endif // defined(AMD_MFMA_AVAILABLE) + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. + + const int i0 = (threadIdx.y / ntx) * (ntx*tile_C::I); +#if defined(TURING_MMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) + static_assert(nwarps*tile_C::I == mmq_y, "nwarps*tile_C::I != mmq_y"); +#else + GGML_UNUSED(nwarps); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { +#pragma unroll + for (int n = 0; n < ntx; ++n) { +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + const int j = j0 + (threadIdx.y % ntx) * tile_C::J + tile_C::get_j(l); + + if (j > j_max) { + continue; + } + + const int i = i0 + n*tile_C::I + tile_C::get_i(l); + + if (need_check && i > i_max) { + continue; + } + + dst[ids_dst[j]*stride + i] = sum[(j0/tile_C::J + n)*tile_C::ne + l]; + } + } + } +} + +// ===================================== ik_llama.cpp quants ============================================================= + +// +// Strictly speaking, we should bite the bullet and change WARP_SIZE to warp_size or MMQ_TILE_NE_K. +// But as we basically don't support anything but Nvidia in the CUDA backend, we alays have +// WARP_SIZE = MMQ_TILE_NE_K = 32 +// + + +// ------------------------------------------------------------------------------------------------------------------------------------- + +template +struct mmq_type_traits_id; + +template +struct mmq_type_traits_id { + //static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_0; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_0_q8_1_dp4a; +}; + +template +struct mmq_type_traits_id { + //static constexpr int vdr = VDR_Q4_1_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_1; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_1_q8_1_dp4a; +}; + +template +struct mmq_type_traits_id { + //static constexpr int vdr = VDR_Q5_0_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; +}; + +template +struct mmq_type_traits_id { + //static constexpr int vdr = VDR_Q5_1_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_1; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a; +}; + +template +struct mmq_type_traits_id { + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_0; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; +}; + +template +struct mmq_type_traits_id { + //static constexpr int vdr = VDR_Q8_0_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q8_0; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; +}; + +template +struct mmq_type_traits_id { + //static constexpr int vdr = VDR_MXFP4_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_mxfp4; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; +}; + +template +struct mmq_type_traits_id { + //static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q2_K; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q2_K_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q2_K_q8_1_dp4a; +}; + +template +struct mmq_type_traits_id { + //static constexpr int vdr = VDR_Q3_K_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q3_K; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q3_K_q8_1_dp4a; +}; + +template +struct mmq_type_traits_id { + //static constexpr int vdr = VDR_Q4_K_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_K_q8_1_dp4a; +}; + +template +struct mmq_type_traits_id { + //static constexpr int vdr = VDR_Q5_K_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q5_K_q8_1_dp4a; +}; + +template +struct mmq_type_traits_id { + //static constexpr int vdr = VDR_Q6_K_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_K; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q6_K_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q6_K_q8_1_dp4a; +}; + +template +struct mmq_type_traits_id { + //static constexpr int vdr = VDR_IQ2_XXS_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_xxs; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; +}; + +template +struct mmq_type_traits_id { + //static constexpr int vdr = VDR_IQ2_XS_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_xs; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a; +}; + +template +struct mmq_type_traits_id { + //static constexpr int vdr = VDR_IQ2_S_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_s; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a; +}; + +template +struct mmq_type_traits_id { + //static constexpr int vdr = VDR_IQ3_XXS_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_xxs; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; +}; + +template +struct mmq_type_traits_id { + //static constexpr int vdr = VDR_IQ3_S_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_s; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; +}; + +template +struct mmq_type_traits_id { + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq1_s; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a; +}; + +template +struct mmq_type_traits_id { + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq1_s_r4; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_0_q8_1_dp4a; +}; + +template +struct mmq_type_traits_id { + //static constexpr int vdr = VDR_IQ4_NL_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_nl; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; +}; + +template +struct mmq_type_traits_id { + //static constexpr int vdr = VDR_IQ4_XS_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_xs; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; +}; + +template +static __device__ __forceinline__ void mul_mat_q_process_tile_id( + const char * __restrict__ x, const int * __restrict__ y, + const int * __restrict__ ids_dst, float * __restrict__ dst, float * __restrict__ tmp_fixup, + const int stride_row_x, const int ncols_y, const int stride_col_dst, + const int tile_x_max_i, const int tile_y_max_j, const int kb0_start, const int kb0_stop) { + + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int qk = ggml_cuda_type_traits::qk; + constexpr int mmq_y = get_mmq_y_device(); + constexpr load_tiles_mmq_t load_tiles = mmq_type_traits_id::load_tiles; + + extern __shared__ int data_mul_mat_q[]; + int * tile_y = data_mul_mat_q + mmq_x; + int * tile_x = tile_y + GGML_PAD(mmq_x*MMQ_TILE_Y_K, nwarps*warp_size); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + constexpr vec_dot_mmq_t vec_dot = mmq_type_traits_id::vec_dot_mma; + constexpr mmq_write_back_t write_back = mmq_write_back_mma_id; +#else + constexpr vec_dot_mmq_t vec_dot = mmq_type_traits_id::vec_dot_dp4a; + constexpr mmq_write_back_t write_back = mmq_write_back_dp4a_id; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + + constexpr int blocks_per_iter = MMQ_ITER_K / qk; + + float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0.0f}; + + for (int kb0 = kb0_start; kb0 < kb0_stop; kb0 += blocks_per_iter) { + load_tiles(x, tile_x, kb0, tile_x_max_i, stride_row_x); + + { + const int * by0 = y + ncols_y*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 0*sizeof(block_q8_1_mmq)/sizeof(int)); +#pragma unroll + for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*warp_size) { + int l = l0 + threadIdx.y*warp_size + threadIdx.x; + + tile_y[l] = by0[l]; + } + } + + __syncthreads(); + + vec_dot(tile_x, tile_y, sum, 0); + + __syncthreads(); + + { + const int * by0 = y + ncols_y*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 1*sizeof(block_q8_1_mmq)/sizeof(int)); +#pragma unroll + for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*warp_size) { + int l = l0 + threadIdx.y*warp_size + threadIdx.x; + + tile_y[l] = by0[l]; + } + } + + __syncthreads(); + + vec_dot(tile_x, tile_y, sum, MMQ_TILE_NE_K); + + __syncthreads(); + } + + if (fixup) { + write_back(sum, ids_dst, tmp_fixup + blockIdx.x*(mmq_x*mmq_y), mmq_y, mmq_y, mmq_x); + } else { + write_back(sum, ids_dst, dst, stride_col_dst, tile_x_max_i, tile_y_max_j); + } +} + + +// The mul_mat_q_id kernel implements "stream-k" work partitioning as described in https://arxiv.org/abs/2301.03598 + +template +#if defined(GGML_USE_HIP) +#if defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN) + __launch_bounds__(ggml_cuda_get_physical_warp_size()*mmq_get_nwarps_device(), 2) +#endif // defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN) +#else +#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA + __launch_bounds__(ggml_cuda_get_physical_warp_size()*mmq_get_nwarps_device(), 1) +#else + __launch_bounds__(ggml_cuda_get_physical_warp_size()*mmq_get_nwarps_device(), 2) +#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA +#endif // defined(GGML_USE_HIP) +static __global__ void mul_mat_q_id( + const char * __restrict__ x, const int * __restrict__ y, const int32_t * __restrict__ ids_dst, + const int32_t * __restrict__ expert_bounds, float * __restrict__ dst, float * __restrict__ tmp_fixup, + const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_row_x, const int ncols_y, const int stride_col_dst, + const int channel_ratio, const int nchannels_y, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, + const int sample_ratio, const int nsamples_y, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, + const int ncols_max) { + + // Skip unused template specializations for faster compilation: + if (mmq_x > get_mmq_x_max_device() || mmq_x % mmq_get_granularity_device(mmq_x) != 0) { + NO_DEVICE_CODE; + return; + } + + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + + constexpr int qk = ggml_cuda_type_traits::qk; + constexpr int mmq_y = get_mmq_y_device(); + + const int ntx = (ncols_max + mmq_x - 1) / mmq_x; // Number of tiles x + const int nty = (nrows_x + mmq_y - 1) / mmq_y; // Number of tiles y + + // Initialize the ids for writing back data with just the index. + // For regular matrix multiplications this is never changed. + // For MoE the correct indices are loaded from ids_dst. + extern __shared__ int ids_dst_shared[]; // Stored at beginning of shared memory. +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) { + const int j = j0 + threadIdx.y*warp_size + threadIdx.x; + + if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) { + break; + } + + ids_dst_shared[j] = j; + } + __syncthreads(); + + // On non-CDNA AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead: +#if (defined(GGML_USE_HIP) && !defined(CDNA)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA + { + const int wt = blockIdx.z / nchannels_y; + const int zt = blockIdx.z - wt*nchannels_y; + const int jt = blockIdx.y; + const int it = blockIdx.x; + + // Defaults for regular matrix multiplication: + int col_low = 0; + int col_high = ncols_dst; + int col_diff = ncols_dst; + int offset_y = wt*stride_sample_y + zt*stride_channel_y; + int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst; + + if (ids_dst) { + col_low = expert_bounds[zt + 0]; + col_high = expert_bounds[zt + 1]; + col_diff = col_high - col_low; + + offset_y = 0; + offset_dst = 0; + + if (jt*mmq_x >= col_diff) { + return; + } + + // __syncthreads(); // There is no previous tile that could cause a race condition. +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) { + const int j = j0 + threadIdx.y*warp_size + threadIdx.x; + + if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) { + break; + } + + ids_dst_shared[j] = ids_dst[col_low + jt*mmq_x + j]; + } + __syncthreads(); + } + + offset_y += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int)); + offset_dst += it*mmq_y; + + const int tile_x_max_i = nrows_x - it*mmq_y - 1; + const int tile_y_max_j = col_diff - jt*mmq_x - 1; + + const int64_t offset_x = (wt/sample_ratio )*int64_t(stride_sample_x) + + (zt/channel_ratio)*int64_t(stride_channel_x) + it*mmq_y*int64_t(stride_row_x); + + constexpr bool fixup = false; + mul_mat_q_process_tile_id + (x + offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst, + tile_x_max_i, tile_y_max_j, 0, ncols_x/qk); + return; + } +#endif // (defined(GGML_USE_HIP) && !defined(CDNA3)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA + + const int64_t blocks_per_ne00 = ncols_x / qk; + constexpr int blocks_per_iter = MMQ_ITER_K / qk; + + // kbc == k block continuous, current index in continuous ijk space. + int64_t kbc = (int64_t) blockIdx.x *nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x; + int64_t kbc_stop = (int64_t)(blockIdx.x + 1)*nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x; + + kbc -= (kbc % blocks_per_ne00) % blocks_per_iter; + kbc_stop -= (kbc_stop % blocks_per_ne00) % blocks_per_iter; + + // kb0 == k index when doing the matrix multiplication for an output tile. + int kb0_start = kbc % blocks_per_ne00; + int kb0_stop = min(blocks_per_ne00, kb0_start + kbc_stop - kbc); + while (kbc < kbc_stop && kb0_stop == blocks_per_ne00) { + int tmp = kbc; + const int it = tmp / (nsamples_y*nchannels_y*ntx*blocks_per_ne00); + tmp -= it * (nsamples_y*nchannels_y*ntx*blocks_per_ne00); + const int wt = tmp / (nchannels_y*ntx*blocks_per_ne00); + tmp -= wt * (nchannels_y*ntx*blocks_per_ne00); + const int zt = tmp / (ntx*blocks_per_ne00); + tmp -= zt * (ntx*blocks_per_ne00); + const int jt = tmp / blocks_per_ne00; + + // Defaults for regular matrix multiplication: + int col_low = 0; + int col_high = ncols_dst; + int col_diff = ncols_dst; + int offset_y = wt*stride_sample_y + zt*stride_channel_y; + int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst; + + if (ids_dst) { + col_low = expert_bounds[zt + 0]; + col_high = expert_bounds[zt + 1]; + col_diff = col_high - col_low; + + offset_y = 0; + offset_dst = 0; + + if (jt*mmq_x >= col_diff) { + kbc += blocks_per_ne00; + kbc -= kbc % blocks_per_ne00; + + kb0_start = 0; + kb0_stop = min(blocks_per_ne00, kbc_stop - kbc); + + continue; + } + + __syncthreads(); +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) { + const int j = j0 + threadIdx.y*warp_size + threadIdx.x; + + if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) { + break; + } + + ids_dst_shared[j] = ids_dst[col_low + jt*mmq_x + j]; + } + __syncthreads(); + } + + offset_y += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int)); + offset_dst += it*mmq_y; + + const int tile_x_max_i = nrows_x - it*mmq_y - 1; + const int tile_y_max_j = col_diff - jt*mmq_x - 1; + + const int64_t offset_x = (wt/sample_ratio )*int64_t(stride_sample_x) + + (zt/channel_ratio)*int64_t(stride_channel_x) + it*mmq_y*int64_t(stride_row_x); + + constexpr bool fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer. + mul_mat_q_process_tile_id + (x + offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst, + tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop); + + kbc += blocks_per_ne00; + kbc -= kbc % blocks_per_ne00; + + kb0_start = 0; + kb0_stop = min(blocks_per_ne00, kbc_stop - kbc); + } + + if (kbc >= kbc_stop) { + return; + } + + int tmp = kbc; + const int it = tmp / (nsamples_y*nchannels_y*ntx*blocks_per_ne00); + tmp -= it * (nsamples_y*nchannels_y*ntx*blocks_per_ne00); + const int wt = tmp / (nchannels_y*ntx*blocks_per_ne00); + tmp -= wt * (nchannels_y*ntx*blocks_per_ne00); + const int zt = tmp / (ntx*blocks_per_ne00); + tmp -= zt * (ntx*blocks_per_ne00); + const int jt = tmp / blocks_per_ne00; + + // Defaults for regular matrix multiplication: + int col_low = 0; + int col_high = ncols_dst; + int col_diff = ncols_dst; + int offset_y = wt*stride_sample_y + zt*stride_channel_y; + int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst; + + if (ids_dst) { + col_low = expert_bounds[zt + 0]; + col_high = expert_bounds[zt + 1]; + col_diff = col_high - col_low; + + offset_y = 0; + offset_dst = 0; + + if (jt*mmq_x >= col_diff) { + return; + } + + // The memory layout for the fixup buffer is always contiguous, therefore reset ids: + __syncthreads(); +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) { + const int j = j0 + threadIdx.y*warp_size + threadIdx.x; + + if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) { + break; + } + + ids_dst_shared[j] = j; + } + __syncthreads(); + } + + offset_y += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int)); + offset_dst += it*mmq_y; + + const int tile_x_max_i = nrows_x - it*mmq_y - 1; + const int tile_y_max_j = col_diff - jt*mmq_x - 1; + + const int64_t offset_x = (wt/sample_ratio )*int64_t(stride_sample_x) + + (zt/channel_ratio)*int64_t(stride_channel_x) + it*mmq_y*int64_t(stride_row_x); + + constexpr bool fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks. + mul_mat_q_process_tile_id + (x + offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst, + tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop); +} + + +template +static __global__ void mul_mat_q_stream_k_fixup_id( + const int32_t * ids_dst, const int32_t * expert_bounds, float * __restrict__ dst, const float * __restrict__ tmp_last_tile, + const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_col_dst, + const int nchannels_y, const int stride_channel_dst, const int nsamples_y, const int stride_sample_dst, + const int ncols_max) { + constexpr int mmq_y = get_mmq_y_device(); + constexpr int qk = ggml_cuda_type_traits::qk; + constexpr int blocks_per_iter = MMQ_ITER_K / qk; + const int64_t blocks_per_ne00 = ncols_x / qk; + + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + + float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0.0f}; + + const int ntx = (ncols_max + mmq_x - 1) / mmq_x; + const int nty = (nrows_x + mmq_y - 1) / mmq_y; + + const int bidx0 = blockIdx.x; + + // kbc == k block continuous, current index in continuous ijk space. + int64_t kbc0 = (int64_t) bidx0 *nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x; + int64_t kbc0_stop = (int64_t)(bidx0 + 1)*nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x; + + kbc0 -= (kbc0 % blocks_per_ne00) % blocks_per_iter; + kbc0_stop -= (kbc0_stop % blocks_per_ne00) % blocks_per_iter; + + const bool did_not_have_any_data = kbc0 == kbc0_stop; + const bool wrote_beginning_of_tile = kbc0 % blocks_per_ne00 == 0; + const bool did_not_write_last = kbc0/blocks_per_ne00 == kbc0_stop/blocks_per_ne00 && kbc0_stop % blocks_per_ne00 != 0; + if (did_not_have_any_data || wrote_beginning_of_tile || did_not_write_last) { + return; + } + + bool any_fixup = false; + + // Iterate over previous blocks and sum up partial sums written to fixup buffer. + // All CUDA blocks that get here must have a previous block that needs a fixup. + int64_t bidx = bidx0 - 1; + int64_t kbc_stop = kbc0; + while(true) { + int64_t kbc = bidx*nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x; + kbc -= (kbc % blocks_per_ne00) % blocks_per_iter; + + if (kbc == kbc_stop) { // Did not have any data. + bidx--; + kbc_stop = kbc; + continue; + } + + any_fixup = true; + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { + const int i = i0 + threadIdx.x; + + sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size] += tmp_last_tile[bidx*(mmq_x*mmq_y) + j*mmq_y + i]; + } + } + + // If this block started in a previous tile we are done and don't need to combine additional partial results. + if (kbc % blocks_per_ne00 == 0 || kbc/blocks_per_ne00 < kbc0/blocks_per_ne00) { + break; + } + bidx--; + kbc_stop = kbc; + } + + if (!any_fixup) { + return; + } + + int tmp = kbc0; + const int it = tmp / (nsamples_y*nchannels_y*ntx*blocks_per_ne00); + tmp -= it * (nsamples_y*nchannels_y*ntx*blocks_per_ne00); + const int wt = tmp / (nchannels_y*ntx*blocks_per_ne00); + tmp -= wt * (nchannels_y*ntx*blocks_per_ne00); + const int zt = tmp / (ntx*blocks_per_ne00); + tmp -= zt * (ntx*blocks_per_ne00); + const int jt = tmp / blocks_per_ne00; + + if (!ids_dst) { + const int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst + it*mmq_y; + dst += offset_dst; + + const int i_max = nrows_x - it*mmq_y - 1; + const int j_max = ncols_dst - jt*mmq_x - 1; + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + + if (j > j_max) { + return; + } + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { + const int i = i0 + threadIdx.x; + + if (need_check && i > i_max) { + continue; + } + + dst[j*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size]; + } + } + return; + } + + __shared__ int ids_dst_shared[mmq_x]; + const int col_low = expert_bounds[zt + 0]; + const int col_high = expert_bounds[zt + 1]; + const int col_diff = col_high - col_low; + + for (int j = threadIdx.y*warp_size + threadIdx.x; j < mmq_x; j += nwarps*warp_size) { + ids_dst_shared[j] = ids_dst[col_low + j]; + } + __syncthreads(); + + const int offset_dst = it*mmq_y; + dst += offset_dst; + + const int i_max = nrows_x - it*mmq_y - 1; + const int j_max = col_diff - jt*mmq_x - 1; + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + + if (j > j_max) { + return; + } + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { + const int i = i0 + threadIdx.x; + + if (need_check && i > i_max) { + continue; + } + + dst[ids_dst_shared[j]*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size]; + } + } +} + +struct mmq_args_id { + const char * x; ggml_type type_x; const int * y; const int32_t * ids_dst; const int32_t * expert_bounds; float * dst; + int64_t ncols_x; int64_t nrows_x; int64_t ncols_dst; int64_t stride_row_x; int64_t ncols_y; int64_t nrows_dst; + int64_t nchannels_x; int64_t nchannels_y; int64_t stride_channel_x; int64_t stride_channel_y; int64_t stride_channel_dst; + int64_t nsamples_x; int64_t nsamples_y; int64_t stride_sample_x; int64_t stride_sample_y; int64_t stride_sample_dst; + bool use_stream_k; int64_t ncols_max; +}; + +template +static size_t mmq_get_nbytes_shared(const int mmq_x, const int mmq_y, const int cc, const int warp_size, const int nwarps) { + const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(type, mmq_y); + const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type); + const size_t nbs_ids = mmq_x*sizeof(int); + const size_t nbs_x = (turing_mma_available(cc) || amd_mfma_available(cc)) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int); + const size_t nbs_y = mmq_x*sizeof(block_q8_1_mmq); + return nbs_ids + nbs_x + GGML_PAD(nbs_y, nwarps*warp_size*sizeof(int)); +} + +template +static void launch_mul_mat_q_id(ggml_backend_cuda_context & ctx, const mmq_args_id & args, cudaStream_t stream) { + const int id = ggml_cuda_get_device(); + const int cc = ggml_cuda_info().devices[id].cc; + const int nsm = ggml_cuda_info().devices[id].nsm; + const int warp_size = ggml_cuda_get_physical_warp_size_host(); //ggml_cuda_info().devices[id].warp_size; + const int nwarps = mmq_get_nwarps_host(cc, warp_size); + const int mmq_y = get_mmq_y_host(cc); + + const dim3 block_dims(warp_size, nwarps, 1); + + const int nbytes_shared = mmq_get_nbytes_shared(mmq_x, mmq_y, cc, warp_size, nwarps); + + CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q_id), nbytes_shared); + CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q_id), nbytes_shared); + + const int nty = (args.nrows_x + mmq_y - 1) / mmq_y; + const int ntx = (args.ncols_max + mmq_x - 1) / mmq_x; + const int ntzw = args.nchannels_y * args.nsamples_y; + const dim3 block_nums_xy_tiling(nty, ntx, ntzw); + + GGML_ASSERT(args.nchannels_y % args.nchannels_x == 0); + GGML_ASSERT(args.nsamples_y % args.nsamples_x == 0); + const int channel_ratio = args.nchannels_y / args.nchannels_x; + const int sample_ratio = args.nsamples_y / args.nsamples_x; + + if (!args.use_stream_k) { + if (args.nrows_x % mmq_y == 0) { + constexpr bool need_check = false; + mul_mat_q_id<<>> + (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr, + args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, + channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, + sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst, + args.ncols_max); + } else { + constexpr bool need_check = true; + mul_mat_q_id<<>> + (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr, + args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, + channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, + sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst, + args.ncols_max); + } + return; + } + + const dim3 block_nums_stream_k(nsm, 1, 1); + const bool fixup_needed = ntx*nty*ntzw % nsm != 0; + + ggml_cuda_pool & pool = ctx.pool(id); + ggml_cuda_pool_alloc tmp_fixup(pool); + if (fixup_needed) { + tmp_fixup.alloc(block_nums_stream_k.x * mmq_x*mmq_y); + } + + if (args.nrows_x % mmq_y == 0) { + constexpr bool need_check = false; + mul_mat_q_id<<>> + (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, + args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, + channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, + sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst, + args.ncols_max); + + if (!fixup_needed) { + return; + } + + mul_mat_q_stream_k_fixup_id<<>> + (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst, + args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst, + args.ncols_max); + } else { + constexpr bool need_check = true; + mul_mat_q_id<<>> + (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, + args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, + channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, + sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst, + args.ncols_max); + + if (!fixup_needed) { + return; + } + + mul_mat_q_stream_k_fixup_id<<>> + (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst, + args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst, + args.ncols_max); + } +} + +template +void mul_mat_q_case_id(ggml_backend_cuda_context & ctx, const mmq_args_id & args, cudaStream_t stream) { + const int id = ggml_cuda_get_device(); + const int cc = ggml_cuda_info().devices[id].cc; + const size_t smpbo = ggml_cuda_info().devices[id].smpbo; + const int warp_size = ggml_cuda_get_physical_warp_size_host(); //ggml_cuda_info().devices[id].warp_size; + const int nwarps = mmq_get_nwarps_host(cc, warp_size); + + const int mmq_x_max = get_mmq_x_max_host(cc); + const int mmq_y = get_mmq_y_host(cc); + + int mmq_x_best = 0; + int ntiles_x_best = INT_MAX; + + for (int mmq_x = 8; mmq_x <= mmq_x_max && ntiles_x_best > 1; mmq_x += 8) { + const int granularity = mmq_get_granularity_host(mmq_x, cc); + + if (mmq_x % granularity != 0 || mmq_get_nbytes_shared(mmq_x, mmq_y, cc, warp_size, nwarps) > smpbo) { + continue; + } + + const int ntiles_x = (args.ncols_max + mmq_x - 1) / mmq_x; + + if (ntiles_x < ntiles_x_best) { + mmq_x_best = mmq_x; + ntiles_x_best = ntiles_x; + } + } + + switch (mmq_x_best) { + case 8: + launch_mul_mat_q_id(ctx, args, stream); + break; + case 16: + launch_mul_mat_q_id(ctx, args, stream); + break; + case 24: + launch_mul_mat_q_id(ctx, args, stream); + break; + case 32: + launch_mul_mat_q_id(ctx, args, stream); + break; + case 40: + launch_mul_mat_q_id(ctx, args, stream); + break; + case 48: + launch_mul_mat_q_id(ctx, args, stream); + break; + case 56: + launch_mul_mat_q_id(ctx, args, stream); + break; + case 64: + launch_mul_mat_q_id(ctx, args, stream); + break; + case 72: + launch_mul_mat_q_id(ctx, args, stream); + break; + case 80: + launch_mul_mat_q_id(ctx, args, stream); + break; + case 88: + launch_mul_mat_q_id(ctx, args, stream); + break; + case 96: + launch_mul_mat_q_id(ctx, args, stream); + break; + case 104: + launch_mul_mat_q_id(ctx, args, stream); + break; + case 112: + launch_mul_mat_q_id(ctx, args, stream); + break; + case 120: + launch_mul_mat_q_id(ctx, args, stream); + break; + case 128: + launch_mul_mat_q_id(ctx, args, stream); + break; + default: + fprintf(stderr, "mmq_x_best=%d\n", mmq_x_best); + GGML_ABORT("fatal error"); + break; + } +} + +#define DECL_MMQ_CASE(type) \ + template void mul_mat_q_case_id(ggml_backend_cuda_context & ctx, const mmq_args_id & args, cudaStream_t stream) \ + +extern DECL_MMQ_CASE(GGML_TYPE_Q4_0); +extern DECL_MMQ_CASE(GGML_TYPE_Q4_1); +extern DECL_MMQ_CASE(GGML_TYPE_Q5_0); +extern DECL_MMQ_CASE(GGML_TYPE_Q5_1); +extern DECL_MMQ_CASE(GGML_TYPE_Q6_0); +extern DECL_MMQ_CASE(GGML_TYPE_Q8_0); +extern DECL_MMQ_CASE(GGML_TYPE_MXFP4); +extern DECL_MMQ_CASE(GGML_TYPE_Q2_K); +extern DECL_MMQ_CASE(GGML_TYPE_Q3_K); +extern DECL_MMQ_CASE(GGML_TYPE_Q4_K); +extern DECL_MMQ_CASE(GGML_TYPE_Q5_K); +extern DECL_MMQ_CASE(GGML_TYPE_Q6_K); +extern DECL_MMQ_CASE(GGML_TYPE_IQ2_XXS); +extern DECL_MMQ_CASE(GGML_TYPE_IQ2_XS); +extern DECL_MMQ_CASE(GGML_TYPE_IQ2_S); +extern DECL_MMQ_CASE(GGML_TYPE_IQ3_XXS); +extern DECL_MMQ_CASE(GGML_TYPE_IQ3_S); +extern DECL_MMQ_CASE(GGML_TYPE_IQ1_S); +extern DECL_MMQ_CASE(GGML_TYPE_IQ1_S_R4); +extern DECL_MMQ_CASE(GGML_TYPE_IQ4_NL); +extern DECL_MMQ_CASE(GGML_TYPE_IQ4_XS); +extern DECL_MMQ_CASE(GGML_TYPE_IQ2_KS); +extern DECL_MMQ_CASE(GGML_TYPE_IQ2_KL); +extern DECL_MMQ_CASE(GGML_TYPE_IQ2_K); +extern DECL_MMQ_CASE(GGML_TYPE_IQ2_K_R4); +extern DECL_MMQ_CASE(GGML_TYPE_IQ3_KS); +extern DECL_MMQ_CASE(GGML_TYPE_IQ3_K); +extern DECL_MMQ_CASE(GGML_TYPE_IQ3_K_R4); +extern DECL_MMQ_CASE(GGML_TYPE_IQ4_KSS); +extern DECL_MMQ_CASE(GGML_TYPE_IQ4_KS); +extern DECL_MMQ_CASE(GGML_TYPE_IQ4_KS_R4); +extern DECL_MMQ_CASE(GGML_TYPE_IQ4_K); +extern DECL_MMQ_CASE(GGML_TYPE_IQ4_K_R4); +extern DECL_MMQ_CASE(GGML_TYPE_IQ5_KS); +extern DECL_MMQ_CASE(GGML_TYPE_IQ5_KS_R4); +extern DECL_MMQ_CASE(GGML_TYPE_IQ5_K); +extern DECL_MMQ_CASE(GGML_TYPE_IQ5_K_R4); +extern DECL_MMQ_CASE(GGML_TYPE_IQ6_K); +extern DECL_MMQ_CASE(GGML_TYPE_IQ1_KT); +extern DECL_MMQ_CASE(GGML_TYPE_IQ2_KT); +extern DECL_MMQ_CASE(GGML_TYPE_IQ3_KT); +extern DECL_MMQ_CASE(GGML_TYPE_IQ4_KT); + +// ------------------------------------------------------------------------------------------------------------------------- diff --git a/ggml/src/ggml-cuda/quantize_id.cu b/ggml/src/ggml-cuda/quantize_id.cu new file mode 100644 index 000000000..80a9f2220 --- /dev/null +++ b/ggml/src/ggml-cuda/quantize_id.cu @@ -0,0 +1,132 @@ +#include "quantize_id.cuh" +#include "mmq.cuh" +#include + +template +static __global__ void quantize_mmq_q8_1( + const float * __restrict__ x, const int32_t * __restrict__ ids, void * __restrict__ vy, + const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03, + const int64_t ne0, const int ne1, const int ne2) { + + constexpr int vals_per_scale = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 64 : 32; + constexpr int vals_per_sum = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 16 : 32; + + const int64_t i0 = ((int64_t)blockDim.x*blockIdx.y + threadIdx.x)*4; + + if (i0 >= ne0) { + return; + } + + const int64_t i1 = blockIdx.x; + const int64_t i2 = blockIdx.z % ne2; + const int64_t i3 = blockIdx.z / ne2; + + const int64_t i00 = i0; + const int64_t i01 = ids ? ids[i1] : i1; + const int64_t i02 = i2; + const int64_t i03 = i3; + + const float4 * x4 = (const float4 *) x; + + block_q8_1_mmq * y = (block_q8_1_mmq *) vy; + + const int64_t ib0 = blockIdx.z*((int64_t)gridDim.x*gridDim.y*blockDim.x/QK8_1); // first block of channel + const int64_t ib = ib0 + (i0 / (4*QK8_1))*ne1 + blockIdx.x; // block index in channel + const int64_t iqs = i0 % (4*QK8_1); // quant index in block + + // Load 4 floats per thread and calculate max. abs. value between them: + const float4 xi = i0 < ne00 ? x4[(i03*s03 + i02*s02 + i01*s01 + i00)/4] : make_float4(0.0f, 0.0f, 0.0f, 0.0f); + float amax = fabsf(xi.x); + amax = fmaxf(amax, fabsf(xi.y)); + amax = fmaxf(amax, fabsf(xi.z)); + amax = fmaxf(amax, fabsf(xi.w)); + + // Exchange max. abs. value between vals_per_scale/4 threads. +#pragma unroll + for (int offset = vals_per_scale/8; offset > 0; offset >>= 1) { + amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, offset, WARP_SIZE)); + } + + float sum; + if (ds_layout != MMQ_Q8_1_DS_LAYOUT_D4) { + sum = xi.x + xi.y + xi.z + xi.w; + + // Calculate sums across vals_per_sum/4 threads. +#pragma unroll + for (int offset = vals_per_sum/8; offset > 0; offset >>= 1) { + sum += __shfl_xor_sync(0xFFFFFFFF, sum, offset, WARP_SIZE); + } + } + + const float d_inv = 127.0f / amax; + char4 q; + q.x = roundf(xi.x*d_inv); + q.y = roundf(xi.y*d_inv); + q.z = roundf(xi.z*d_inv); + q.w = roundf(xi.w*d_inv); + + // Write back 4 int8 values as a single 32 bit value for better memroy bandwidth: + char4 * yqs4 = (char4 *) y[ib].qs; + yqs4[iqs/4] = q; + + if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6) { + if (iqs % 16 != 0 || iqs >= 96) { + return; + } + + y[ib].d2s6[2 + iqs/16] = sum; + + if (iqs % 64 != 0) { + return; + } + + const float d = 1.0f / d_inv; + + y[ib].d2s6[iqs/64] = d; + + return; + } + + if (iqs % 32 != 0) { + return; + } + + const float d = 1.0f / d_inv; + + if (ds_layout == MMQ_Q8_1_DS_LAYOUT_DS4) { + y[ib].ds4[iqs/32] = make_half2(d, sum); + } else { + y[ib].d4[iqs/32] = d; + } +} + +void quantize_mmq_q8_1_cuda_id( + const float * x, const int32_t * ids, void * vy, const ggml_type type_src0, + const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03, + const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) { + GGML_ASSERT(ne00 % 4 == 0); + GGML_ASSERT(ne0 % (4*QK8_1) == 0); + GGML_ASSERT(ids); + + // ne1 tends to assume the highest values, therefore use it as the "x" dimension of the CUDA grid: + const int64_t block_num_y = (ne0 + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ); + const dim3 num_blocks(ne1, block_num_y, ne2*ne3); + const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE_MMQ, 1, 1); + switch (mmq_get_q8_1_ds_layout(type_src0)) { + case MMQ_Q8_1_DS_LAYOUT_D4: + quantize_mmq_q8_1 + <<>>(x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2); + break; + case MMQ_Q8_1_DS_LAYOUT_DS4: + quantize_mmq_q8_1 + <<>>(x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2); + break; + case MMQ_Q8_1_DS_LAYOUT_D2S6: + quantize_mmq_q8_1 + <<>>(x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2); + break; + default: + GGML_ABORT("fatal error"); + break; + } +} diff --git a/ggml/src/ggml-cuda/quantize_id.cuh b/ggml/src/ggml-cuda/quantize_id.cuh new file mode 100644 index 000000000..d37539236 --- /dev/null +++ b/ggml/src/ggml-cuda/quantize_id.cuh @@ -0,0 +1,16 @@ +#pragma once + +#include "common.cuh" + +#include + +#define CUDA_QUANTIZE_BLOCK_SIZE 256 +#define CUDA_QUANTIZE_BLOCK_SIZE_MMQ 128 + +//static_assert(MATRIX_ROW_PADDING % CUDA_QUANTIZE_BLOCK_SIZE == 0, "Risk of out-of-bounds access."); +//static_assert(MATRIX_ROW_PADDING % (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ) == 0, "Risk of out-of-bounds access."); + +void quantize_mmq_q8_1_cuda_id( + const float * x, const int32_t * ids, void * vy, + ggml_type type_src0, int64_t ne00, int64_t s01, int64_t s02, int64_t s03, + int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, cudaStream_t stream); diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq1_kt_id.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq1_kt_id.cu new file mode 100644 index 000000000..bbdc750cf --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq1_kt_id.cu @@ -0,0 +1,83 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq_id_common.cuh" + +template static __device__ __forceinline__ void load_tiles_iq1_kt( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + + constexpr int nwarps = mmq_get_nwarps_device(); + + constexpr uint32_t ka = 0xCBAC1FED; + constexpr uint32_t km = 0x3f3f3f3f; + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE + + const int kqsx = threadIdx.x; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + threadIdx.y; + + if (need_check) { + i = min(i, i_max); + } + + const block_iq1_kt * bxi = (const block_iq1_kt *)(x + i*stride + sizeof(float)) + kbx0; + + int ib32 = kqsx/4; + int j = kqsx%4; + uint32_t val = bxi->ql[kqsx] + ((bxi->qh[kqsx%16] << (8 - 4*(kqsx/16))) & 0xf00) + ((bxi->sh[kqsx/4] << (8 - (kqsx%4))) & 0x1000) + 4096; + int2 v = {0, 0}; + for (int k = 0; k < 4; ++k) { + val *= ka; + v.x |= (ggml_cuda_dp4a(val & km, 0x01010101, -126) & 0xff) << 8*k; + } + for (int k = 0; k < 4; ++k) { + val *= ka; + v.y |= (ggml_cuda_dp4a(val & km, 0x01010101, -126) & 0xff) << 8*k; + } +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*ib32 + 2*j + 0] = v.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*ib32 + 2*j + 1] = v.y; +#else + x_qs[i*(2*WARP_SIZE + 1) + 8*ib32 + 2*j + 0] = v.x; + x_qs[i*(2*WARP_SIZE + 1) + 8*ib32 + 2*j + 1] = v.y; +#endif // INT8_MMA_AVAILABLE + } + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) { + int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4); + + if (need_check) { + i = min(i, i_max); + } + + const float * dptr = (const float *)(x + i*stride); + const float d = dptr[0]; + const block_iq1_kt * bxi = (const block_iq1_kt *)(dptr + 1) + kbx0; + const int ls = iq4k_values[bxi->sh[threadIdx.x % 8] & 0xf]; + +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = d * ls; +#else + x_df[i*(WARP_SIZE/4) + i/4 + threadIdx.x % 8] = d * ls; +#endif // INT8_MMA_AVAILABLE + } +} + +template +struct mmq_type_traits_id { + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq1_kt; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; +}; + +DECL_MMQ_CASE(GGML_TYPE_IQ1_KT); diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq1_s_id.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq1_s_id.cu new file mode 100644 index 000000000..29d2e2ce2 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq1_s_id.cu @@ -0,0 +1,6 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq_id_common.cuh" + +DECL_MMQ_CASE(GGML_TYPE_IQ1_S); +DECL_MMQ_CASE(GGML_TYPE_IQ1_S_R4); diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_k_id.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_k_id.cu new file mode 100644 index 000000000..3c3fd4c9d --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_k_id.cu @@ -0,0 +1,200 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq_id_common.cuh" + +template static __device__ __forceinline__ void load_tiles_iq2_k( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + //constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE*2); +#else + constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16; + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE + + constexpr int qstep = 8; + const int kqsx = threadIdx.x % qstep; + + #pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/qstep) { + int i = i0 + threadIdx.y*(WARP_SIZE/qstep) + threadIdx.x/qstep; + + if (need_check) { + i = min(i, i_max); + } + + const block_iq2_k * bxi = (const block_iq2_k *)(x + i*stride) + kbx0; + + const float d = bxi->d; + uint16_t extra = bxi->extra >> (kqsx/4); + +#ifdef __CUDA_ARCH__ + + uint32_t extra32[2] = { uint32_t(extra & 0xff) * 0x01010101, uint32_t(extra >> 8) * 0x01010101 }; + #pragma unroll + for (int l = 0; l < qstep/4; ++l) { + const int ql = get_int_b4(bxi->qs, kqsx + qstep*l); + uint32_t val1 = ((ql >> 0) & 0x33333333) | ((extra32[l] << 2) & 0x44444444); + uint32_t val2 = ((ql >> 2) & 0x33333333) | ((extra32[l] << 0) & 0x44444444); + int2 v1 = get_int_from_table_8(val1, iq2nl_values); + int2 v2 = get_int_from_table_8(val2, iq2nl_values); +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 0] = v1.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 8] = v2.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 16] = v1.y; + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 24] = v2.y; +#else + x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 0] = v1.x; + x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 8] = v2.x; + x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 16] = v1.y; + x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 24] = v2.y; +#endif // INT8_MMA_AVAILABLE + } + +#else + + auto all_values = (const int *)iq2k_table; + + #pragma unroll + for (int l = 0; l < qstep/4; ++l) { + + const int ql = get_int_b4(bxi->qs, kqsx + qstep*l); + +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 0] = int_from_table_4((ql >> 0) & 0x03030303, all_values + ((extra & 0x01) << 8)); + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 8] = int_from_table_4((ql >> 2) & 0x03030303, all_values + ((extra & 0x04) << 6)); + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 16] = int_from_table_4((ql >> 4) & 0x03030303, all_values + ((extra & 0x10) << 4)); + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 24] = int_from_table_4((ql >> 6) & 0x03030303, all_values + ((extra & 0x40) << 2)); +#else + x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 0] = int_from_table_4((ql >> 0) & 0x03030303, all_values + ((extra & 0x01) << 8)); + x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 8] = int_from_table_4((ql >> 2) & 0x03030303, all_values + ((extra & 0x04) << 6)); + x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 16] = int_from_table_4((ql >> 4) & 0x03030303, all_values + ((extra & 0x10) << 4)); + x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 24] = int_from_table_4((ql >> 6) & 0x03030303, all_values + ((extra & 0x40) << 2)); +#endif // INT8_MMA_AVAILABLE + + extra >>= 8; + } +#endif // __CUDA_ARCH__ + +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = d * (((bxi->scales[kqsx] >> 0) & 0xf) - 8); + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = d * (((bxi->scales[kqsx] >> 4) & 0xf) - 8); +#else + x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = d * (((bxi->scales[kqsx] >> 0) & 0xf) - 8); + x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = d * (((bxi->scales[kqsx] >> 4) & 0xf) - 8); +#endif // INT8_MMA_AVAILABLE + } +} + +template static __device__ __forceinline__ void load_tiles_iq2_k_r4( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE*2); +#else + constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16; + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE + + const int kqsx = threadIdx.x/4; // 0...7 -> block of 32 + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += 4*nwarps) { + int i = i0 + 4*threadIdx.y + threadIdx.x%4; + + if (need_check) { + i = min(i, i_max); + } + int i4 = i/4; + int ir = i%4; + + const block_iq2_k_r4 * bxi = (const block_iq2_k_r4 *)(x + 4*i4*stride) + kbx0; + + const float d = __half2float(bxi->d[ir]); + +#ifdef __CUDA_ARCH__ + #pragma unroll + for (int l = 0; l < 2; ++l) { + + uint32_t extra = uint32_t((bxi->extra[ir+4*l] >> kqsx) & 1) * 0x04040404; + extra = extra | (extra << 4); + + const int ql = get_int_b4(bxi->qs, 8*kqsx + ir + 4*l); + uint32_t val1 = ((ql >> 0) & 0x33333333) | extra; + uint32_t val2 = ((ql >> 2) & 0x33333333) | extra; + int2 v1 = get_int_from_table_8(val1, iq2nl_values); + int2 v2 = get_int_from_table_8(val2, iq2nl_values); + +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 0] = v1.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 1] = v2.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 2] = v1.y; + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 3] = v2.y; +#else + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 0] = v1.x; + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 1] = v2.x; + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 2] = v1.y; + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 3] = v2.y; +#endif // INT8_MMA_AVAILABLE + } + +#else + #pragma unroll + for (int l = 0; l < 2; ++l) { + + auto values_l = (const int *)iq2k_table + (((bxi->extra[ir+4*l] >> kqsx) & 1) << 8); + + const int ql = get_int_b4(bxi->qs, 8*kqsx + ir + 4*l); + +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 0] = int_from_table_4((ql >> 0) & 0x03030303, values_l); + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 1] = int_from_table_4((ql >> 2) & 0x03030303, values_l); + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 2] = int_from_table_4((ql >> 4) & 0x03030303, values_l); + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 3] = int_from_table_4((ql >> 6) & 0x03030303, values_l); +#else + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 0] = int_from_table_4((ql >> 0) & 0x03030303, values_l); + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 1] = int_from_table_4((ql >> 2) & 0x03030303, values_l); + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 2] = int_from_table_4((ql >> 4) & 0x03030303, values_l); + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 3] = int_from_table_4((ql >> 6) & 0x03030303, values_l); +#endif // INT8_MMA_AVAILABLE + } +#endif // __CUDA_ARCH__ + + int is = 8*kqsx + ir; + float dl1 = d * (((bxi->scales[is%32] >> 4*(is/32)) & 0xf) - 8); + is += 4; + float dl2 = d * (((bxi->scales[is%32] >> 4*(is/32)) & 0xf) - 8); + +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = dl1; + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = dl2; +#else + x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = dl1; + x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = dl2; +#endif // INT8_MMA_AVAILABLE + } +} + +template +struct mmq_type_traits_id { + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_k; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a; +}; + +template +struct mmq_type_traits_id { + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_k_r4; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a; +}; + +DECL_MMQ_CASE(GGML_TYPE_IQ2_K); +DECL_MMQ_CASE(GGML_TYPE_IQ2_K_R4); diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_kl_id.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_kl_id.cu new file mode 100644 index 000000000..66e51797e --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_kl_id.cu @@ -0,0 +1,72 @@ +#include "../mmq_id_common.cuh" + +template static __device__ __forceinline__ void load_tiles_iq2_kl( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + + constexpr int nwarps = mmq_get_nwarps_device(); + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE + + const int kqsx = threadIdx.x/4; + + uint32_t aux32[2]; + const uint8_t * a8 = (const uint8_t *)aux32; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += 4*nwarps) { + int i = i0 + 4*threadIdx.y + threadIdx.x%4; + + if (need_check) { + i = min(i, i_max); + } + + const half * dptr = (const half *)(x + i*stride); + const float d = *dptr; + const block_iq2_kl * bxi = (const block_iq2_kl *)(dptr + 1) + kbx0; + + #pragma unroll + for (int j = 0; j < 2; ++j) { + auto ql = get_int_b2(bxi->qs, 4*(kqsx/2) + 2*(kqsx%2) + j); + auto qh = get_int_b2(bxi->qh, 2*(kqsx%2) + j) >> 2*(kqsx/2); + aux32[0] = ((ql >> 0) & 0x0f0f0f0f) | ((qh << 4) & 0x10101010); + aux32[1] = ((ql >> 4) & 0x0f0f0f0f) | ((qh << 3) & 0x10101010); + #pragma unroll + for (int l = 0; l < 2; ++l) { + int val1 = iq2kl_values[a8[2*l+0]] | (iq2kl_values[a8[2*l+1]] << 16); + int val2 = iq2kl_values[a8[2*l+4]] | (iq2kl_values[a8[2*l+5]] << 16); +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 16*(kqsx/2) + 4*(kqsx%2) + 2*j + l + 0] = val1; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 16*(kqsx/2) + 4*(kqsx%2) + 2*j + l + 8] = val2; +#else + x_qs[i*(2*WARP_SIZE + 1) + 16*(kqsx/2) + 4*(kqsx%2) + 2*j + l + 0] = val1; + x_qs[i*(2*WARP_SIZE + 1) + 16*(kqsx/2) + 4*(kqsx%2) + 2*j + l + 8] = val2; +#endif + } + } + + int ls = int(((bxi->scales_l[kqsx%4] >> 4*(kqsx/4)) & 0xf) | (((bxi->scales_h >> 2*kqsx) & 3) << 4)) - 32; + +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = d * ls; +#else + x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = d * ls; +#endif + } + +} + +template +struct mmq_type_traits_id { + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_kl; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; +}; + +DECL_MMQ_CASE(GGML_TYPE_IQ2_KL); diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_ks_id.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_ks_id.cu new file mode 100644 index 000000000..87759b593 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_ks_id.cu @@ -0,0 +1,114 @@ +#include "../mmq_id_common.cuh" + +template static __device__ __forceinline__ void load_tiles_iq2_ks( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + + constexpr int nwarps = mmq_get_nwarps_device(); + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE + + const int kqsx = threadIdx.x%16; + +#ifdef __CUDA_ARCH__ + #pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += 2*nwarps) { + int i = i0 + 2*threadIdx.y + threadIdx.x/16; + + if (need_check) { + i = min(i, i_max); + } + + const block_iq2_ks * bxi = (const block_iq2_ks *)(x + i*stride + sizeof(half)) + kbx0; + + uint16_t extra = bxi->extra >> 4*(kqsx/8); + int q2 = get_int_b2(bxi->qs, kqsx); + + uint32_t extra32 = uint32_t(extra & 0xf) * 0x01010101; + uint32_t val1 = ((q2 >> 0) & 0x33333333) | ((extra32 << 2) & 0x04040404) | ((extra32 << 4) & 0x40404040); + uint32_t val2 = ((q2 >> 2) & 0x33333333) | ((extra32 << 1) & 0x04040404) | ((extra32 << 3) & 0x40404040); + int2 v1 = get_int_from_table_8(val1, iq2nl_values); + int2 v2 = get_int_from_table_8(val2, iq2nl_values); + +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx%8 + 32*(kqsx/8) + 0] = v1.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx%8 + 32*(kqsx/8) + 8] = v2.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx%8 + 32*(kqsx/8) + 16] = v1.y; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx%8 + 32*(kqsx/8) + 24] = v2.y; +#else + x_qs[i*(2*WARP_SIZE + 1) + kqsx%8 + 32*(kqsx/8) + 0] = v1.x; + x_qs[i*(2*WARP_SIZE + 1) + kqsx%8 + 32*(kqsx/8) + 8] = v2.x; + x_qs[i*(2*WARP_SIZE + 1) + kqsx%8 + 32*(kqsx/8) + 16] = v1.y; + x_qs[i*(2*WARP_SIZE + 1) + kqsx%8 + 32*(kqsx/8) + 24] = v2.y; +#endif // INT8_MMA_AVAILABLE + } + +#else // __CUDA_ARCH__ + + + const int * all_values = (const int *)iq2k_table; + #pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += 2*nwarps) { + int i = i0 + 2*threadIdx.y + threadIdx.x/16; + + if (need_check) { + i = min(i, i_max); + } + + const block_iq2_ks * bxi = (const block_iq2_ks *)(x + i*stride + sizeof(half)) + kbx0; + + uint16_t extra = bxi->extra >> 4*(kqsx/8); + int q2 = get_int_b2(bxi->qs, kqsx); + +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx%8 + 32*(kqsx/8) + 0] = int_from_table_4((q2 >> 0) & 0x03030303, all_values + ((extra & 1) << 8)); + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx%8 + 32*(kqsx/8) + 8] = int_from_table_4((q2 >> 2) & 0x03030303, all_values + ((extra & 2) << 7)); + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx%8 + 32*(kqsx/8) + 16] = int_from_table_4((q2 >> 4) & 0x03030303, all_values + ((extra & 4) << 6)); + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx%8 + 32*(kqsx/8) + 24] = int_from_table_4((q2 >> 6) & 0x03030303, all_values + ((extra & 8) << 5)); +#else + x_qs[i*(2*WARP_SIZE + 1) + kqsx%8 + 32*(kqsx/8) + 0] = int_from_table_4((q2 >> 0) & 0x03030303, all_values + ((extra & 1) << 8)); + x_qs[i*(2*WARP_SIZE + 1) + kqsx%8 + 32*(kqsx/8) + 8] = int_from_table_4((q2 >> 2) & 0x03030303, all_values + ((extra & 2) << 7)); + x_qs[i*(2*WARP_SIZE + 1) + kqsx%8 + 32*(kqsx/8) + 16] = int_from_table_4((q2 >> 4) & 0x03030303, all_values + ((extra & 4) << 6)); + x_qs[i*(2*WARP_SIZE + 1) + kqsx%8 + 32*(kqsx/8) + 24] = int_from_table_4((q2 >> 6) & 0x03030303, all_values + ((extra & 8) << 5)); +#endif // INT8_MMA_AVAILABLE + } +#endif // __CUDA_ARCH__ + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) { + int i = i0 + threadIdx.y * 8 + threadIdx.x / 4; + + if (need_check) { + i = min(i, i_max); + } + + const half * dptr = (const half *)(x + i*stride); + const float d = dptr[0]; + const block_iq2_ks * bxi = (const block_iq2_ks *)(dptr + 1) + kbx0; + const int ls1 = ((bxi->scales[threadIdx.x % 4] >> 0) & 0xf) | ((bxi->extra >> (4 + 2*(threadIdx.x % 4))) & 0x10); + const int ls2 = ((bxi->scales[threadIdx.x % 4] >> 4) & 0xf) | ((bxi->extra >> (5 + 2*(threadIdx.x % 4))) & 0x10); + +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + 2*(threadIdx.x % 4) + 0] = d * (ls1 - 16); + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + 2*(threadIdx.x % 4) + 1] = d * (ls2 - 16); +#else + x_df[i*(WARP_SIZE/4) + i/4 + 2*(threadIdx.x % 4) + 0] = d * (ls1 - 16); + x_df[i*(WARP_SIZE/4) + i/4 + 2*(threadIdx.x % 4) + 1] = d * (ls2 - 16); +#endif // INT8_MMA_AVAILABLE + } +} + +template +struct mmq_type_traits_id { + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_ks; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; +}; + +DECL_MMQ_CASE(GGML_TYPE_IQ2_KS); diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_kt_id.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_kt_id.cu new file mode 100644 index 000000000..5c2d77e77 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_kt_id.cu @@ -0,0 +1,85 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq_id_common.cuh" + +template static __device__ __forceinline__ void load_tiles_iq2_kt( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + + constexpr int nwarps = mmq_get_nwarps_device(); + + constexpr uint32_t ka = 0xCBAC1FED; + constexpr uint32_t km = 0x3f3f3f3f; + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE + + const int kqsx = threadIdx.x; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + threadIdx.y; + + if (need_check) { + i = min(i, i_max); + } + + const block_iq2_kt * bxi = (const block_iq2_kt *)(x + i*stride + sizeof(float)) + kbx0; + + int ib32 = kqsx/4; + int j = kqsx%4; + const auto ql = (const uint16_t *)bxi->ql; + uint32_t val = ql[4*ib32+j] + 4096; + int2 v = {0, 0}; + for (int k = 0; k < 4; ++k) { + val *= ka; + v.x |= (ggml_cuda_dp4a(val & km, 0x01010101, -126) & 0xff) << 8*k; + } + for (int k = 0; k < 4; ++k) { + val *= ka; + v.y |= (ggml_cuda_dp4a(val & km, 0x01010101, -126) & 0xff) << 8*k; + } +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*ib32 + 2*j + 0] = v.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*ib32 + 2*j + 1] = v.y; +#else + x_qs[i*(2*WARP_SIZE + 1) + 8*ib32 + 2*j + 0] = v.x; + x_qs[i*(2*WARP_SIZE + 1) + 8*ib32 + 2*j + 1] = v.y; +#endif // INT8_MMA_AVAILABLE + } + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) { + int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4); + + if (need_check) { + i = min(i, i_max); + } + + const float * dptr = (const float *)(x + i*stride); + const float d = dptr[0] * 1.05f; + const block_iq2_kt * bxi = (const block_iq2_kt *)(dptr + 1) + kbx0; + int ib32 = threadIdx.x % 8; + const int ls = iq4k_values[(bxi->scales[ib32%4] >> 4*(ib32/4)) & 0xf]; + +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = d * ls; +#else + x_df[i*(WARP_SIZE/4) + i/4 + threadIdx.x % 8] = d * ls; +#endif // INT8_MMA_AVAILABLE + } +} + +template +struct mmq_type_traits_id { + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_kt; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; +}; + +DECL_MMQ_CASE(GGML_TYPE_IQ2_KT); diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_s_id.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_s_id.cu new file mode 100644 index 000000000..6f54f24e6 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_s_id.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq_id_common.cuh" + +DECL_MMQ_CASE(GGML_TYPE_IQ2_S); diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xs_id.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xs_id.cu new file mode 100644 index 000000000..529586678 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xs_id.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq_id_common.cuh" + +DECL_MMQ_CASE(GGML_TYPE_IQ2_XS); diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xxs_id.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xxs_id.cu new file mode 100644 index 000000000..c735ab977 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xxs_id.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq_id_common.cuh" + +DECL_MMQ_CASE(GGML_TYPE_IQ2_XXS); diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_k_id.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_k_id.cu new file mode 100644 index 000000000..7e7ed12a2 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_k_id.cu @@ -0,0 +1,164 @@ +#include "../mmq_id_common.cuh" + +template static __device__ __forceinline__ void load_tiles_iq3_k( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + + constexpr int nwarps = mmq_get_nwarps_device(); + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE*2); +#else + constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16; + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE + + constexpr int qstep = 8; + const int kqsx = threadIdx.x % qstep; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/qstep) { + int i = i0 + threadIdx.y*(WARP_SIZE/qstep) + threadIdx.x/qstep; + + if (need_check) { + i = min(i, i_max); + } + + const block_iq3_k * bxi = (const block_iq3_k *)(x + i*stride) + kbx0; + + const float d = bxi->d; + + uint16_t extra = bxi->extra >> (kqsx/4); + uint32_t extra32[2] = { uint32_t(extra & 0xff) * 0x01010101, uint32_t(extra >> 8) * 0x01010101 }; + int qh = get_int_b2(bxi->qh, kqsx); + + #pragma unroll + for (int l = 0; l < qstep/4; ++l) { + + //extra << 3, extra << 1, extra >> 1, extra >> 3 + const int ql = get_int_b2(bxi->qs, kqsx + qstep*l); + uint32_t val1 = ((ql >> 0) & 0x33333333) | ((extra32[l] << 3) & 0x88888888) + | ((qh << 2) & 0x04040404) | ((qh << 4) & 0x40404040); + uint32_t val2 = ((ql >> 2) & 0x33333333) | ((extra32[l] << 1) & 0x88888888) + | ((qh << 1) & 0x04040404) | ((qh << 3) & 0x40404040); + int2 v1 = get_int_from_table_16(val1, iq3nl_values); + int2 v2 = get_int_from_table_16(val2, iq3nl_values); + + qh >>= 4; + +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 0] = v1.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 8] = v2.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 16] = v1.y; + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 32*l + 24] = v2.y; +#else + x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 0] = v1.x; + x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 8] = v2.x; + x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 16] = v1.y; + x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 24] = v2.y; +#endif // INT8_MMA_AVAILABLE + } + + uint16_t sh = bxi->scales_h >> 2*kqsx; +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = d * ((2*(bxi->scales_l[kqsx] & 0xf) + 1) * (sh & 1 ? -1 : 1)); + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = d * ((2*(bxi->scales_l[kqsx] >> 4) + 1) * (sh & 2 ? -1 : 1)); +#else + x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = d * ((2*(bxi->scales_l[kqsx] & 0xf) + 1) * (sh & 1 ? -1 : 1)); + x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = d * ((2*(bxi->scales_l[kqsx] >> 4) + 1) * (sh & 2 ? -1 : 1)); +#endif // INT8_MMA_AVAILABLE + } +} + +template static __device__ __forceinline__ void load_tiles_iq3_k_r4( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + + constexpr int nwarps = mmq_get_nwarps_device(); + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE*2); +#else + constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16; + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE + + const int kqsx = threadIdx.x/4; // 0...7 -> block of 32 + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += 4*nwarps) { + int i = i0 + 4*threadIdx.y + threadIdx.x%4; + + if (need_check) { + i = min(i, i_max); + } + int i4 = i/4; + int ir = i%4; + + const block_iq3_k_r4 * bxi = (const block_iq3_k_r4 *)(x + 4*i4*stride) + kbx0; + + const float d = __half2float(bxi->d[ir]); + + int qh = get_int_b4(bxi->qh, 4*kqsx+ir); + + #pragma unroll + for (int l = 0; l < 2; ++l) { + + //auto values_l = iq3k_table + (((bxi->extra[ir+4*l] >> kqsx) & 1) << 6); + uint32_t extra32 = uint32_t((bxi->extra[ir+4*l] >> kqsx) & 1) * 0x88888888; + + const int ql = get_int_b4(bxi->qs, 8*kqsx + ir + 4*l); + uint32_t val1 = ((ql >> 0) & 0x33333333) | extra32 | ((qh << 2) & 0x04040404) | ((qh << 4) & 0x40404040); + uint32_t val2 = ((ql >> 2) & 0x33333333) | extra32 | ((qh << 1) & 0x04040404) | ((qh << 3) & 0x40404040); + int2 v1 = get_int_from_table_16(val1, iq3nl_values); + int2 v2 = get_int_from_table_16(val2, iq3nl_values); + +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 0] = v1.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 1] = v2.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 2] = v1.y; + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 3] = v2.y; +#else + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 0] = v1.x; + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 1] = v2.x; + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 2] = v1.y; + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 3] = v2.y; +#endif // INT8_MMA_AVAILABLE + + qh >>= 4; + } + + int is = 8*kqsx + ir; + float dl1 = d * (2*((bxi->scales_l[is%32] >> 4*(is/32)) & 0xf) + 1) * ((bxi->scales_h[is%8] >> (is/8)) & 1 ? -1 : 1); + is += 4; + float dl2 = d * (2*((bxi->scales_l[is%32] >> 4*(is/32)) & 0xf) + 1) * ((bxi->scales_h[is%8] >> (is/8)) & 1 ? -1 : 1); + +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = dl1; + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = dl2; +#else + x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = dl1; + x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = dl2; +#endif // INT8_MMA_AVAILABLE + } +} + +template +struct mmq_type_traits_id { + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_k; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a; +}; + +template +struct mmq_type_traits_id { + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_k_r4; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a; +}; + + +DECL_MMQ_CASE(GGML_TYPE_IQ3_K); +DECL_MMQ_CASE(GGML_TYPE_IQ3_K_R4); diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_ks_id.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_ks_id.cu new file mode 100644 index 000000000..1ecb748f3 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_ks_id.cu @@ -0,0 +1,79 @@ +#include "../mmq_id_common.cuh" + +template static __device__ __forceinline__ void load_tiles_iq3_ks( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + + constexpr int nwarps = mmq_get_nwarps_device(); + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE*2); +#else + constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16; + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE + + constexpr int qstep = 8; + const int kqsx = threadIdx.x % qstep; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/qstep) { + int i = i0 + threadIdx.y*(WARP_SIZE/qstep) + threadIdx.x/qstep; + + if (need_check) { + i = min(i, i_max); + } + + const half * dptr = (const half *)(x + i*stride); + const float d = __half2float(dptr[0]); + const block_iq3_ks * bxi = (const block_iq3_ks *)(dptr + 1) + kbx0; + + //uint16_t extra = bxi->extra >> 8; + int qh = get_int_b2(bxi->qh, kqsx); + + uint32_t extra32 = uint32_t(bxi->extra >> 8) * 0x01010101; + + #pragma unroll + for (int l = 0; l < qstep/4; ++l) { + + const int ql = get_int_b2(bxi->qs, kqsx + qstep*l); + uint32_t val1 = ((ql >> 0) & 0x33333333) | ((qh << 2) & 0x04040404) | ((extra32 << 3) & 0x08080808) + | ((qh << 4) & 0x40404040) | ((extra32 << 5) & 0x80808080); + uint32_t val2 = ((ql >> 2) & 0x33333333) | ((qh << 1) & 0x04040404) | ((extra32 << 2) & 0x08080808) + | ((qh << 3) & 0x40404040) | ((extra32 << 4) & 0x80808080); + int2 v1 = get_int_from_table_16(val1, iq3nl_values); + int2 v2 = get_int_from_table_16(val2, iq3nl_values); + + extra32 >>= 4; + qh >>= 4; + +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx + 32*l + 0] = v1.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx + 32*l + 8] = v2.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx + 32*l + 16] = v1.y; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx + 32*l + 24] = v2.y; +#else + x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 0] = v1.x; + x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 8] = v2.x; + x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 16] = v1.y; + x_qs[i*(2*WARP_SIZE + 1) + kqsx + 32*l + 24] = v2.y; +#endif // INT8_MMA_AVAILABLE + } + +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = d * (int(((bxi->scales[kqsx%4] >> 4*(kqsx/4)) & 0xf) | (((bxi->extra >> kqsx) & 1) << 4)) - 16); +#else + x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = d * (int(((bxi->scales[kqsx%4] >> 4*(kqsx/4)) & 0xf) | (((bxi->extra >> kqsx) & 1) << 4)) - 16); +#endif // INT8_MMA_AVAILABLE + } +} + +template +struct mmq_type_traits_id { + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_ks; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; +}; + +DECL_MMQ_CASE(GGML_TYPE_IQ3_KS); diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_kt_id.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_kt_id.cu new file mode 100644 index 000000000..6af8b9d64 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_kt_id.cu @@ -0,0 +1,91 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq_id_common.cuh" + +template static __device__ __forceinline__ void load_tiles_iq3_kt( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + + constexpr int nwarps = mmq_get_nwarps_device(); + + constexpr uint32_t ka = 0xCBAC1FED; + constexpr uint32_t km = 0x3f3f3f3f; + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE + + const int kqsx = threadIdx.x; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + threadIdx.y; + + if (need_check) { + i = min(i, i_max); + } + + const block_iq3_kt * bxi = (const block_iq3_kt *)(x + i*stride + sizeof(float)) + kbx0; + + int ib32 = kqsx/4; + int j = kqsx%4; + const auto ql = (const uint16_t *)bxi->ql; + const auto qh = (const uint32_t *)bxi->qh; + uint32_t mask = 0x01010101 << ib32; + uint32_t val = ql[4*ib32+j] + 4096; + int2 v = {0, 0}; + for (int k = 0; k < 4; ++k) { + val *= ka; + v.x |= std::abs(ggml_cuda_dp4a(val & km, 0x01010101, -126)) << 8*k; + } + auto signs = __vcmpne4(qh[2*j+0] & mask, 0); + v.x = __vsub4(v.x ^ signs, signs); + for (int k = 0; k < 4; ++k) { + val *= ka; + v.y |= std::abs(ggml_cuda_dp4a(val & km, 0x01010101, -126)) << 8*k; + } + signs = __vcmpne4(qh[2*j+1] & mask, 0); + v.y = __vsub4(v.y ^ signs, signs); +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*ib32 + 2*j + 0] = v.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*ib32 + 2*j + 1] = v.y; +#else + x_qs[i*(2*WARP_SIZE + 1) + 8*ib32 + 2*j + 0] = v.x; + x_qs[i*(2*WARP_SIZE + 1) + 8*ib32 + 2*j + 1] = v.y; +#endif // INT8_MMA_AVAILABLE + } + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) { + int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4); + + if (need_check) { + i = min(i, i_max); + } + + const float * dptr = (const float *)(x + i*stride); + const float d = dptr[0] * 1.01f; + const block_iq3_kt * bxi = (const block_iq3_kt *)(dptr + 1) + kbx0; + int ib32 = threadIdx.x % 8; + const int ls = (bxi->scales[ib32%4] >> 4*(ib32/4)) & 0xf; + +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = d * ls; +#else + x_df[i*(WARP_SIZE/4) + i/4 + threadIdx.x % 8] = d * ls; +#endif // INT8_MMA_AVAILABLE + } +} + +template +struct mmq_type_traits_id { + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_kt; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; +}; + +DECL_MMQ_CASE(GGML_TYPE_IQ3_KT); diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_s_id.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_s_id.cu new file mode 100644 index 000000000..5c501ed36 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_s_id.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq_id_common.cuh" + +DECL_MMQ_CASE(GGML_TYPE_IQ3_S); diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_xxs_id.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_xxs_id.cu new file mode 100644 index 000000000..387e6d274 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_xxs_id.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq_id_common.cuh" + +DECL_MMQ_CASE(GGML_TYPE_IQ3_XXS); diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_k_id.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_k_id.cu new file mode 100644 index 000000000..f2d267fac --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_k_id.cu @@ -0,0 +1,162 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq_id_common.cuh" + +template static __device__ __forceinline__ void load_tiles_iq4_k( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + + constexpr int nwarps = mmq_get_nwarps_device(); + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE*2); +#else + constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16; + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE + + constexpr int qstep = 8; + const int kqsx = threadIdx.x % qstep; + + uint32_t aux32[2]; + const uint8_t * aux8 = (const uint8_t *)aux32; +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/qstep) { + int i = i0 + threadIdx.y*(WARP_SIZE/qstep) + threadIdx.x/qstep; + + if (need_check) { + i = min(i, i_max); + } + + const block_iq4_k * bxi = (const block_iq4_k *)(x + i*stride) + kbx0; + const uint16_t extra = bxi->extra >> 2*kqsx; + + auto values_l = iq4k_table + ((extra & 1) << 8); + auto values_h = iq4k_table + ((extra & 2) << 7); + + #pragma unroll + for (int l = 0; l < qstep/2; ++l) { + + const int q4 = get_int_b4(bxi->qs, (qstep/2)*kqsx + l); + + aux32[0] = (q4 >> 0) & 0x0f0f0f0f; + aux32[1] = (q4 >> 4) & 0x0f0f0f0f; + + int val0 = int_from_table_x(aux8+0, values_l); + int val1 = int_from_table_x(aux8+4, values_h); + +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + l + 0] = val0; + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + l + 4] = val1; +#else + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + l + 0] = val0; + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + l + 4] = val1; +#endif // INT8_MMA_AVAILABLE + } + + const uint8_t sh = bxi->scales_h[kqsx/2] >> 4*(kqsx%2); + const int ls1 = ((bxi->scales_l[kqsx] & 0xf) | ((sh << 4) & 0x30)) - 32; + const int ls2 = ((bxi->scales_l[kqsx] >> 4) | ((sh << 2) & 0x30)) - 32; + + const float d = bxi->d; + +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = d * ls1; + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = d * ls2; +#else + x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = d * ls1; + x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = d * ls2; +#endif // INT8_MMA_AVAILABLE + } +} + +template static __device__ __forceinline__ void load_tiles_iq4_k_r4( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + + constexpr int nwarps = mmq_get_nwarps_device(); + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE*2); +#else + constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16; + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE + + const int kqsx = threadIdx.x/4; // 0...7 -> block of 32 + + uint32_t aux32[4]; + const uint8_t * aux8 = (const uint8_t *)aux32; +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += 4*nwarps) { + int i = i0 + 4*threadIdx.y + threadIdx.x%4; + + if (need_check) { + i = min(i, i_max); + } + int i4 = i/4; + int ir = i%4; + + const block_iq4_k_r4 * bxi = (const block_iq4_k_r4 *)(x + 4*i4*stride) + kbx0; + + const float d = __half2float(bxi->d[ir]); + + #pragma unroll + for (int l = 0; l < 2; ++l) { + + auto values_l = iq4k_table + ((bxi->extra[ir+4*l] << (8 - kqsx)) & 0x100); + + const int ql1 = get_int_b4(bxi->qs, 16*kqsx + ir + 4*l + 0); + const int ql2 = get_int_b4(bxi->qs, 16*kqsx + ir + 4*l + 8); + aux32[0] = (ql1 >> 0) & 0x0f0f0f0f; + aux32[1] = (ql1 >> 4) & 0x0f0f0f0f; + aux32[2] = (ql2 >> 0) & 0x0f0f0f0f; + aux32[3] = (ql2 >> 4) & 0x0f0f0f0f; + +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 0] = int_from_table_x(aux8+ 0, values_l); + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 2] = int_from_table_x(aux8+ 4, values_l); + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 1] = int_from_table_x(aux8+ 8, values_l); + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 3] = int_from_table_x(aux8+12, values_l); +#else + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 0] = int_from_table_x(aux8+ 0, values_l); + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 2] = int_from_table_x(aux8+ 4, values_l); + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 1] = int_from_table_x(aux8+ 8, values_l); + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 3] = int_from_table_x(aux8+12, values_l); +#endif // INT8_MMA_AVAILABLE + + } + + int is = 8*kqsx + ir; + float dl1 = d * ((((bxi->scales_l[is%32] >> 4*(is/32)) & 0xf) | (((bxi->scales_h[is%16] >> 2*(is/16)) & 3) << 4)) - 32); + is += 4; + float dl2 = d * ((((bxi->scales_l[is%32] >> 4*(is/32)) & 0xf) | (((bxi->scales_h[is%16] >> 2*(is/16)) & 3) << 4)) - 32); + +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = dl1; + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = dl2; +#else + x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = dl1; + x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = dl2; +#endif // INT8_MMA_AVAILABLE + } +} + +template +struct mmq_type_traits_id { + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_k; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a; +}; + +template +struct mmq_type_traits_id { + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_k_r4; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a; +}; + +DECL_MMQ_CASE(GGML_TYPE_IQ4_K); +DECL_MMQ_CASE(GGML_TYPE_IQ4_K_R4); diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_ks_id.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_ks_id.cu new file mode 100644 index 000000000..ae4b4aa08 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_ks_id.cu @@ -0,0 +1,187 @@ +#include "../mmq_id_common.cuh" + +template static __device__ __forceinline__ void load_tiles_iq4_kss( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + + constexpr int nwarps = mmq_get_nwarps_device(); + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE + + const int kqsx = threadIdx.x / 4; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += 4*nwarps) { + int i = i0 + 4*threadIdx.y + threadIdx.x%4; + + if (need_check) { + i = min(i, i_max); + } + + const float * dptr = (const float *)(x + i*stride); + const block_iq4_kss * bxi = (const block_iq4_kss *)(dptr + 1) + kbx0; + const uint32_t * q4 = bxi->qs + 4*kqsx; + uint32_t s32 = (q4[0] & 0x00010001) | ((q4[1] & 0x00010001) << 2) | ((q4[2] & 0x00010001) << 4) | ((q4[3] & 0x00010001) << 6); + uint8_t ls = (s32 | (s32 >> 15)) & 0xff; + + auto values = iq4k_values + ((ls & 1) << 4); + + #pragma unroll + for (int j = 0; j < 4; ++j) { + uint32_t val = q4[j] & 0xfffefffe; + val = val ^ (val >> 1); + auto v = get_int_from_table_16(val, values); +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + j + 0] = v.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + j + 4] = v.y; +#else + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + j + 0] = v.x; + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + j + 4] = v.y; +#endif // INT8_MMA_AVAILABLE + } +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = dptr[0] * ((ls & 254) - 127); +#else + x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = dptr[0] * ((ls & 254) - 127); +#endif // INT8_MMA_AVAILABLE + } + +} + +template static __device__ __forceinline__ void load_tiles_iq4_ks( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + + constexpr int nwarps = mmq_get_nwarps_device(); + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE + + const int kqsx = threadIdx.x / 4; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += 4*nwarps) { + int i = i0 + 4*threadIdx.y + threadIdx.x%4; + + if (need_check) { + i = min(i, i_max); + } + + const float * dptr = (const float *)(x + i*stride); + const block_iq4_ks * bxi = (const block_iq4_ks *)(dptr + 1) + kbx0; + const int ls = (bxi->scales[kqsx] & 254) - 127; + + auto values = iq4k_values + ((bxi->scales[kqsx] & 1) << 4); + + #pragma unroll + for (int j = 0; j < 4; ++j) { + const int q4 = get_int_b4(bxi->qs, 4*kqsx+j); + const int2 v = get_int_from_table_16(q4, values); +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + j + 0] = v.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + j + 4] = v.y; +#else + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + j + 0] = v.x; + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + j + 4] = v.y; +#endif // INT8_MMA_AVAILABLE + } +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = dptr[0] * ls; +#else + x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = dptr[0] * ls; +#endif // INT8_MMA_AVAILABLE + } + +} + +template static __device__ __forceinline__ void load_tiles_iq4_ks_r4( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + + constexpr int nwarps = mmq_get_nwarps_device(); + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_KS_R4, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE + + const int kqsx = threadIdx.x/4; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += 4*nwarps) { + int i = i0 + 4*threadIdx.y + threadIdx.x%4; + + if (need_check) { + i = min(i, i_max); + } + int i4 = i/4; + int ir = i%4; + + const float * dptr = (const float *)(x + 4*i4*stride); + const block_iq4_ks_r4 * bxi = (const block_iq4_ks_r4 *)(dptr + 4) + kbx0; + + const int ls = (bxi->scales[4*kqsx + ir] & 254) - 127; + auto values = iq4k_values + ((bxi->scales[4*kqsx+ir] & 1) << 4); + +#pragma unroll + for (int j = 0; j < 4; ++j) { + const int q4 = get_int_b4(bxi->qs, 16*kqsx+4*j+ir); + const int2 v = get_int_from_table_16(q4, values); + const int k0 = 8*kqsx + 4*(j%2) + j/2; +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 2] = v.y; +#else + x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x; + x_qs[i*(2*WARP_SIZE + 1) + k0 + 2] = v.y; +#endif // INT8_MMA_AVAILABLE + } +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = dptr[ir] * ls; +#else + x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = dptr[ir] * ls; +#endif // INT8_MMA_AVAILABLE + + } + +} + +template +struct mmq_type_traits_id { + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_kss; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; +}; + +template +struct mmq_type_traits_id { + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_ks; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; +}; + +template +struct mmq_type_traits_id { + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_ks_r4; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; +}; + +DECL_MMQ_CASE(GGML_TYPE_IQ4_KSS); +DECL_MMQ_CASE(GGML_TYPE_IQ4_KS); +DECL_MMQ_CASE(GGML_TYPE_IQ4_KS_R4); + diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_kt_id.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_kt_id.cu new file mode 100644 index 000000000..a27fe8a8d --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_kt_id.cu @@ -0,0 +1,86 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq_id_common.cuh" + +template static __device__ __forceinline__ void load_tiles_iq4_kt( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + + constexpr int nwarps = mmq_get_nwarps_device(); + + constexpr uint32_t ka = 0xCBAC1FED; + constexpr uint32_t km = 0x3f3f3f3f; + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE + + const int kqsx = threadIdx.x; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + threadIdx.y; + + if (need_check) { + i = min(i, i_max); + } + + const block_iq4_kt * bxi = (const block_iq4_kt *)(x + i*stride + sizeof(float)) + kbx0; + + int ib32 = kqsx/4; + int j = kqsx%4; + const auto shb = bxi->qs; + const auto ql = (const uint8_t *)(shb + 8); + const auto qh = ql + 64; + const uint32_t sh = shb[ib32] >> (8 + 6*j); + uint32_t offset = 4096 + ((shb[ib32] & 1) << 15); + uint32_t val1 = offset + ql[8*ib32+2*j+0] + ((qh[8*(ib32%4)+2*j+0] << (8 - 4*(ib32/4))) & 0xf00) + ((sh & 7) << 12); + uint32_t val2 = offset + ql[8*ib32+2*j+1] + ((qh[8*(ib32%4)+2*j+1] << (8 - 4*(ib32/4))) & 0xf00) + ((sh & 56) << 9); + int2 v = {0, 0}; + for (int k = 0; k < 4; ++k) { + val1 *= ka; + val2 *= ka; + v.x |= (ggml_cuda_dp4a(val1 & km, 0x01010101, -126) & 0xff) << 8*k; + v.y |= (ggml_cuda_dp4a(val2 & km, 0x01010101, -126) & 0xff) << 8*k; + } +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*ib32 + 2*j + 0] = v.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*ib32 + 2*j + 1] = v.y; +#else + x_qs[i*(2*WARP_SIZE + 1) + 8*ib32 + 2*j + 0] = v.x; + x_qs[i*(2*WARP_SIZE + 1) + 8*ib32 + 2*j + 1] = v.y; +#endif // INT8_MMA_AVAILABLE + } + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) { + int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4); + + if (need_check) { + i = min(i, i_max); + } + + const float * dptr = (const float *)(x + i*stride); + const block_iq4_kt * bxi = (const block_iq4_kt *)(dptr + 1) + kbx0; + const int ls = (bxi->qs[threadIdx.x % 8] & 0xff) >> 1; + +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = dptr[0] * (ls - 64); +#else + x_df[i*(WARP_SIZE/4) + i/4 + threadIdx.x % 8] = dptr[0] * (ls - 64); +#endif // INT8_MMA_AVAILABLE + } +} + +template +struct mmq_type_traits_id { + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_kt; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; +}; + +DECL_MMQ_CASE(GGML_TYPE_IQ4_KT); diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_nl_id.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_nl_id.cu new file mode 100644 index 000000000..6ef74f1d8 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_nl_id.cu @@ -0,0 +1,4 @@ +#include "../mmq_id_common.cuh" + +DECL_MMQ_CASE(GGML_TYPE_IQ4_NL); +DECL_MMQ_CASE(GGML_TYPE_MXFP4); diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_xs_id.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_xs_id.cu new file mode 100644 index 000000000..988f5da1a --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_xs_id.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq_id_common.cuh" + +DECL_MMQ_CASE(GGML_TYPE_IQ4_XS); diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq5_k_id.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq5_k_id.cu new file mode 100644 index 000000000..853efb4c4 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq5_k_id.cu @@ -0,0 +1,174 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq_id_common.cuh" + +template static __device__ __forceinline__ void load_tiles_iq5_k( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + + constexpr int nwarps = mmq_get_nwarps_device(); + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE*2); +#else + constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16; + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE + + constexpr int qstep = 8; + const int kqsx = threadIdx.x % qstep; + + auto values = iq5nl_values; + + uint32_t aux32[2]; + const uint8_t * aux8 = (const uint8_t *)aux32; +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/qstep) { + int i = i0 + threadIdx.y*(WARP_SIZE/qstep) + threadIdx.x/qstep; + + if (need_check) { + i = min(i, i_max); + } + + const block_iq5_k * bxi = (const block_iq5_k *)(x + i*stride) + kbx0; + + int qh = get_int_b4(bxi->qh, kqsx); + uint16_t extra = bxi->extra >> (kqsx/4); + + #pragma unroll + for (int l = 0; l < qstep/2; ++l) { + + const int ql = get_int_b4(bxi->qs, kqsx + qstep*l); + aux32[0] = ((ql >> 0) & 0x0f0f0f0f) | ((qh & 0x01010101) << 4) | ((extra & 1) * 0x20202020); // this is very slightly faster + aux32[1] = ((ql >> 4) & 0x0f0f0f0f) | ((qh & 0x02020202) << 3) | ((extra & 4) * 0x08080808); // then the version below + //aux32[0] = ((ql >> 0) & 0x0f0f0f0f) | ((qh & 0x01010101) << 4) | ((extra & 1) ? 0x20202020 : 0); + //aux32[1] = ((ql >> 4) & 0x0f0f0f0f) | ((qh & 0x02020202) << 3) | ((extra & 4) ? 0x20202020 : 0); + qh >>= 2; + extra >>= 4; + + const char4 val0 = make_char4(values[aux8[0]], values[aux8[1]], values[aux8[2]], values[aux8[3]]); + const char4 val1 = make_char4(values[aux8[4]], values[aux8[5]], values[aux8[6]], values[aux8[7]]); + +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 16*l + 0] = *(const int *)&val0; + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 16*l + 8] = *(const int *)&val1; +#else + x_qs[i*(2*WARP_SIZE + 1) + kqsx + 16*l + 0] = *(const int *)&val0; + x_qs[i*(2*WARP_SIZE + 1) + kqsx + 16*l + 8] = *(const int *)&val1; +#endif // INT8_MMA_AVAILABLE + } + + const uint8_t sh = bxi->scales_h[kqsx/2] >> 4*(kqsx%2); + const int ls1 = ((bxi->scales_l[kqsx] & 0xf) | ((sh << 4) & 0x30)) - 32; + const int ls2 = ((bxi->scales_l[kqsx] >> 4) | ((sh << 2) & 0x30)) - 32; + + const float d = bxi->d; + +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = d * ls1; + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = d * ls2; +#else + x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = d * ls1; + x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = d * ls2; +#endif // INT8_MMA_AVAILABLE + } +} + +template static __device__ __forceinline__ void load_tiles_iq5_k_r4( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + + constexpr int nwarps = mmq_get_nwarps_device(); + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE*2); +#else + constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16; + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE + + const int kqsx = threadIdx.x/4; // 0...7 -> block of 32 + + uint32_t aux32[4]; + const uint8_t * aux8 = (const uint8_t *)aux32; +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += 4*nwarps) { + int i = i0 + 4*threadIdx.y + threadIdx.x%4; + + if (need_check) { + i = min(i, i_max); + } + int i4 = i/4; + int ir = i%4; + + const block_iq5_k_r4 * bxi = (const block_iq5_k_r4 *)(x + 4*i4*stride) + kbx0; + + const float d = __half2float(bxi->d[ir]); + + int qh = get_int_b4(bxi->qh, 4*kqsx + ir); + + #pragma unroll + for (int l = 0; l < 2; ++l) { + + auto values_l = iq5nl_values + (((bxi->extra[ir+4*l] >> kqsx) & 1) << 5); + + const int ql1 = get_int_b4(bxi->qs, 16*kqsx + ir + 4*l + 0); + const int ql2 = get_int_b4(bxi->qs, 16*kqsx + ir + 4*l + 8); + aux32[0] = ((ql1 >> 0) & 0x0f0f0f0f) | ((qh << 4) & 0x10101010); + aux32[1] = ((ql1 >> 4) & 0x0f0f0f0f) | ((qh << 3) & 0x10101010); + aux32[2] = ((ql2 >> 0) & 0x0f0f0f0f) | ((qh >> 0) & 0x10101010); + aux32[3] = ((ql2 >> 4) & 0x0f0f0f0f) | ((qh >> 1) & 0x10101010); + + const char4 val0 = make_char4(values_l[aux8[ 0]], values_l[aux8[ 1]], values_l[aux8[ 2]], values_l[aux8[ 3]]); + const char4 val1 = make_char4(values_l[aux8[ 4]], values_l[aux8[ 5]], values_l[aux8[ 6]], values_l[aux8[ 7]]); + const char4 val2 = make_char4(values_l[aux8[ 8]], values_l[aux8[ 9]], values_l[aux8[10]], values_l[aux8[11]]); + const char4 val3 = make_char4(values_l[aux8[12]], values_l[aux8[13]], values_l[aux8[14]], values_l[aux8[15]]); + +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 0] = *(const int *)&val0; + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 2] = *(const int *)&val1; + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 1] = *(const int *)&val2; + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + 4*l + 3] = *(const int *)&val3; +#else + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 0] = *(const int *)&val0; + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 2] = *(const int *)&val1; + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 1] = *(const int *)&val2; + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + 4*l + 3] = *(const int *)&val3; +#endif // INT8_MMA_AVAILABLE + + qh >>= 2; + } + + int is = 8*kqsx + ir; + float dl1 = d * ((((bxi->scales_l[is%32] >> 4*(is/32)) & 0xf) | (((bxi->scales_h[is%16] >> 2*(is/16)) & 3) << 4)) - 32); + is += 4; + float dl2 = d * ((((bxi->scales_l[is%32] >> 4*(is/32)) & 0xf) | (((bxi->scales_h[is%16] >> 2*(is/16)) & 3) << 4)) - 32); + +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = dl1; + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = dl2; +#else + x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = dl1; + x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = dl2; +#endif // INT8_MMA_AVAILABLE + } +} + +template +struct mmq_type_traits_id { + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq5_k; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a; +}; + +template +struct mmq_type_traits_id { + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq5_k_r4; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a; +}; + +DECL_MMQ_CASE(GGML_TYPE_IQ5_K); +DECL_MMQ_CASE(GGML_TYPE_IQ5_K_R4); diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq5_ks_id.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq5_ks_id.cu new file mode 100644 index 000000000..3232fa085 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq5_ks_id.cu @@ -0,0 +1,146 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq_id_common.cuh" + +template static __device__ __forceinline__ void load_tiles_iq5_ks( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + + constexpr int nwarps = mmq_get_nwarps_device(); + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ5_KS, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE + + constexpr int qstep = 8; + const int kqsx = threadIdx.x % qstep; + + auto values = iq5nl_values; + + uint32_t aux32[2]; + const uint8_t * aux8 = (const uint8_t *)aux32; +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/qstep) { + int i = i0 + threadIdx.y*(WARP_SIZE/qstep) + threadIdx.x/qstep; + + if (need_check) { + i = min(i, i_max); + } + + const float * dptr = (const float *)(x + i*stride); + const float d = dptr[0]; + const block_iq5_ks * bxi = (const block_iq5_ks *)(dptr + 1) + kbx0; + + int qh = get_int_b4(bxi->qh, kqsx); + + #pragma unroll + for (int l = 0; l < qstep/2; ++l) { + + const int ql = get_int_b4(bxi->qs, kqsx + qstep*l); + aux32[0] = ((ql >> 0) & 0x0f0f0f0f) | ((qh & 0x01010101) << 4) | ((bxi->scales[2*l+0] & 1) * 0x20202020); + aux32[1] = ((ql >> 4) & 0x0f0f0f0f) | ((qh & 0x02020202) << 3) | ((bxi->scales[2*l+1] & 1) * 0x20202020); + qh >>= 2; + + const char4 val0 = make_char4(values[aux8[0]], values[aux8[1]], values[aux8[2]], values[aux8[3]]); + const char4 val1 = make_char4(values[aux8[4]], values[aux8[5]], values[aux8[6]], values[aux8[7]]); + +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx + 16*l + 0] = *(const int *)&val0; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx + 16*l + 8] = *(const int *)&val1; +#else + x_qs[i*(2*WARP_SIZE + 1) + kqsx + 16*l + 0] = *(const int *)&val0; + x_qs[i*(2*WARP_SIZE + 1) + kqsx + 16*l + 8] = *(const int *)&val1; +#endif // INT8_MMA_AVAILABLE + } + +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = d * ((bxi->scales[kqsx] & 254) - 127); +#else + x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + kqsx] = d * ((bxi->scales[kqsx] & 254) - 127); +#endif // INT8_MMA_AVAILABLE + } +} + +template static __device__ __forceinline__ void load_tiles_iq5_ks_r4( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + + constexpr int nwarps = mmq_get_nwarps_device(); + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ5_KS_R4, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE + + const int kqsx = threadIdx.x/4; + + uint32_t aux32[2]; + const uint8_t * aux8 = (const uint8_t *)aux32; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += 4*nwarps) { + int i = i0 + 4*threadIdx.y + threadIdx.x%4; + + if (need_check) { + i = min(i, i_max); + } + int i4 = i/4; + int ir = i%4; + + const float * dptr = (const float *)(x + 4*i4*stride); + const block_iq5_ks_r4 * bxi = (const block_iq5_ks_r4 *)(dptr + 4) + kbx0; + + const int ls = (bxi->scales[4*kqsx + ir] & 254) - 127; + auto values = iq5nl_values + ((bxi->scales[4*kqsx+ir] & 1) << 5); + + int qh = *((const int *)bxi->qh + 4*kqsx + ir); + const int * ql = (const int *)bxi->qs + 16*kqsx + ir; +#pragma unroll + for (int j = 0; j < 4; ++j) { + aux32[0] = ((ql[4*j] >> 0) & 0x0f0f0f0f) | ((qh << 4) & 0x10101010); + aux32[1] = ((ql[4*j] >> 4) & 0x0f0f0f0f) | ((qh << 3) & 0x10101010); + qh >>= 2; + const char4 val0 = make_char4(values[aux8[0]], values[aux8[1]], values[aux8[2]], values[aux8[3]]); + const char4 val1 = make_char4(values[aux8[4]], values[aux8[5]], values[aux8[6]], values[aux8[7]]); + const int k0 = 8*kqsx + 4*(j%2) + j/2; +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = *(const int *)&val0; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 2] = *(const int *)&val1; +#else + x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = *(const int *)&val0; + x_qs[i*(2*WARP_SIZE + 1) + k0 + 2] = *(const int *)&val1; +#endif // INT8_MMA_AVAILABLE + } +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = dptr[ir] * ls; +#else + x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = dptr[ir] * ls; +#endif // INT8_MMA_AVAILABLE + + } + +} + +template +struct mmq_type_traits_id { + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq5_ks; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; +}; + +template +struct mmq_type_traits_id { + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq5_ks_r4; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; +}; + +DECL_MMQ_CASE(GGML_TYPE_IQ5_KS); +DECL_MMQ_CASE(GGML_TYPE_IQ5_KS_R4); diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq6_k_id.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq6_k_id.cu new file mode 100644 index 000000000..af05bd768 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq6_k_id.cu @@ -0,0 +1,80 @@ +#include "../mmq_id_common.cuh" + +template static __device__ __forceinline__ void load_tiles_iq6_k( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + + constexpr int nwarps = mmq_get_nwarps_device(); + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE*2); +#else + constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16; + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE + + constexpr int qstep = 8; + const int kqsx = threadIdx.x % qstep; + + auto values = iq6nl_values; + int qh[2]; + + uint32_t aux32[2]; + const uint8_t * aux8 = (const uint8_t *)aux32; +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/qstep) { + int i = i0 + threadIdx.y*(WARP_SIZE/qstep) + threadIdx.x/qstep; + + if (need_check) { + i = min(i, i_max); + } + + const block_iq6_k * bxi = (const block_iq6_k *)(x + i*stride) + kbx0; + + const float d = bxi->d; + uint16_t extra = bxi->extra >> (kqsx/4); + + qh[0] = get_int_b4(bxi->qh, kqsx+0); + qh[1] = get_int_b4(bxi->qh, kqsx+8); + + #pragma unroll + for (int l = 0; l < qstep/2; ++l) { + + const int ql = get_int_b4(bxi->qs, kqsx + qstep*l); + aux32[0] = ((ql >> 0) & 0x0f0f0f0f) | ((qh[l/2] & 0x03030303) << 4) | ((extra & 1) * 0x40404040); + aux32[1] = ((ql >> 4) & 0x0f0f0f0f) | ((qh[l/2] & 0x0c0c0c0c) << 2) | ((extra & 4) * 0x10101010); + qh[l/2] >>= 4; + extra >>= 4; + + const char4 val0 = make_char4(values[aux8[0]], values[aux8[1]], values[aux8[2]], values[aux8[3]]); + const char4 val1 = make_char4(values[aux8[4]], values[aux8[5]], values[aux8[6]], values[aux8[7]]); + +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 16*l + 0] = *(const int *)&val0; + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + kqsx + 16*l + 8] = *(const int *)&val1; +#else + x_qs[i*(2*WARP_SIZE + 1) + kqsx + 16*l + 0] = *(const int *)&val0; + x_qs[i*(2*WARP_SIZE + 1) + kqsx + 16*l + 8] = *(const int *)&val1; +#endif // INT8_MMA_AVAILABLE + } + + +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = d * bxi->scales[2*kqsx+0]; + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = d * bxi->scales[2*kqsx+1]; +#else + x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = d * bxi->scales[2*kqsx+0]; + x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = d * bxi->scales[2*kqsx+1]; +#endif // INT8_MMA_AVAILABLE + } +} + +template +struct mmq_type_traits_id { + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq6_k; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a; +}; + +DECL_MMQ_CASE(GGML_TYPE_IQ6_K); diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-q2_k_id.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-q2_k_id.cu new file mode 100644 index 000000000..14d78d839 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-q2_k_id.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq_id_common.cuh" + +DECL_MMQ_CASE(GGML_TYPE_Q2_K); diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-q3_k_id.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-q3_k_id.cu new file mode 100644 index 000000000..4262f49aa --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-q3_k_id.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq_id_common.cuh" + +DECL_MMQ_CASE(GGML_TYPE_Q3_K); diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_0_id.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_0_id.cu new file mode 100644 index 000000000..4e747a8ca --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_0_id.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq_id_common.cuh" + +DECL_MMQ_CASE(GGML_TYPE_Q4_0); diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_1_id.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_1_id.cu new file mode 100644 index 000000000..a29a943ed --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_1_id.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq_id_common.cuh" + +DECL_MMQ_CASE(GGML_TYPE_Q4_1); diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_k_id.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_k_id.cu new file mode 100644 index 000000000..c16f76138 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_k_id.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq_id_common.cuh" + +DECL_MMQ_CASE(GGML_TYPE_Q4_K); diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_0_id.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_0_id.cu new file mode 100644 index 000000000..3afa037ad --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_0_id.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq_id_common.cuh" + +DECL_MMQ_CASE(GGML_TYPE_Q5_0); diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_1_id.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_1_id.cu new file mode 100644 index 000000000..1c161297f --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_1_id.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq_id_common.cuh" + +DECL_MMQ_CASE(GGML_TYPE_Q5_1); diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_k_id.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_k_id.cu new file mode 100644 index 000000000..36c7a9fae --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_k_id.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq_id_common.cuh" + +DECL_MMQ_CASE(GGML_TYPE_Q5_K); diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-q6_0_id.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-q6_0_id.cu new file mode 100644 index 000000000..7da267a42 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-q6_0_id.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq_id_common.cuh" + +DECL_MMQ_CASE(GGML_TYPE_Q6_0); diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-q6_k_id.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-q6_k_id.cu new file mode 100644 index 000000000..2a02f965a --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-q6_k_id.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq_id_common.cuh" + +DECL_MMQ_CASE(GGML_TYPE_Q6_K); diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-q8_0_id.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-q8_0_id.cu new file mode 100644 index 000000000..3b126bed9 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-q8_0_id.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq_id_common.cuh" + +DECL_MMQ_CASE(GGML_TYPE_Q8_0); diff --git a/ggml/src/ggml-cuda/vendors/cuda.h b/ggml/src/ggml-cuda/vendors/cuda.h index 840809a15..30ddbb3c2 100644 --- a/ggml/src/ggml-cuda/vendors/cuda.h +++ b/ggml/src/ggml-cuda/vendors/cuda.h @@ -6,6 +6,10 @@ #include #include +#if CUDART_VERSION >= 12050 +#include +#endif // CUDART_VERSION >= 12050 + #if CUDART_VERSION < 11020 #define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED #define CUBLAS_TF32_TENSOR_OP_MATH CUBLAS_TENSOR_OP_MATH