Skip to content

Commit 1926e07

Browse files
committed
ggml-blas: code clean up
Signed-off-by: Aaron Teo <[email protected]>
1 parent 19c8ec9 commit 1926e07

File tree

1 file changed

+31
-32
lines changed

1 file changed

+31
-32
lines changed

ggml/src/ggml-blas/ggml-blas.cpp

Lines changed: 31 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,16 @@ struct ggml_backend_blas_context {
2727
#endif
2828
};
2929

30-
static void ggml_backend_blas_mul_mat(ggml_backend_blas_context * ctx, struct ggml_tensor * dst) {
31-
const struct ggml_tensor * src0 = dst->src[0];
32-
const struct ggml_tensor * src1 = dst->src[1];
30+
static void ggml_backend_blas_mul_mat(
31+
ggml_backend_blas_context * ctx,
32+
ggml_tensor * dst) {
33+
34+
const ggml_tensor * src0 = dst->src[0];
35+
const ggml_tensor * src1 = dst->src[1];
3336

3437
GGML_TENSOR_BINARY_OP_LOCALS
3538

36-
const enum ggml_type type = src0->type;
39+
const ggml_type type = src0->type;
3740

3841
GGML_ASSERT(ne0 == ne01);
3942
GGML_ASSERT(ne1 == ne11);
@@ -70,8 +73,8 @@ static void ggml_backend_blas_mul_mat(ggml_backend_blas_context * ctx, struct gg
7073

7174
for (int64_t i03 = 0; i03 < ne03; i03++) {
7275
for (int64_t i02 = 0; i02 < ne02; i02++) {
73-
const void * x = (char *) src0->data + i02*nb02 + i03*nb03;
74-
float * const wplane = (float *) wdata + i02*ne_plane + i03*ne02*ne_plane;
76+
const void * x = (char *) src0->data + i02*nb02 + i03*nb03;
77+
float * const wplane = (float *) wdata + i02*ne_plane + i03*ne02*ne_plane;
7578

7679
const int min_cols_per_thread = 4096;
7780
const int min_rows_per_thread = std::max((int)(min_cols_per_thread/ne00), 1);
@@ -84,8 +87,8 @@ static void ggml_backend_blas_mul_mat(ggml_backend_blas_context * ctx, struct gg
8487
}
8588
#else
8689
for (int i = 1; i < n_threads; i++) {
87-
const int64_t start = i*ne01/n_threads;
88-
const int64_t end = (i + 1)*ne01/n_threads;
90+
const int64_t start = (i + 0) * ne01/n_threads;
91+
const int64_t end = (i + 1) * ne01/n_threads;
8992
if (start < end) {
9093
ctx->tasks.push_back(std::async(std::launch::async, [=]() {
9194
for (int64_t i01 = start; i01 < end; i01++) {
@@ -149,14 +152,17 @@ static void ggml_backend_blas_mul_mat(ggml_backend_blas_context * ctx, struct gg
149152
}
150153
}
151154

152-
static void ggml_backend_blas_mul_mat_id(ggml_backend_blas_context * ctx, ggml_tensor * dst) {
153-
const struct ggml_tensor * src0 = dst->src[0];
154-
const struct ggml_tensor * src1 = dst->src[1];
155-
const struct ggml_tensor * ids = dst->src[2];
155+
static void ggml_backend_blas_mul_mat_id(
156+
ggml_backend_blas_context * ctx,
157+
ggml_tensor * dst) {
158+
159+
const ggml_tensor * src0 = dst->src[0];
160+
const ggml_tensor * src1 = dst->src[1];
161+
const ggml_tensor * ids = dst->src[2];
156162

157163
GGML_TENSOR_BINARY_OP_LOCALS
158164

159-
const enum ggml_type type = src0->type;
165+
const ggml_type type = src0->type;
160166

161167
GGML_ASSERT(nb00 == ggml_type_size(type));
162168
GGML_ASSERT(nb10 == ggml_type_size(src1->type));
@@ -173,15 +179,10 @@ static void ggml_backend_blas_mul_mat_id(ggml_backend_blas_context * ctx, ggml_t
173179
GGML_ASSERT(src1->type == GGML_TYPE_F32);
174180
GGML_ASSERT(ids->type == GGML_TYPE_I32);
175181

176-
// broadcast factors
177-
const int64_t r2 = ne12/ne02;
178-
const int64_t r3 = ne13/ne03;
179-
180-
GGML_UNUSED(r2);
181-
GGML_UNUSED(r3);
182-
183182
const int64_t ne_plane = ne01*ne00;
184-
const size_t desired_wsize = type == GGML_TYPE_F32 ? 0 : ne03*ne02*ne_plane*sizeof(float);
183+
const size_t desired_wsize = type == GGML_TYPE_F32
184+
? 0
185+
: ne03*ne02*ne_plane*sizeof(float);
185186

186187
if (ctx->work_size < desired_wsize) {
187188
ctx->work_data.reset(new char[desired_wsize]);
@@ -196,8 +197,8 @@ static void ggml_backend_blas_mul_mat_id(ggml_backend_blas_context * ctx, ggml_t
196197

197198
for (int64_t i03 = 0; i03 < ne03; i03++) {
198199
for (int64_t i02 = 0; i02 < ne02; i02++) {
199-
const void * x = (char *) src0->data + i02*nb02 + i03*nb03;
200-
float * const wplane = (float *) wdata + i02*ne_plane + i03*ne02*ne_plane;
200+
const void * x = (char *) src0->data + i02*nb02 + i03*nb03;
201+
float * const wplane = (float *) wdata + i02*ne_plane + i03*ne02*ne_plane;
201202

202203
const int min_cols_per_thread = 4096;
203204
const int min_rows_per_thread = std::max((int)(min_cols_per_thread/ne00), 1);
@@ -210,7 +211,7 @@ static void ggml_backend_blas_mul_mat_id(ggml_backend_blas_context * ctx, ggml_t
210211
}
211212
#else
212213
for (int i = 1; i < n_threads; i++) {
213-
const int64_t start = i*ne01/n_threads;
214+
const int64_t start = (i + 0)*ne01/n_threads;
214215
const int64_t end = (i + 1)*ne01/n_threads;
215216
if (start < end) {
216217
ctx->tasks.push_back(std::async(std::launch::async, [=]() {
@@ -555,15 +556,13 @@ static bool ggml_backend_blas_device_supports_op(ggml_backend_dev_t dev, const s
555556

556557
case GGML_OP_MUL_MAT_ID:
557558
{
558-
const struct ggml_tensor * src0 = op->src[0];
559-
const struct ggml_tensor * src1 = op->src[1];
560-
const struct ggml_tensor * src2 = op->src[2];
561-
562-
// GGML_LOG_INFO("%s: op=GGML_OP_MUL_MAT_ID src0_type=%s src1_type=%s src2_type=%s ne0=%lld ne1=%lld ne2=%lld ne3=%lld\n",
563-
// __func__, ggml_type_name(src0->type), ggml_type_name(src1->type), ggml_type_name(src2->type),
564-
// op->ne[0], op->ne[1], op->ne[2], op->ne[3]);
559+
const ggml_tensor * src0 = op->src[0];
560+
const ggml_tensor * src1 = op->src[1];
565561

566-
return src2->type == GGML_TYPE_I32;
562+
return ggml_is_contiguous(src0) &&
563+
ggml_is_contiguous(src1) &&
564+
src1->type == GGML_TYPE_F32 &&
565+
(src0->type == GGML_TYPE_F32 || ggml_get_type_traits(src0->type)->to_float != NULL);
567566
}
568567

569568
case GGML_OP_OUT_PROD:

0 commit comments

Comments
 (0)