@@ -27,13 +27,16 @@ struct ggml_backend_blas_context {
2727#endif
2828};
2929
30- static void ggml_backend_blas_mul_mat (ggml_backend_blas_context * ctx, struct ggml_tensor * dst) {
31- const struct ggml_tensor * src0 = dst->src [0 ];
32- const struct ggml_tensor * src1 = dst->src [1 ];
30+ static void ggml_backend_blas_mul_mat (
31+ ggml_backend_blas_context * ctx,
32+ ggml_tensor * dst) {
33+
34+ const ggml_tensor * src0 = dst->src [0 ];
35+ const ggml_tensor * src1 = dst->src [1 ];
3336
3437 GGML_TENSOR_BINARY_OP_LOCALS
3538
36- const enum ggml_type type = src0->type ;
39+ const ggml_type type = src0->type ;
3740
3841 GGML_ASSERT (ne0 == ne01);
3942 GGML_ASSERT (ne1 == ne11);
@@ -70,8 +73,8 @@ static void ggml_backend_blas_mul_mat(ggml_backend_blas_context * ctx, struct gg
7073
7174 for (int64_t i03 = 0 ; i03 < ne03; i03++) {
7275 for (int64_t i02 = 0 ; i02 < ne02; i02++) {
73- const void * x = (char *) src0->data + i02*nb02 + i03*nb03;
74- float * const wplane = (float *) wdata + i02*ne_plane + i03*ne02*ne_plane;
76+ const void * x = (char *) src0->data + i02*nb02 + i03*nb03;
77+ float * const wplane = (float *) wdata + i02*ne_plane + i03*ne02*ne_plane;
7578
7679 const int min_cols_per_thread = 4096 ;
7780 const int min_rows_per_thread = std::max ((int )(min_cols_per_thread/ne00), 1 );
@@ -84,8 +87,8 @@ static void ggml_backend_blas_mul_mat(ggml_backend_blas_context * ctx, struct gg
8487 }
8588#else
8689 for (int i = 1 ; i < n_threads; i++) {
87- const int64_t start = i* ne01/n_threads;
88- const int64_t end = (i + 1 )* ne01/n_threads;
90+ const int64_t start = (i + 0 ) * ne01/n_threads;
91+ const int64_t end = (i + 1 ) * ne01/n_threads;
8992 if (start < end) {
9093 ctx->tasks .push_back (std::async (std::launch::async, [=]() {
9194 for (int64_t i01 = start; i01 < end; i01++) {
@@ -149,14 +152,17 @@ static void ggml_backend_blas_mul_mat(ggml_backend_blas_context * ctx, struct gg
149152 }
150153}
151154
152- static void ggml_backend_blas_mul_mat_id (ggml_backend_blas_context * ctx, ggml_tensor * dst) {
153- const struct ggml_tensor * src0 = dst->src [0 ];
154- const struct ggml_tensor * src1 = dst->src [1 ];
155- const struct ggml_tensor * ids = dst->src [2 ];
155+ static void ggml_backend_blas_mul_mat_id (
156+ ggml_backend_blas_context * ctx,
157+ ggml_tensor * dst) {
158+
159+ const ggml_tensor * src0 = dst->src [0 ];
160+ const ggml_tensor * src1 = dst->src [1 ];
161+ const ggml_tensor * ids = dst->src [2 ];
156162
157163 GGML_TENSOR_BINARY_OP_LOCALS
158164
159- const enum ggml_type type = src0->type ;
165+ const ggml_type type = src0->type ;
160166
161167 GGML_ASSERT (nb00 == ggml_type_size (type));
162168 GGML_ASSERT (nb10 == ggml_type_size (src1->type ));
@@ -173,15 +179,10 @@ static void ggml_backend_blas_mul_mat_id(ggml_backend_blas_context * ctx, ggml_t
173179 GGML_ASSERT (src1->type == GGML_TYPE_F32);
174180 GGML_ASSERT (ids->type == GGML_TYPE_I32);
175181
176- // broadcast factors
177- const int64_t r2 = ne12/ne02;
178- const int64_t r3 = ne13/ne03;
179-
180- GGML_UNUSED (r2);
181- GGML_UNUSED (r3);
182-
183182 const int64_t ne_plane = ne01*ne00;
184- const size_t desired_wsize = type == GGML_TYPE_F32 ? 0 : ne03*ne02*ne_plane*sizeof (float );
183+ const size_t desired_wsize = type == GGML_TYPE_F32
184+ ? 0
185+ : ne03*ne02*ne_plane*sizeof (float );
185186
186187 if (ctx->work_size < desired_wsize) {
187188 ctx->work_data .reset (new char [desired_wsize]);
@@ -196,8 +197,8 @@ static void ggml_backend_blas_mul_mat_id(ggml_backend_blas_context * ctx, ggml_t
196197
197198 for (int64_t i03 = 0 ; i03 < ne03; i03++) {
198199 for (int64_t i02 = 0 ; i02 < ne02; i02++) {
199- const void * x = (char *) src0->data + i02*nb02 + i03*nb03;
200- float * const wplane = (float *) wdata + i02*ne_plane + i03*ne02*ne_plane;
200+ const void * x = (char *) src0->data + i02*nb02 + i03*nb03;
201+ float * const wplane = (float *) wdata + i02*ne_plane + i03*ne02*ne_plane;
201202
202203 const int min_cols_per_thread = 4096 ;
203204 const int min_rows_per_thread = std::max ((int )(min_cols_per_thread/ne00), 1 );
@@ -210,7 +211,7 @@ static void ggml_backend_blas_mul_mat_id(ggml_backend_blas_context * ctx, ggml_t
210211 }
211212#else
212213 for (int i = 1 ; i < n_threads; i++) {
213- const int64_t start = i *ne01/n_threads;
214+ const int64_t start = (i + 0 ) *ne01/n_threads;
214215 const int64_t end = (i + 1 )*ne01/n_threads;
215216 if (start < end) {
216217 ctx->tasks .push_back (std::async (std::launch::async, [=]() {
@@ -555,15 +556,13 @@ static bool ggml_backend_blas_device_supports_op(ggml_backend_dev_t dev, const s
555556
556557 case GGML_OP_MUL_MAT_ID:
557558 {
558- const struct ggml_tensor * src0 = op->src [0 ];
559- const struct ggml_tensor * src1 = op->src [1 ];
560- const struct ggml_tensor * src2 = op->src [2 ];
561-
562- // GGML_LOG_INFO("%s: op=GGML_OP_MUL_MAT_ID src0_type=%s src1_type=%s src2_type=%s ne0=%lld ne1=%lld ne2=%lld ne3=%lld\n",
563- // __func__, ggml_type_name(src0->type), ggml_type_name(src1->type), ggml_type_name(src2->type),
564- // op->ne[0], op->ne[1], op->ne[2], op->ne[3]);
559+ const ggml_tensor * src0 = op->src [0 ];
560+ const ggml_tensor * src1 = op->src [1 ];
565561
566- return src2->type == GGML_TYPE_I32;
562+ return ggml_is_contiguous (src0) &&
563+ ggml_is_contiguous (src1) &&
564+ src1->type == GGML_TYPE_F32 &&
565+ (src0->type == GGML_TYPE_F32 || ggml_get_type_traits (src0->type )->to_float != NULL );
567566 }
568567
569568 case GGML_OP_OUT_PROD:
0 commit comments