@@ -149,6 +149,143 @@ static void ggml_backend_blas_mul_mat(ggml_backend_blas_context * ctx, struct gg
149149 }
150150}
151151
152+ static void ggml_backend_blas_mul_mat_id (ggml_backend_blas_context * ctx, ggml_tensor * dst) {
153+ const ggml_tensor * src0 = dst->src [0 ]; // weights
154+ const ggml_tensor * src1 = dst->src [1 ]; // inputs
155+ const ggml_tensor * src2 = dst->src [2 ]; // ids
156+
157+ GGML_TENSOR_TERNARY_OP_LOCALS
158+
159+ const ggml_type type = src0->type ;
160+
161+ GGML_ASSERT (ne10 == ne00);
162+ GGML_ASSERT (ne21 == ne12);
163+ GGML_ASSERT (ne22 == 1 || ne22 == ne13);
164+ GGML_ASSERT (src2->type == GGML_TYPE_I32);
165+
166+ GGML_ASSERT (nb00 == ggml_type_size (type));
167+ GGML_ASSERT (nb10 == ggml_type_size (src1->type ));
168+
169+ GGML_ASSERT (dst->type == GGML_TYPE_F32);
170+ GGML_ASSERT (nb0 == sizeof (float ));
171+ GGML_ASSERT (nb0 <= nb1 && nb1 <= nb2 && nb2 <= nb3);
172+
173+ const int64_t n_used = (int64_t )ne20;
174+ GGML_ASSERT (n_used <= ne02);
175+
176+ const int64_t ne_plane = ne01 * ne00;
177+ const size_t desired_wsize = (type == GGML_TYPE_F32) ? 0 : ne03 * ne02 * ne_plane * sizeof (float );
178+ if (ctx->work_size < desired_wsize) {
179+ ctx->work_data .reset (new char [desired_wsize]);
180+ ctx->work_size = desired_wsize;
181+ }
182+ void * wdata = ctx->work_data .get ();
183+
184+ if (type != GGML_TYPE_F32) {
185+ const auto * type_traits = ggml_get_type_traits (type);
186+ ggml_to_float_t to_float = type_traits->to_float ;
187+
188+ for (int64_t i03 = 0 ; i03 < ne03; ++i03) {
189+ for (int64_t i02 = 0 ; i02 < ne02; ++i02) {
190+ const void * x = (char *)src0->data + i02*nb02 + i03*nb03;
191+ float * wplane = (float *)wdata + i02*ne_plane + i03*ne02*ne_plane;
192+
193+ const int min_cols_per_thread = 4096 ;
194+ const int min_rows_per_thread = std::max ((int )(min_cols_per_thread / ne00), 1 );
195+ const int n_threads = std::max (std::min (ctx->n_threads , (int )(ne01 / min_rows_per_thread)), 1 );
196+
197+ #ifdef GGML_USE_OPENMP
198+ #pragma omp parallel for num_threads(n_threads)
199+ for (int64_t i01 = 0 ; i01 < ne01; ++i01) {
200+ to_float ((const char *)x + i01*nb01, wplane + i01*ne00, ne00);
201+ }
202+ #else
203+ for (int i = 1 ; i < n_threads; i++) {
204+ const int64_t start = i * ne01/n_threads;
205+ const int64_t end = (i + 1 ) * ne01/n_threads;
206+ if (start < end) {
207+ ctx->tasks .push_back (std::async (std::launch::async, [=]() {
208+ for (int64_t i01 = start; i01 < end; ++i01) {
209+ to_float ((const char *)x + i01*nb01, wplane + i01*ne00, ne00);
210+ }
211+ }));
212+ }
213+ }
214+ {
215+ const int64_t start = 0 ;
216+ const int64_t end = ne01/n_threads;
217+ for (int64_t i01 = start; i01 < end; i01++) {
218+ to_float ((const char *)x + i01*nb01, wplane + i01*ne00, ne00);
219+ }
220+ }
221+ #endif
222+ }
223+ }
224+
225+ #ifndef GGML_USE_OPENMP
226+ for (auto & task: ctx->tasks ) {
227+ task.get ();
228+ }
229+ ctx->tasks .clear ();
230+ #endif
231+ }
232+
233+ #ifdef OPENBLAS_VERSION
234+ openblas_set_num_threads (ctx->n_threads );
235+ #endif
236+
237+ #ifdef GGML_BLAS_USE_BLIS
238+ bli_thread_set_num_threads (ctx->n_threads );
239+ #endif
240+
241+ #ifdef GGML_BLAS_USE_NVPL
242+ nvpl_blas_set_num_threads (ctx->n_threads );
243+ #endif
244+
245+ for (int64_t i13 = 0 ; i13 < ne13; ++i13) {
246+ for (int64_t j = 0 ; j < ne12; ++j) {
247+ const int64_t ids_batch_index = (ne22 > 1 ? i13 : 0 );
248+ const int32_t * ids_row = (const int32_t *)((char *)src2->data + ids_batch_index*nb22 + j*nb21);
249+ float * out_ptr = (float *)((char *)dst->data + i13*nb3 + j*nb2);
250+
251+ for (int iE = 0 ; iE < n_used; ++iE) {
252+ const int expert_id = ids_row[iE];
253+ GGML_ASSERT (expert_id < ne02);
254+
255+ const float * wmat;
256+ if (type == GGML_TYPE_F32) {
257+ wmat = (const float *)((char *)src0->data + expert_id*nb02);
258+ } else {
259+ wmat = (const float *)((char *)wdata + expert_id * ne_plane * sizeof (float ));
260+ }
261+
262+ if (ne03 > 1 ) {
263+ int64_t w_batch_index = (ne03 == ne13 ? i13 : 0 );
264+ wmat = (const float *)((char *)wdata + (w_batch_index * ne02 + expert_id) * ne_plane * sizeof (float ));
265+ }
266+
267+ const float * inp = (const float *)((char *)src1->data
268+ + ((ne11 == 1 ? 0 : iE) * nb11)
269+ + j * nb12 + i13 * nb13);
270+
271+ if (iE == 0 ) {
272+ cblas_sgemv (CblasRowMajor, CblasNoTrans, (int )ne01, (int )ne00,
273+ 1 .0f , wmat, (int )ne00,
274+ inp, 1 ,
275+ 0 .0f ,
276+ out_ptr, 1 );
277+ } else {
278+ cblas_sgemv (CblasRowMajor, CblasNoTrans, (int )ne01, (int )ne00,
279+ 1 .0f , wmat, (int )ne00,
280+ inp, 1 ,
281+ 1 .0f ,
282+ out_ptr, 1 );
283+ }
284+ }
285+ }
286+ }
287+ }
288+
152289static void ggml_backend_blas_out_prod (ggml_backend_blas_context * ctx, struct ggml_tensor * dst) {
153290 const struct ggml_tensor * src0 = dst->src [0 ];
154291 const struct ggml_tensor * src1 = dst->src [1 ];
@@ -235,6 +372,10 @@ static enum ggml_status ggml_backend_blas_graph_compute(ggml_backend_t backend,
235372 ggml_backend_blas_mul_mat (ctx, node);
236373 break ;
237374
375+ case GGML_OP_MUL_MAT_ID:
376+ ggml_backend_blas_mul_mat_id (ctx, node);
377+ break ;
378+
238379 case GGML_OP_OUT_PROD:
239380 ggml_backend_blas_out_prod (ctx, node);
240381 break ;
@@ -418,6 +559,19 @@ static bool ggml_backend_blas_device_supports_op(ggml_backend_dev_t dev, const s
418559 (src0->type == GGML_TYPE_F32 || ggml_get_type_traits (src0->type )->to_float != NULL );
419560 }
420561
562+ case GGML_OP_MUL_MAT_ID:
563+ {
564+ const struct ggml_tensor * src0 = op->src [0 ];
565+ const struct ggml_tensor * src1 = op->src [1 ];
566+ const struct ggml_tensor * src2 = op->src [2 ];
567+
568+ // GGML_LOG_INFO("%s: op=GGML_OP_MUL_MAT_ID src0_type=%s src1_type=%s src2_type=%s ne0=%lld ne1=%lld ne2=%lld ne3=%lld\n",
569+ // __func__, ggml_type_name(src0->type), ggml_type_name(src1->type), ggml_type_name(src2->type),
570+ // op->ne[0], op->ne[1], op->ne[2], op->ne[3]);
571+
572+ return src2->type == GGML_TYPE_I32;
573+ }
574+
421575 case GGML_OP_OUT_PROD:
422576 return op->src [0 ]->type == GGML_TYPE_F32 &&
423577 op->src [1 ]->type == GGML_TYPE_F32 &&
0 commit comments