Skip to content

Commit 19c8ec9

Browse files
committed
ggml-blas: fully working mmid
Signed-off-by: Aaron Teo <[email protected]>
1 parent f682374 commit 19c8ec9

File tree

1 file changed

+66
-72
lines changed

1 file changed

+66
-72
lines changed

ggml/src/ggml-blas/ggml-blas.cpp

Lines changed: 66 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -150,138 +150,132 @@ static void ggml_backend_blas_mul_mat(ggml_backend_blas_context * ctx, struct gg
150150
}
151151

152152
static void ggml_backend_blas_mul_mat_id(ggml_backend_blas_context * ctx, ggml_tensor * dst) {
153-
const ggml_tensor * src0 = dst->src[0]; // weights
154-
const ggml_tensor * src1 = dst->src[1]; // inputs
155-
const ggml_tensor * src2 = dst->src[2]; // ids
156-
157-
GGML_TENSOR_TERNARY_OP_LOCALS
153+
const struct ggml_tensor * src0 = dst->src[0];
154+
const struct ggml_tensor * src1 = dst->src[1];
155+
const struct ggml_tensor * ids = dst->src[2];
158156

159-
const ggml_type type = src0->type;
157+
GGML_TENSOR_BINARY_OP_LOCALS
160158

161-
GGML_ASSERT(ne10 == ne00);
162-
GGML_ASSERT(ne21 == ne12);
163-
GGML_ASSERT(ne22 == 1 || ne22 == ne13);
164-
GGML_ASSERT(src2->type == GGML_TYPE_I32);
159+
const enum ggml_type type = src0->type;
165160

166161
GGML_ASSERT(nb00 == ggml_type_size(type));
167162
GGML_ASSERT(nb10 == ggml_type_size(src1->type));
168163

169-
GGML_ASSERT(dst->type == GGML_TYPE_F32);
170164
GGML_ASSERT(nb0 == sizeof(float));
171-
GGML_ASSERT(nb0 <= nb1 && nb1 <= nb2 && nb2 <= nb3);
165+
GGML_ASSERT(nb0 <= nb1);
166+
GGML_ASSERT(nb1 <= nb2);
167+
GGML_ASSERT(nb2 <= nb3);
168+
169+
GGML_ASSERT(ne03 == 1);
170+
GGML_ASSERT(ne13 == 1);
171+
GGML_ASSERT(ne3 == 1);
172+
173+
GGML_ASSERT(src1->type == GGML_TYPE_F32);
174+
GGML_ASSERT(ids->type == GGML_TYPE_I32);
172175

173-
const int64_t n_used = (int64_t)ne20;
174-
GGML_ASSERT(n_used <= ne02);
176+
// broadcast factors
177+
const int64_t r2 = ne12/ne02;
178+
const int64_t r3 = ne13/ne03;
179+
180+
GGML_UNUSED(r2);
181+
GGML_UNUSED(r3);
182+
183+
const int64_t ne_plane = ne01*ne00;
184+
const size_t desired_wsize = type == GGML_TYPE_F32 ? 0 : ne03*ne02*ne_plane*sizeof(float);
175185

176-
const int64_t ne_plane = ne01 * ne00;
177-
const size_t desired_wsize = (type == GGML_TYPE_F32) ? 0 : ne03 * ne02 * ne_plane * sizeof(float);
178186
if (ctx->work_size < desired_wsize) {
179187
ctx->work_data.reset(new char[desired_wsize]);
180188
ctx->work_size = desired_wsize;
181189
}
182190
void * wdata = ctx->work_data.get();
183191

192+
// convert src0 to float
184193
if (type != GGML_TYPE_F32) {
185194
const auto * type_traits = ggml_get_type_traits(type);
186-
ggml_to_float_t to_float = type_traits->to_float;
195+
ggml_to_float_t const to_float = type_traits->to_float;
187196

188-
for (int64_t i03 = 0; i03 < ne03; ++i03) {
189-
for (int64_t i02 = 0; i02 < ne02; ++i02) {
190-
const void * x = (char *)src0->data + i02*nb02 + i03*nb03;
191-
float * wplane = (float *)wdata + i02*ne_plane + i03*ne02*ne_plane;
197+
for (int64_t i03 = 0; i03 < ne03; i03++) {
198+
for (int64_t i02 = 0; i02 < ne02; i02++) {
199+
const void * x = (char *) src0->data + i02*nb02 + i03*nb03;
200+
float * const wplane = (float *) wdata + i02*ne_plane + i03*ne02*ne_plane;
192201

193202
const int min_cols_per_thread = 4096;
194-
const int min_rows_per_thread = std::max((int)(min_cols_per_thread / ne00), 1);
195-
const int n_threads = std::max(std::min(ctx->n_threads, (int)(ne01 / min_rows_per_thread)), 1);
203+
const int min_rows_per_thread = std::max((int)(min_cols_per_thread/ne00), 1);
204+
const int n_threads = std::max(std::min(ctx->n_threads, (int)(ne01/min_rows_per_thread)), 1);
196205

197206
#ifdef GGML_USE_OPENMP
198207
#pragma omp parallel for num_threads(n_threads)
199-
for (int64_t i01 = 0; i01 < ne01; ++i01) {
200-
to_float((const char *)x + i01*nb01, wplane + i01*ne00, ne00);
208+
for (int64_t i01 = 0; i01 < ne01; i01++) {
209+
to_float((const char *) x + i01*nb01, wplane + i01*ne00, ne00);
201210
}
202211
#else
203212
for (int i = 1; i < n_threads; i++) {
204-
const int64_t start = i * ne01/n_threads;
205-
const int64_t end = (i + 1) * ne01/n_threads;
213+
const int64_t start = i*ne01/n_threads;
214+
const int64_t end = (i + 1)*ne01/n_threads;
206215
if (start < end) {
207216
ctx->tasks.push_back(std::async(std::launch::async, [=]() {
208-
for (int64_t i01 = start; i01 < end; ++i01) {
209-
to_float((const char *)x + i01*nb01, wplane + i01*ne00, ne00);
217+
for (int64_t i01 = start; i01 < end; i01++) {
218+
to_float((const char *) x + i01*nb01, wplane + i01*ne00, ne00);
210219
}
211220
}));
212221
}
213222
}
214223
{
224+
// reuse the current thread for the first task
215225
const int64_t start = 0;
216-
const int64_t end = ne01/n_threads;
226+
const int64_t end = ne01/n_threads;
217227
for (int64_t i01 = start; i01 < end; i01++) {
218-
to_float((const char *)x + i01*nb01, wplane + i01*ne00, ne00);
228+
to_float((const char *) x + i01*nb01, wplane + i01*ne00, ne00);
219229
}
220230
}
221231
#endif
222232
}
223233
}
224234

225235
#ifndef GGML_USE_OPENMP
226-
for (auto & task: ctx->tasks) {
236+
// wait for all tasks to finish
237+
for (auto & task : ctx->tasks) {
227238
task.get();
228239
}
229240
ctx->tasks.clear();
230241
#endif
231242
}
232243

233-
#ifdef OPENBLAS_VERSION
244+
#if defined(OPENBLAS_VERSION)
234245
openblas_set_num_threads(ctx->n_threads);
235246
#endif
236247

237-
#ifdef GGML_BLAS_USE_BLIS
248+
#if defined(GGML_BLAS_USE_BLIS)
238249
bli_thread_set_num_threads(ctx->n_threads);
239250
#endif
240251

241-
#ifdef GGML_BLAS_USE_NVPL
252+
#if defined(GGML_BLAS_USE_NVPL)
242253
nvpl_blas_set_num_threads(ctx->n_threads);
243254
#endif
244255

245-
for (int64_t i13 = 0; i13 < ne13; ++i13) {
246-
for (int64_t j = 0; j < ne12; ++j) {
247-
const int64_t ids_batch_index = (ne22 > 1 ? i13 : 0);
248-
const int32_t * ids_row = (const int32_t *)((char *)src2->data + ids_batch_index*nb22 + j*nb21);
249-
float * out_ptr = (float *)((char *)dst->data + i13*nb3 + j*nb2);
250-
251-
for (int iE = 0; iE < n_used; ++iE) {
252-
const int expert_id = ids_row[iE];
253-
GGML_ASSERT(expert_id < ne02);
254-
255-
const float * wmat;
256-
if (type == GGML_TYPE_F32) {
257-
wmat = (const float *)((char *)src0->data + expert_id*nb02);
258-
} else {
259-
wmat = (const float *)((char *)wdata + expert_id * ne_plane * sizeof(float));
260-
}
256+
const int n_ids = ids->ne[0];
257+
const int n_tokens = ids->ne[1];
261258

262-
if (ne03 > 1) {
263-
int64_t w_batch_index = (ne03 == ne13 ? i13 : 0);
264-
wmat = (const float *)((char *)wdata + (w_batch_index * ne02 + expert_id) * ne_plane * sizeof(float));
265-
}
259+
for (int t = 0; t < n_tokens; ++t) {
260+
for (int e = 0; e < n_ids; ++e) {
261+
const int32_t expert = *(const int32_t *) ((const char *) ids->data + e*ids->nb[0] + t*ids->nb[1]);
262+
GGML_ASSERT(expert >= 0 && expert < ne02);
266263

267-
const float * inp = (const float *)((char *)src1->data
268-
+ ((ne11 == 1 ? 0 : iE) * nb11)
269-
+ j * nb12 + i13 * nb13);
270-
271-
if (iE == 0) {
272-
cblas_sgemv(CblasRowMajor, CblasNoTrans, (int)ne01, (int)ne00,
273-
1.0f, wmat, (int)ne00,
274-
inp, 1,
275-
0.0f,
276-
out_ptr, 1);
277-
} else {
278-
cblas_sgemv(CblasRowMajor, CblasNoTrans, (int)ne01, (int)ne00,
279-
1.0f, wmat, (int)ne00,
280-
inp, 1,
281-
1.0f,
282-
out_ptr, 1);
283-
}
264+
const int e_src1 = e % ne11;
265+
266+
const float * a = (float *) ((char *) src0->data + expert*nb02);
267+
const float * b = (float *) ((char *) src1->data + e_src1*nb11 + t*nb12);
268+
float * d = (float *) ((char *) dst->data + e*nb1 + t*nb2);
269+
270+
if (type != GGML_TYPE_F32) {
271+
a = (float *) wdata + expert*ne_plane;
284272
}
273+
274+
cblas_sgemv(CblasRowMajor, CblasNoTrans,
275+
ne01, ne00,
276+
1.0f, a, ne00,
277+
b, 1,
278+
0.0f, d, 1);
285279
}
286280
}
287281
}

0 commit comments

Comments
 (0)