Skip to content

Commit b1c8a17

Browse files
airMengggerganov
authored andcommitted
Align GEMM dispatch (llama/7566)
* align GEMM dispatch
1 parent c052218 commit b1c8a17

File tree

1 file changed

+55
-67
lines changed

1 file changed

+55
-67
lines changed

src/ggml-sycl.cpp

Lines changed: 55 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -3022,20 +3022,19 @@ static int g_work_group_size = 0;
30223022
// typedef sycl::half ggml_fp16_t;
30233023

30243024
#define __SYCL_ARCH__ DPCT_COMPATIBILITY_TEMP
3025-
#define VER_4VEC 610 //todo for hardward optimize.
3025+
#define VER_4VEC 130 //todo for hardward optimize.
30263026
#define VER_GEN9 700 //todo for hardward optimize.
30273027
#define VER_GEN12 1000000 //todo for hardward optimize.
30283028
#define VER_GEN13 (VER_GEN12 + 1030) //todo for hardward optimize.
30293029

30303030
#define GGML_SYCL_MAX_NODES 8192 //TODO: adapt to hardwares
30313031

3032-
3033-
//define for XMX in Intel GPU
3034-
//TODO: currently, it's not used for XMX really.
3035-
#define SYCL_USE_XMX
3032+
#if !defined(GGML_SYCL_FORCE_MMQ)
3033+
#define SYCL_USE_XMX
3034+
#endif
30363035

30373036
// max batch size to use MMQ kernels when tensor cores are available
3038-
#define XMX_MAX_BATCH_SIZE 32
3037+
#define MMQ_MAX_BATCH_SIZE 32
30393038

30403039

30413040
#if defined(_MSC_VER)
@@ -15249,6 +15248,29 @@ catch (sycl::exception const &exc) {
1524915248
std::exit(1);
1525015249
}
1525115250

15251+
inline bool ggml_sycl_supports_mmq(enum ggml_type type) {
15252+
// TODO: accuracy issues in MMQ
15253+
return false;
15254+
}
15255+
15256+
bool ggml_sycl_supports_dmmv(enum ggml_type type) {
15257+
switch (type) {
15258+
case GGML_TYPE_Q4_0:
15259+
case GGML_TYPE_Q4_1:
15260+
case GGML_TYPE_Q5_0:
15261+
case GGML_TYPE_Q5_1:
15262+
case GGML_TYPE_Q8_0:
15263+
case GGML_TYPE_Q2_K:
15264+
case GGML_TYPE_Q3_K:
15265+
case GGML_TYPE_Q4_K:
15266+
case GGML_TYPE_Q5_K:
15267+
case GGML_TYPE_Q6_K:
15268+
case GGML_TYPE_F16:
15269+
return true;
15270+
default:
15271+
return false;
15272+
}
15273+
}
1525215274

1525315275
static void ggml_sycl_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
1525415276
const bool all_on_device =
@@ -15265,76 +15287,42 @@ static void ggml_sycl_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
1526515287
}
1526615288
}
1526715289

15268-
#ifdef SYCL_USE_XMX
15269-
const bool use_xmx = true;
15270-
#else
15271-
const bool use_xmx = false;
15272-
#endif
15290+
// check data types and tensor shapes for custom matrix multiplication kernels:
15291+
bool use_dequantize_mul_mat_vec = ggml_sycl_supports_dmmv(src0->type)
15292+
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
15293+
&& src0->ne[0] % GGML_SYCL_DMMV_X == 0 && src1->ne[1] == 1;
1527315294

15274-
// debug helpers
15275-
//printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]);
15276-
//printf(" %8d %8d %8d %8d\n", src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]);
15277-
//printf("src1: %8d %8d %8d %8d\n", src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3]);
15278-
//printf(" %8d %8d %8d %8d\n", src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3]);
15279-
//printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
15280-
//printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
15295+
bool use_mul_mat_vec_q = ggml_is_quantized(src0->type)
15296+
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
15297+
&& src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
15298+
15299+
bool use_mul_mat_q = ggml_sycl_supports_mmq(src0->type)
15300+
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
15301+
15302+
// mmvq and mmq need the __dp4a instruction which is available for gen12+
15303+
// Workaround in https://github.com/ggerganov/llama.cpp/commit/95f84d5ce8b449a9b16009434aca800df504a02e
15304+
use_mul_mat_q = use_mul_mat_q && (src0->type != GGML_TYPE_IQ2_XXS);
15305+
#ifdef SYCL_USE_XMX
15306+
use_mul_mat_q = use_mul_mat_q && (src1->ne[1] <= MMQ_MAX_BATCH_SIZE);
15307+
#endif // SYCL_USE_XMX
1528115308

