@@ -2831,13 +2831,19 @@ catch (sycl::exception const &exc) {
28312831 std::exit (1 );
28322832}
28332833
2834+ enum class Mul_Mat_Algo {
2835+ DMMV = 0 ,
2836+ MMVQ = 1 ,
2837+ MUL_MAT_SYCL = 2 ,
2838+ };
2839+
28342840inline bool ggml_sycl_supports_mmq (enum ggml_type type) {
28352841 // TODO: accuracy issues in MMQ
28362842 GGML_UNUSED (type);
28372843 return false ;
28382844}
28392845
2840- inline bool ggml_sycl_supports_reorder_dequantize (enum ggml_type type) {
2846+ inline bool ggml_sycl_supports_reorder_mul_mat_sycl (enum ggml_type type) {
28412847 switch (type) {
28422848 case GGML_TYPE_Q4_0:
28432849 return true ;
@@ -2927,20 +2933,37 @@ static bool should_reorder_tensor(ggml_backend_sycl_context& ctx, const ggml_ten
29272933 dst->src [1 ]->ne [2 ]==1 && dst->src [1 ]->ne [3 ]==1 ;
29282934}
29292935
2930- /*
2931- * This function could be called when the OP (mul_mat) function support reorder optimizition.
2932- */
2933- static void opt_for_reorder (ggml_backend_sycl_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1,
2934- ggml_tensor * dst) {
2935- if (should_reorder_tensor (*ctx, dst)) {
2936- ggml_tensor_extra_gpu* extra = (ggml_tensor_extra_gpu*)src0->extra ;
2937- if (!extra) return ; // only happen in CI/UT permute case.
2936+ static void opt_for_reorder (ggml_backend_sycl_context * ctx, const ggml_tensor * src0, const ggml_tensor * /* src1 */ ,
2937+ ggml_tensor * dst, Mul_Mat_Algo mul_mat_algo) {
2938+ if (!should_reorder_tensor (*ctx, dst)) {
2939+ return ;
2940+ }
29382941
2939- if (extra->optimized_feature .reorder ) return ; // skip the tensor which is handled for reorder.
2942+ ggml_tensor_extra_gpu * extra = static_cast <ggml_tensor_extra_gpu *>(src0->extra );
2943+ if (!extra || extra->optimized_feature .reorder ) {
2944+ return ; // Skip permutations and already reordered tensors
2945+ }
29402946
2941- reorder_qw (src0, ctx->stream ());
2942- extra->optimized_feature .reorder = true ; // used to decode/dequan in next steps.
2947+ switch (mul_mat_algo) {
2948+ case Mul_Mat_Algo::DMMV:
2949+ if (!ggml_sycl_supports_reorder_dmmv (src0->type )) {
2950+ return ;
2951+ }
2952+ break ;
2953+ case Mul_Mat_Algo::MMVQ:
2954+ if (!ggml_sycl_supports_reorder_mmvq (src0->type )) {
2955+ return ;
2956+ }
2957+ break ;
2958+ case Mul_Mat_Algo::MUL_MAT_SYCL:
2959+ if (!ggml_sycl_supports_reorder_mul_mat_sycl (src0->type )) {
2960+ return ;
2961+ }
2962+ break ;
29432963 }
2964+
2965+ reorder_qw (src0, ctx->stream ());
2966+ extra->optimized_feature .reorder = true ; // Used to decode/dequan in next steps and avoid re-reordering
29442967}
29452968
29462969static void ggml_sycl_mul_mat (ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -3013,24 +3036,19 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
30133036 ggml_sycl_mul_mat_batched_sycl (ctx, src0, src1, dst);
30143037 } else if (use_dequantize_mul_mat_vec) {
30153038 constexpr bool convert_src1_to_q8_1 = false ;
3016- if (ggml_sycl_supports_reorder_dmmv (src0->type )) {
3017- opt_for_reorder (&ctx, src0, src1, dst);
3018- }
3039+ opt_for_reorder (&ctx, src0, src1, dst, Mul_Mat_Algo::DMMV);
30193040 ggml_sycl_op_mul_mat (ctx, src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, convert_src1_to_q8_1);
30203041 } else if (use_mul_mat_vec_q) {
30213042 constexpr bool convert_src1_to_q8_1 = true ;
3022- if (ggml_sycl_supports_reorder_mmvq (src0->type )) {
3023- opt_for_reorder (&ctx, src0, src1, dst);
3024- }
3043+ opt_for_reorder (&ctx, src0, src1, dst, Mul_Mat_Algo::MMVQ);
30253044 ggml_sycl_op_mul_mat (ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, convert_src1_to_q8_1);
30263045 } else if (use_mul_mat_q) {
30273046 constexpr bool convert_src1_to_q8_1 = true ;
30283047 ggml_sycl_op_mul_mat (ctx, src0, src1, dst, ggml_sycl_op_mul_mat_q, convert_src1_to_q8_1);
30293048 } else {
30303049 constexpr bool convert_src1_to_q8_1 = false ;
3031- if (ggml_sycl_supports_reorder_dequantize (src0->type )) {
3032- opt_for_reorder (&ctx, src0, src1, dst); // the OP function in this branch support reorder.
3033- }
3050+ // MUL_MAT_SYCL supports reorder
3051+ opt_for_reorder (&ctx, src0, src1, dst, Mul_Mat_Algo::MUL_MAT_SYCL);
30343052 ggml_sycl_op_mul_mat (ctx, src0, src1, dst, ggml_sycl_op_mul_mat_sycl, convert_src1_to_q8_1);
30353053 }
30363054 GGML_SYCL_DEBUG (" call %s done\n " , __func__);
0 commit comments