Skip to content

Commit ebbad77

Browse files
committed
add x param to ggml_vec_mad1_f32
1 parent cd1703a commit ebbad77

File tree

2 files changed

+11
-11
lines changed

2 files changed

+11
-11
lines changed

ggml/src/ggml-cpu/ops.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4670,17 +4670,17 @@ static void ggml_compute_forward_scale_f32(
46704670
for (int i1 = ir0; i1 < ir1; i1++) {
46714671
if (dst->data != src0->data) {
46724672
// src0 is same shape as dst => same indices
4673+
// TODO: add x parameter to ggml_vec_scale_f32 and remove this memcpy
46734674
memcpy((char *)dst->data + i1*nb1, (char *)src0->data + i1*nb01, nc * sizeof(float));
46744675
}
46754676
ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*nb1), s);
46764677
}
46774678
} else {
46784679
for (int i1 = ir0; i1 < ir1; i1++) {
4679-
if (dst->data != src0->data) {
4680-
// src0 is same shape as dst => same indices
4681-
memcpy((char *)dst->data + i1*nb1, (char *)src0->data + i1*nb01, nc * sizeof(float));
4682-
}
4683-
ggml_vec_mad1_f32(nc, (float *) ((char *) dst->data + i1*nb1), s, b);
4680+
ggml_vec_mad1_f32(nc,
4681+
(float *) ((char *) dst->data + i1*nb1),
4682+
(float *) ((char *) src0->data + i1*nb1),
4683+
s, b);
46844684
}
46854685
}
46864686
}

ggml/src/ggml-cpu/vec.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -351,14 +351,14 @@ inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int
351351
#endif
352352
}
353353

354-
inline static void ggml_vec_mad1_f32(const int n, float * y, const float s, const float b) {
354+
inline static void ggml_vec_mad1_f32(const int n, float * y, const float * x, const float s, const float b) {
355355
#if defined(GGML_USE_ACCELERATE)
356-
vDSP_vsmsa(y, 1, &s, &b, y, 1, n);
356+
vDSP_vsmsa(x, 1, &s, &b, y, 1, n);
357357
#elif defined(GGML_SIMD)
358358
#if defined(__ARM_FEATURE_SVE)
359359
// scalar ; TODO: Write SVE code
360360
for (int i = 0; i < n; ++i) {
361-
y[i] = y[i]*s + b;
361+
y[i] = x[i]*s + b;
362362
}
363363
#else
364364
const int np = (n & ~(GGML_F32_STEP - 1));
@@ -370,7 +370,7 @@ inline static void ggml_vec_mad1_f32(const int n, float * y, const float s, cons
370370

371371
for (int i = 0; i < np; i += GGML_F32_STEP) {
372372
for (int j = 0; j < GGML_F32_ARR; j++) {
373-
ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
373+
ay[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR);
374374
ay[j] = GGML_F32_VEC_FMA(ay[j], vs, vb);
375375

376376
GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
@@ -379,13 +379,13 @@ inline static void ggml_vec_mad1_f32(const int n, float * y, const float s, cons
379379

380380
// leftovers
381381
for (int i = np; i < n; ++i) {
382-
y[i] = y[i]*s + b;
382+
y[i] = x[i]*s + b;
383383
}
384384
#endif
385385
#else
386386
// scalar
387387
for (int i = 0; i < n; ++i) {
388-
y[i] = y[i]*s + b;
388+
y[i] = x[i]*s + b;
389389
}
390390
#endif
391391
}

0 commit comments

Comments
 (0)