15282-
if (!split && all_on_device && !use_xmx && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
15309+
if (!split && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
1528315310
// KQ single-batch
15284-
// GGML_SYCL_DEBUG("ggml_sycl_mul_mat_vec_p021\n");
1528515311
ggml_sycl_mul_mat_vec_p021(src0, src1, dst);
15286-
} else if (!split && all_on_device && !use_xmx && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
15312+
} else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
1528715313
// KQV single-batch
15288-
// GGML_SYCL_DEBUG("ggml_sycl_mul_mat_vec_nc\n");
1528915314
ggml_sycl_mul_mat_vec_nc(src0, src1, dst);
15290-
} else if (!split && all_on_device && use_xmx && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1)) {
15315+
} else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16) && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
1529115316
// KQ + KQV multi-batch
15292-
// GGML_SYCL_DEBUG("ggml_sycl_mul_mat_batched_sycl\n");
1529315317
ggml_sycl_mul_mat_batched_sycl(src0, src1, dst);
15294-
} else if (src0->type == GGML_TYPE_F32) {
15295-
// GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat\n");
15296-
ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false);
15297-
} else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) {
15298-
// GGML_SYCL_DEBUG("ggml_is_quantized or GGML_TYPE_F16\n");
15299-
if (src1->ne[1] == 1 && src0->ne[0] % GGML_SYCL_DMMV_X == 0) {
15300-
#ifdef GGML_SYCL_FORCE_DMMV
15301-
const bool use_mul_mat_vec_q = false;
15302-
#else
15303-
bool use_mul_mat_vec_q = min_compute_capability >= VER_4VEC && ggml_is_quantized(src0->type);
15304-
use_mul_mat_vec_q = use_mul_mat_vec_q ||
15305-
(src0->type == GGML_TYPE_IQ2_XXS) || (src0->type == GGML_TYPE_IQ2_XS) || (src0->type == GGML_TYPE_IQ2_S) ||
15306-
(src0->type == GGML_TYPE_IQ3_XXS) || (src0->type == GGML_TYPE_IQ3_S) ||
15307-
(src0->type == GGML_TYPE_IQ4_NL) || (src0->type == GGML_TYPE_IQ4_XS) ||
15308-
(src0->type == GGML_TYPE_IQ1_S) || (src0->type == GGML_TYPE_IQ1_M);
15309-
15310-
15311-
#endif // GGML_SYCL_FORCE_DMMV
15312-
15313-
if (use_mul_mat_vec_q) {
15314-
// GGML_SYCL_DEBUG("ggml_sycl_mul_mat ggml_sycl_op_mul_mat_vec_q path\n");
15315-
ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, true);
15316-
} else {
15317-
// GGML_SYCL_DEBUG("ggml_sycl_mul_mat ggml_sycl_op_dequantize_mul_mat_vec path\n");
15318-
ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, false);
15319-
}
15320-
} else {
15321-
bool use_mul_mat_q = min_compute_capability >= VER_4VEC && ggml_is_quantized(src0->type);
15322-
use_mul_mat_q = use_mul_mat_q && (src0->type != GGML_TYPE_IQ2_XXS);
15323-
15324-
if (use_xmx && min_compute_capability >= VER_GEN9 && src1->ne[1] > XMX_MAX_BATCH_SIZE) {
15325-
use_mul_mat_q = false;
15326-
}
15327-
15328-
if (use_mul_mat_q) {
15329-
// GGML_SYCL_DEBUG("ggml_sycl_mul_mat ggml_sycl_op_mul_mat_q path\n");
15330-
ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_q, true);
15331-
} else {
15332-
// GGML_SYCL_DEBUG("ggml_sycl_mul_mat ggml_sycl_op_mul_mat_sycl path\n");
15333-
ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false);
15334-
}
15335-
}
15318+
} else if (use_dequantize_mul_mat_vec) {
15319+
ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, false);
15320+
} else if (use_mul_mat_vec_q) {
15321+
ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, true);
15322+
} else if (use_mul_mat_q) {
15323+
ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_q, true);
1533615324
} else {
15337-
GGML_ASSERT(false);
15325+
ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false);
1533815326
}
1533915327
}
1534015328

0 commit comments

Comments
 (0)