@@ -150,138 +150,132 @@ static void ggml_backend_blas_mul_mat(ggml_backend_blas_context * ctx, struct gg
150150}
151151
152152static void ggml_backend_blas_mul_mat_id (ggml_backend_blas_context * ctx, ggml_tensor * dst) {
153- const ggml_tensor * src0 = dst->src [0 ]; // weights
154- const ggml_tensor * src1 = dst->src [1 ]; // inputs
155- const ggml_tensor * src2 = dst->src [2 ]; // ids
156-
157- GGML_TENSOR_TERNARY_OP_LOCALS
153+ const struct ggml_tensor * src0 = dst->src [0 ];
154+ const struct ggml_tensor * src1 = dst->src [1 ];
155+ const struct ggml_tensor * ids = dst->src [2 ];
158156
159- const ggml_type type = src0-> type ;
157+ GGML_TENSOR_BINARY_OP_LOCALS
160158
161- GGML_ASSERT (ne10 == ne00);
162- GGML_ASSERT (ne21 == ne12);
163- GGML_ASSERT (ne22 == 1 || ne22 == ne13);
164- GGML_ASSERT (src2->type == GGML_TYPE_I32);
159+ const enum ggml_type type = src0->type ;
165160
166161 GGML_ASSERT (nb00 == ggml_type_size (type));
167162 GGML_ASSERT (nb10 == ggml_type_size (src1->type ));
168163
169- GGML_ASSERT (dst->type == GGML_TYPE_F32);
170164 GGML_ASSERT (nb0 == sizeof (float ));
171- GGML_ASSERT (nb0 <= nb1 && nb1 <= nb2 && nb2 <= nb3);
165+ GGML_ASSERT (nb0 <= nb1);
166+ GGML_ASSERT (nb1 <= nb2);
167+ GGML_ASSERT (nb2 <= nb3);
168+
169+ GGML_ASSERT (ne03 == 1 );
170+ GGML_ASSERT (ne13 == 1 );
171+ GGML_ASSERT (ne3 == 1 );
172+
173+ GGML_ASSERT (src1->type == GGML_TYPE_F32);
174+ GGML_ASSERT (ids->type == GGML_TYPE_I32);
172175
173- const int64_t n_used = (int64_t )ne20;
174- GGML_ASSERT (n_used <= ne02);
176+ // broadcast factors
177+ const int64_t r2 = ne12/ne02;
178+ const int64_t r3 = ne13/ne03;
179+
180+ GGML_UNUSED (r2);
181+ GGML_UNUSED (r3);
182+
183+ const int64_t ne_plane = ne01*ne00;
184+ const size_t desired_wsize = type == GGML_TYPE_F32 ? 0 : ne03*ne02*ne_plane*sizeof (float );
175185
176- const int64_t ne_plane = ne01 * ne00;
177- const size_t desired_wsize = (type == GGML_TYPE_F32) ? 0 : ne03 * ne02 * ne_plane * sizeof (float );
178186 if (ctx->work_size < desired_wsize) {
179187 ctx->work_data .reset (new char [desired_wsize]);
180188 ctx->work_size = desired_wsize;
181189 }
182190 void * wdata = ctx->work_data .get ();
183191
192+ // convert src0 to float
184193 if (type != GGML_TYPE_F32) {
185194 const auto * type_traits = ggml_get_type_traits (type);
186- ggml_to_float_t to_float = type_traits->to_float ;
195+ ggml_to_float_t const to_float = type_traits->to_float ;
187196
188- for (int64_t i03 = 0 ; i03 < ne03; ++i03 ) {
189- for (int64_t i02 = 0 ; i02 < ne02; ++i02 ) {
190- const void * x = (char *)src0->data + i02*nb02 + i03*nb03;
191- float * wplane = (float *)wdata + i02*ne_plane + i03*ne02*ne_plane;
197+ for (int64_t i03 = 0 ; i03 < ne03; i03++ ) {
198+ for (int64_t i02 = 0 ; i02 < ne02; i02++ ) {
199+ const void * x = (char *) src0->data + i02*nb02 + i03*nb03;
200+ float * const wplane = (float *) wdata + i02*ne_plane + i03*ne02*ne_plane;
192201
193202 const int min_cols_per_thread = 4096 ;
194- const int min_rows_per_thread = std::max ((int )(min_cols_per_thread / ne00), 1 );
195- const int n_threads = std::max (std::min (ctx->n_threads , (int )(ne01 / min_rows_per_thread)), 1 );
203+ const int min_rows_per_thread = std::max ((int )(min_cols_per_thread/ ne00), 1 );
204+ const int n_threads = std::max (std::min (ctx->n_threads , (int )(ne01/ min_rows_per_thread)), 1 );
196205
197206#ifdef GGML_USE_OPENMP
198207 #pragma omp parallel for num_threads(n_threads)
199- for (int64_t i01 = 0 ; i01 < ne01; ++i01 ) {
200- to_float ((const char *)x + i01*nb01, wplane + i01*ne00, ne00);
208+ for (int64_t i01 = 0 ; i01 < ne01; i01++ ) {
209+ to_float ((const char *) x + i01*nb01, wplane + i01*ne00, ne00);
201210 }
202211#else
203212 for (int i = 1 ; i < n_threads; i++) {
204- const int64_t start = i * ne01/n_threads;
205- const int64_t end = (i + 1 ) * ne01/n_threads;
213+ const int64_t start = i* ne01/n_threads;
214+ const int64_t end = (i + 1 )* ne01/n_threads;
206215 if (start < end) {
207216 ctx->tasks .push_back (std::async (std::launch::async, [=]() {
208- for (int64_t i01 = start; i01 < end; ++i01 ) {
209- to_float ((const char *)x + i01*nb01, wplane + i01*ne00, ne00);
217+ for (int64_t i01 = start; i01 < end; i01++ ) {
218+ to_float ((const char *) x + i01*nb01, wplane + i01*ne00, ne00);
210219 }
211220 }));
212221 }
213222 }
214223 {
224+ // reuse the current thread for the first task
215225 const int64_t start = 0 ;
216- const int64_t end = ne01/n_threads;
226+ const int64_t end = ne01/n_threads;
217227 for (int64_t i01 = start; i01 < end; i01++) {
218- to_float ((const char *)x + i01*nb01, wplane + i01*ne00, ne00);
228+ to_float ((const char *) x + i01*nb01, wplane + i01*ne00, ne00);
219229 }
220230 }
221231#endif
222232 }
223233 }
224234
225235#ifndef GGML_USE_OPENMP
226- for (auto & task: ctx->tasks ) {
236+ // wait for all tasks to finish
237+ for (auto & task : ctx->tasks ) {
227238 task.get ();
228239 }
229240 ctx->tasks .clear ();
230241#endif
231242 }
232243
233- #ifdef OPENBLAS_VERSION
244+ #if defined( OPENBLAS_VERSION)
234245 openblas_set_num_threads (ctx->n_threads );
235246#endif
236247
237- #ifdef GGML_BLAS_USE_BLIS
248+ #if defined( GGML_BLAS_USE_BLIS)
238249 bli_thread_set_num_threads (ctx->n_threads );
239250#endif
240251
241- #ifdef GGML_BLAS_USE_NVPL
252+ #if defined( GGML_BLAS_USE_NVPL)
242253 nvpl_blas_set_num_threads (ctx->n_threads );
243254#endif
244255
245- for (int64_t i13 = 0 ; i13 < ne13; ++i13) {
246- for (int64_t j = 0 ; j < ne12; ++j) {
247- const int64_t ids_batch_index = (ne22 > 1 ? i13 : 0 );
248- const int32_t * ids_row = (const int32_t *)((char *)src2->data + ids_batch_index*nb22 + j*nb21);
249- float * out_ptr = (float *)((char *)dst->data + i13*nb3 + j*nb2);
250-
251- for (int iE = 0 ; iE < n_used; ++iE) {
252- const int expert_id = ids_row[iE];
253- GGML_ASSERT (expert_id < ne02);
254-
255- const float * wmat;
256- if (type == GGML_TYPE_F32) {
257- wmat = (const float *)((char *)src0->data + expert_id*nb02);
258- } else {
259- wmat = (const float *)((char *)wdata + expert_id * ne_plane * sizeof (float ));
260- }
256+ const int n_ids = ids->ne [0 ];
257+ const int n_tokens = ids->ne [1 ];
261258
262- if (ne03 > 1 ) {
263- int64_t w_batch_index = (ne03 == ne13 ? i13 : 0 );
264- wmat = (const float *)(( char *)wdata + (w_batch_index * ne02 + expert_id) * ne_plane * sizeof ( float ) );
265- }
259+ for ( int t = 0 ; t < n_tokens; ++t ) {
260+ for ( int e = 0 ; e < n_ids; ++e) {
261+ const int32_t expert = * (const int32_t *) (( const char *) ids-> data + e*ids-> nb [ 0 ] + t*ids-> nb [ 1 ] );
262+ GGML_ASSERT (expert >= 0 && expert < ne02);
266263
267- const float * inp = (const float *)((char *)src1->data
268- + ((ne11 == 1 ? 0 : iE) * nb11)
269- + j * nb12 + i13 * nb13);
270-
271- if (iE == 0 ) {
272- cblas_sgemv (CblasRowMajor, CblasNoTrans, (int )ne01, (int )ne00,
273- 1 .0f , wmat, (int )ne00,
274- inp, 1 ,
275- 0 .0f ,
276- out_ptr, 1 );
277- } else {
278- cblas_sgemv (CblasRowMajor, CblasNoTrans, (int )ne01, (int )ne00,
279- 1 .0f , wmat, (int )ne00,
280- inp, 1 ,
281- 1 .0f ,
282- out_ptr, 1 );
283- }
264+ const int e_src1 = e % ne11;
265+
266+ const float * a = (float *) ((char *) src0->data + expert*nb02);
267+ const float * b = (float *) ((char *) src1->data + e_src1*nb11 + t*nb12);
268+ float * d = (float *) ((char *) dst->data + e*nb1 + t*nb2);
269+
270+ if (type != GGML_TYPE_F32) {
271+ a = (float *) wdata + expert*ne_plane;
284272 }
273+
274+ cblas_sgemv (CblasRowMajor, CblasNoTrans,
275+ ne01, ne00,
276+ 1 .0f , a, ne00,
277+ b, 1 ,
278+ 0 .0f , d, 1 );
285279 }
286280 }
287281}
0 commit comments