diff --git a/src/ggml-cpu/ops.cpp b/src/ggml-cpu/ops.cpp index 3c2adb2172..e297c52b33 100644 --- a/src/ggml-cpu/ops.cpp +++ b/src/ggml-cpu/ops.cpp @@ -3292,6 +3292,7 @@ static void ggml_compute_forward_rms_norm_back_f32( //const float rms = sqrtf(mean_eps); const float rrms = 1.0f / sqrtf(mean_eps); //const float scale = -rrms/(ne00 * mean_eps); // -1/(n*rms**3) + const float scale_x = rrms * (-sum_xdz)/sum_eps; { // z = rms_norm(x) @@ -3390,11 +3391,15 @@ static void ggml_compute_forward_rms_norm_back_f32( float * dx = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); // dx[i00] = (x*(-sum_xdz/sum_eps) + dz) / sqrtf(mean_eps) - ggml_vec_cpy_f32 (ne00, dx, x); + // ggml_vec_cpy_f32 (ne00, dx, x); // ggml_vec_scale_f32(ne00, dx, -mean_xdz/mean_eps); - ggml_vec_scale_f32(ne00, dx, (float)(-sum_xdz)/sum_eps); - ggml_vec_acc_f32 (ne00, dx, dz); - ggml_vec_scale_f32(ne00, dx, rrms); + // ggml_vec_scale_f32(ne00, dx, (float)(-sum_xdz)/sum_eps); + // ggml_vec_acc_f32 (ne00, dx, dz); + // ggml_vec_scale_f32(ne00, dx, rrms); + + for (int64_t i00 = 0; i00 < ne00; i00++) { + dx[i00] = rrms * dz[i00] + scale_x * x[i00]; + } } } }