Skip to content

Commit f682374

Browse files
committed
ggml-blas: initial mmid impl
Signed-off-by: Aaron Teo <[email protected]>
1 parent e4ae383 commit f682374

File tree

1 file changed

+154
-0
lines changed

1 file changed

+154
-0
lines changed

ggml/src/ggml-blas/ggml-blas.cpp

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,143 @@ static void ggml_backend_blas_mul_mat(ggml_backend_blas_context * ctx, struct gg
149149
}
150150
}
151151

152+
static void ggml_backend_blas_mul_mat_id(ggml_backend_blas_context * ctx, ggml_tensor * dst) {
153+
const ggml_tensor * src0 = dst->src[0]; // weights
154+
const ggml_tensor * src1 = dst->src[1]; // inputs
155+
const ggml_tensor * src2 = dst->src[2]; // ids
156+
157+
GGML_TENSOR_TERNARY_OP_LOCALS
158+
159+
const ggml_type type = src0->type;
160+
161+
GGML_ASSERT(ne10 == ne00);
162+
GGML_ASSERT(ne21 == ne12);
163+
GGML_ASSERT(ne22 == 1 || ne22 == ne13);
164+
GGML_ASSERT(src2->type == GGML_TYPE_I32);
165+
166+
GGML_ASSERT(nb00 == ggml_type_size(type));
167+
GGML_ASSERT(nb10 == ggml_type_size(src1->type));
168+
169+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
170+
GGML_ASSERT(nb0 == sizeof(float));
171+
GGML_ASSERT(nb0 <= nb1 && nb1 <= nb2 && nb2 <= nb3);
172+
173+
const int64_t n_used = (int64_t)ne20;
174+
GGML_ASSERT(n_used <= ne02);
175+
176+
const int64_t ne_plane = ne01 * ne00;
177+
const size_t desired_wsize = (type == GGML_TYPE_F32) ? 0 : ne03 * ne02 * ne_plane * sizeof(float);
178+
if (ctx->work_size < desired_wsize) {
179+
ctx->work_data.reset(new char[desired_wsize]);
180+
ctx->work_size = desired_wsize;
181+
}
182+
void * wdata = ctx->work_data.get();
183+
184+
if (type != GGML_TYPE_F32) {
185+
const auto * type_traits = ggml_get_type_traits(type);
186+
ggml_to_float_t to_float = type_traits->to_float;
187+
188+
for (int64_t i03 = 0; i03 < ne03; ++i03) {
189+
for (int64_t i02 = 0; i02 < ne02; ++i02) {
190+
const void * x = (char *)src0->data + i02*nb02 + i03*nb03;
191+
float * wplane = (float *)wdata + i02*ne_plane + i03*ne02*ne_plane;
192+
193+
const int min_cols_per_thread = 4096;
194+
const int min_rows_per_thread = std::max((int)(min_cols_per_thread / ne00), 1);
195+
const int n_threads = std::max(std::min(ctx->n_threads, (int)(ne01 / min_rows_per_thread)), 1);
196+
197+
#ifdef GGML_USE_OPENMP
198+
#pragma omp parallel for num_threads(n_threads)
199+
for (int64_t i01 = 0; i01 < ne01; ++i01) {
200+
to_float((const char *)x + i01*nb01, wplane + i01*ne00, ne00);
201+
}
202+
#else
203+
for (int i = 1; i < n_threads; i++) {
204+
const int64_t start = i * ne01/n_threads;
205+
const int64_t end = (i + 1) * ne01/n_threads;
206+
if (start < end) {
207+
ctx->tasks.push_back(std::async(std::launch::async, [=]() {
208+
for (int64_t i01 = start; i01 < end; ++i01) {
209+
to_float((const char *)x + i01*nb01, wplane + i01*ne00, ne00);
210+
}
211+
}));
212+
}
213+
}
214+
{
215+
const int64_t start = 0;
216+
const int64_t end = ne01/n_threads;
217+
for (int64_t i01 = start; i01 < end; i01++) {
218+
to_float((const char *)x + i01*nb01, wplane + i01*ne00, ne00);
219+
}
220+
}
221+
#endif
222+
}
223+
}
224+
225+
#ifndef GGML_USE_OPENMP
226+
for (auto & task: ctx->tasks) {
227+
task.get();
228+
}
229+
ctx->tasks.clear();
230+
#endif
231+
}
232+
233+
#ifdef OPENBLAS_VERSION
234+
openblas_set_num_threads(ctx->n_threads);
235+
#endif
236+
237+
#ifdef GGML_BLAS_USE_BLIS
238+
bli_thread_set_num_threads(ctx->n_threads);
239+
#endif
240+
241+
#ifdef GGML_BLAS_USE_NVPL
242+
nvpl_blas_set_num_threads(ctx->n_threads);
243+
#endif
244+
245+
for (int64_t i13 = 0; i13 < ne13; ++i13) {
246+
for (int64_t j = 0; j < ne12; ++j) {
247+
const int64_t ids_batch_index = (ne22 > 1 ? i13 : 0);
248+
const int32_t * ids_row = (const int32_t *)((char *)src2->data + ids_batch_index*nb22 + j*nb21);
249+
float * out_ptr = (float *)((char *)dst->data + i13*nb3 + j*nb2);
250+
251+
for (int iE = 0; iE < n_used; ++iE) {
252+
const int expert_id = ids_row[iE];
253+
GGML_ASSERT(expert_id < ne02);
254+
255+
const float * wmat;
256+
if (type == GGML_TYPE_F32) {
257+
wmat = (const float *)((char *)src0->data + expert_id*nb02);
258+
} else {
259+
wmat = (const float *)((char *)wdata + expert_id * ne_plane * sizeof(float));
260+
}
261+
262+
if (ne03 > 1) {
263+
int64_t w_batch_index = (ne03 == ne13 ? i13 : 0);
264+
wmat = (const float *)((char *)wdata + (w_batch_index * ne02 + expert_id) * ne_plane * sizeof(float));
265+
}
266+
267+
const float * inp = (const float *)((char *)src1->data
268+
+ ((ne11 == 1 ? 0 : iE) * nb11)
269+
+ j * nb12 + i13 * nb13);
270+
271+
if (iE == 0) {
272+
cblas_sgemv(CblasRowMajor, CblasNoTrans, (int)ne01, (int)ne00,
273+
1.0f, wmat, (int)ne00,
274+
inp, 1,
275+
0.0f,
276+
out_ptr, 1);
277+
} else {
278+
cblas_sgemv(CblasRowMajor, CblasNoTrans, (int)ne01, (int)ne00,
279+
1.0f, wmat, (int)ne00,
280+
inp, 1,
281+
1.0f,
282+
out_ptr, 1);
283+
}
284+
}
285+
}
286+
}
287+
}
288+
152289
static void ggml_backend_blas_out_prod(ggml_backend_blas_context * ctx, struct ggml_tensor * dst) {
153290
const struct ggml_tensor * src0 = dst->src[0];
154291
const struct ggml_tensor * src1 = dst->src[1];
@@ -235,6 +372,10 @@ static enum ggml_status ggml_backend_blas_graph_compute(ggml_backend_t backend,
235372
ggml_backend_blas_mul_mat(ctx, node);
236373
break;
237374

375+
case GGML_OP_MUL_MAT_ID:
376+
ggml_backend_blas_mul_mat_id(ctx, node);
377+
break;
378+
238379
case GGML_OP_OUT_PROD:
239380
ggml_backend_blas_out_prod(ctx, node);
240381
break;
@@ -418,6 +559,19 @@ static bool ggml_backend_blas_device_supports_op(ggml_backend_dev_t dev, const s
418559
(src0->type == GGML_TYPE_F32 || ggml_get_type_traits(src0->type)->to_float != NULL);
419560
}
420561

562+
case GGML_OP_MUL_MAT_ID:
563+
{
564+
const struct ggml_tensor * src0 = op->src[0];
565+
const struct ggml_tensor * src1 = op->src[1];
566+
const struct ggml_tensor * src2 = op->src[2];
567+
568+
// GGML_LOG_INFO("%s: op=GGML_OP_MUL_MAT_ID src0_type=%s src1_type=%s src2_type=%s ne0=%lld ne1=%lld ne2=%lld ne3=%lld\n",
569+
// __func__, ggml_type_name(src0->type), ggml_type_name(src1->type), ggml_type_name(src2->type),
570+
// op->ne[0], op->ne[1], op->ne[2], op->ne[3]);
571+
572+
return src2->type == GGML_TYPE_I32;
573+
}
574+
421575
case GGML_OP_OUT_PROD:
422576
return op->src[0]->type == GGML_TYPE_F32 &&
423577
op->src[1]->type == GGML_TYPE_F32 &&

0 commit comments

Comments
 (0)