Skip to content

Commit 65910d3

Browse files
committed
use mad instead of fma
1 parent d3f049b commit 65910d3

File tree

3 files changed

+15
-15
lines changed

3 files changed

+15
-15
lines changed

ggml/src/ggml-opencl/kernels/flash_attn_f16.cl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,8 @@ __kernel void flash_attn_f16(
125125
ACC_TYPE4 dot_acc0 = (ACC_TYPE4)(0.0f);
126126
ACC_TYPE4 dot_acc1 = (ACC_TYPE4)(0.0f);
127127
for (int k = 0; k < DK_VEC; k++) {
128-
dot_acc0 = fma(q_priv[k], CONVERT_ACC4(l_k[j][k]), dot_acc0);
129-
dot_acc1 = fma(q_priv[k], CONVERT_ACC4(l_k[j+1][k]), dot_acc1);
128+
dot_acc0 = mad(q_priv[k], CONVERT_ACC4(l_k[j][k]), dot_acc0);
129+
dot_acc1 = mad(q_priv[k], CONVERT_ACC4(l_k[j+1][k]), dot_acc1);
130130
}
131131
ACC_TYPE score0 = (dot_acc0.s0 + dot_acc0.s1 + dot_acc0.s2 + dot_acc0.s3) * scale;
132132
ACC_TYPE score1 = (dot_acc1.s0 + dot_acc1.s1 + dot_acc1.s2 + dot_acc1.s3) * scale;
@@ -251,7 +251,7 @@ __kernel void flash_attn_f16_q1(
251251
const global DATA_TYPE4* k_ptr = (const global DATA_TYPE4*)(k_base + k_row_offset);
252252
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
253253
for (int k = 0; k < DK_VEC; k++) {
254-
dot_acc = fma(q_priv[k], CONVERT_ACC4(k_ptr[k]), dot_acc);
254+
dot_acc = mad(q_priv[k], CONVERT_ACC4(k_ptr[k]), dot_acc);
255255
}
256256
ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;
257257
if (mask_base != NULL) {
@@ -284,7 +284,7 @@ __kernel void flash_attn_f16_q1(
284284
const global DATA_TYPE4* v_ptr = (const global DATA_TYPE4*)(v_base + v_row_offset);
285285
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
286286
for (int k = 0; k < DK_VEC; k++) {
287-
dot_acc = fma(q_priv[k], CONVERT_ACC4(k_ptr[k]), dot_acc);
287+
dot_acc = mad(q_priv[k], CONVERT_ACC4(k_ptr[k]), dot_acc);
288288
}
289289
ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;
290290
if (mask_base != NULL) {
@@ -297,7 +297,7 @@ __kernel void flash_attn_f16_q1(
297297
const ACC_TYPE p = exp(score - m_final);
298298
l_i += p;
299299
for (int i = 0; i < DV_VEC; i++) {
300-
o_acc[i] = fma(p, CONVERT_ACC4(v_ptr[i]), o_acc[i]);
300+
o_acc[i] = mad(p, CONVERT_ACC4(v_ptr[i]), o_acc[i]);
301301
}
302302
}
303303

ggml/src/ggml-opencl/kernels/flash_attn_f32.cl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,8 @@ __kernel void flash_attn_f32(
125125
ACC_TYPE4 dot_acc0 = (ACC_TYPE4)(0.0f);
126126
ACC_TYPE4 dot_acc1 = (ACC_TYPE4)(0.0f);
127127
for (int k = 0; k < DK_VEC; k++) {
128-
dot_acc0 = fma(q_priv[k], CONVERT_ACC4(l_k[j][k]), dot_acc0);
129-
dot_acc1 = fma(q_priv[k], CONVERT_ACC4(l_k[j+1][k]), dot_acc1);
128+
dot_acc0 = mad(q_priv[k], CONVERT_ACC4(l_k[j][k]), dot_acc0);
129+
dot_acc1 = mad(q_priv[k], CONVERT_ACC4(l_k[j+1][k]), dot_acc1);
130130
}
131131
ACC_TYPE score0 = (dot_acc0.s0 + dot_acc0.s1 + dot_acc0.s2 + dot_acc0.s3) * scale;
132132
ACC_TYPE score1 = (dot_acc1.s0 + dot_acc1.s1 + dot_acc1.s2 + dot_acc1.s3) * scale;
@@ -251,7 +251,7 @@ __kernel void flash_attn_f32_q1(
251251
const global DATA_TYPE4* k_ptr = (const global DATA_TYPE4*)(k_base + k_row_offset);
252252
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
253253
for (int k = 0; k < DK_VEC; k++) {
254-
dot_acc = fma(q_priv[k], CONVERT_ACC4(k_ptr[k]), dot_acc);
254+
dot_acc = mad(q_priv[k], CONVERT_ACC4(k_ptr[k]), dot_acc);
255255
}
256256
ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;
257257
if (mask_base != NULL) {
@@ -284,7 +284,7 @@ __kernel void flash_attn_f32_q1(
284284
const global DATA_TYPE4* v_ptr = (const global DATA_TYPE4*)(v_base + v_row_offset);
285285
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
286286
for (int k = 0; k < DK_VEC; k++) {
287-
dot_acc = fma(q_priv[k], CONVERT_ACC4(k_ptr[k]), dot_acc);
287+
dot_acc = mad(q_priv[k], CONVERT_ACC4(k_ptr[k]), dot_acc);
288288
}
289289
ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;
290290
if (mask_base != NULL) {
@@ -297,7 +297,7 @@ __kernel void flash_attn_f32_q1(
297297
const ACC_TYPE p = exp(score - m_final);
298298
l_i += p;
299299
for (int i = 0; i < DV_VEC; i++) {
300-
o_acc[i] = fma(p, CONVERT_ACC4(v_ptr[i]), o_acc[i]);
300+
o_acc[i] = mad(p, CONVERT_ACC4(v_ptr[i]), o_acc[i]);
301301
}
302302
}
303303

ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -128,8 +128,8 @@ __kernel void flash_attn_f32_f16(
128128
ACC_TYPE4 dot_acc0 = (ACC_TYPE4)(0.0f);
129129
ACC_TYPE4 dot_acc1 = (ACC_TYPE4)(0.0f);
130130
for (int k = 0; k < DK_VEC; k++) {
131-
dot_acc0 = fma(q_priv[k], CONVERT_KV_ACC4(l_k[j][k]), dot_acc0);
132-
dot_acc1 = fma(q_priv[k], CONVERT_KV_ACC4(l_k[j+1][k]), dot_acc1);
131+
dot_acc0 = mad(q_priv[k], CONVERT_KV_ACC4(l_k[j][k]), dot_acc0);
132+
dot_acc1 = mad(q_priv[k], CONVERT_KV_ACC4(l_k[j+1][k]), dot_acc1);
133133
}
134134
ACC_TYPE score0 = (dot_acc0.s0 + dot_acc0.s1 + dot_acc0.s2 + dot_acc0.s3) * scale;
135135
ACC_TYPE score1 = (dot_acc1.s0 + dot_acc1.s1 + dot_acc1.s2 + dot_acc1.s3) * scale;
@@ -254,7 +254,7 @@ __kernel void flash_attn_f32_f16_q1(
254254
const global KV_DATA_TYPE4* k_ptr = (const global KV_DATA_TYPE4*)(k_base + k_row_offset);
255255
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
256256
for (int k = 0; k < DK_VEC; k++) {
257-
dot_acc = fma(q_priv[k], CONVERT_KV_ACC4(k_ptr[k]), dot_acc);
257+
dot_acc = mad(q_priv[k], CONVERT_KV_ACC4(k_ptr[k]), dot_acc);
258258
}
259259
ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;
260260
if (mask_base != NULL) {
@@ -287,7 +287,7 @@ __kernel void flash_attn_f32_f16_q1(
287287
const global KV_DATA_TYPE4* v_ptr = (const global KV_DATA_TYPE4*)(v_base + v_row_offset);
288288
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
289289
for (int k = 0; k < DK_VEC; k++) {
290-
dot_acc = fma(q_priv[k], CONVERT_KV_ACC4(k_ptr[k]), dot_acc);
290+
dot_acc = mad(q_priv[k], CONVERT_KV_ACC4(k_ptr[k]), dot_acc);
291291
}
292292
ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;
293293
if (mask_base != NULL) {
@@ -300,7 +300,7 @@ __kernel void flash_attn_f32_f16_q1(
300300
const ACC_TYPE p = exp(score - m_final);
301301
l_i += p;
302302
for (int i = 0; i < DV_VEC; i++) {
303-
o_acc[i] = fma(p, CONVERT_KV_ACC4(v_ptr[i]), o_acc[i]);
303+
o_acc[i] = mad(p, CONVERT_KV_ACC4(v_ptr[i]), o_acc[i]);
304304
}
305305
}
306306

0 commit comments

Comments
 (0)