Skip to content

Commit 1b06404

Browse files
committed
use inline function
1 parent 65910d3 commit 1b06404

File tree

3 files changed

+39
-54
lines changed

3 files changed

+39
-54
lines changed

ggml/src/ggml-opencl/kernels/flash_attn_f16.cl

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,17 @@
1212
#define WG_SIZE (BLOCK_M)
1313
#define Q1_WG_SIZE 64
1414

15+
inline float get_alibi_slope(
16+
const float max_bias, const uint h, const uint n_head_log2, const float m0, const float m1
17+
) {
18+
if (max_bias <= 0.0f) {
19+
return 1.0f;
20+
}
21+
const float base = h < n_head_log2 ? m0 : m1;
22+
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
23+
24+
return pow(base, exph);
25+
}
1526
__kernel void flash_attn_f16(
1627
const global void * q_void, ulong q_offset,
1728
const global void * k_void, ulong k_offset,
@@ -80,15 +91,7 @@ __kernel void flash_attn_f16(
8091
ACC_TYPE m_i = -INFINITY;
8192
ACC_TYPE l_i = 0.0f;
8293

83-
float slope = 1.0f;
84-
if (max_bias > 0.0f) {
85-
int h = head_idx;
86-
if (h < n_head_log2) {
87-
slope = pow(m0, (float)(h + 1));
88-
} else {
89-
slope = pow(m1, (float)(2 * (h - n_head_log2) + 1));
90-
}
91-
}
94+
float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
9295

9396
__local DATA_TYPE4 l_k[BLOCK_N][DK_VEC];
9497
__local DATA_TYPE4 l_v[BLOCK_N][DV_VEC];
@@ -235,15 +238,7 @@ __kernel void flash_attn_f16_q1(
235238
q_priv[i] = CONVERT_ACC4(q_ptr[i]);
236239
}
237240

238-
float slope = 1.0f;
239-
if (max_bias > 0.0f) {
240-
int h = head_idx;
241-
if (h < n_head_log2) {
242-
slope = pow(m0, (float)(h + 1));
243-
} else {
244-
slope = pow(m1, (float)(2 * (h - n_head_log2) + 1));
245-
}
246-
}
241+
float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
247242

248243
ACC_TYPE m_i = -INFINITY;
249244
for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) {

ggml/src/ggml-opencl/kernels/flash_attn_f32.cl

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,17 @@
1212
#define WG_SIZE (BLOCK_M)
1313
#define Q1_WG_SIZE 64
1414

15+
inline float get_alibi_slope(
16+
const float max_bias, const uint h, const uint n_head_log2, const float m0, const float m1
17+
) {
18+
if (max_bias <= 0.0f) {
19+
return 1.0f;
20+
}
21+
const float base = h < n_head_log2 ? m0 : m1;
22+
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
23+
24+
return pow(base, exph);
25+
}
1526
__kernel void flash_attn_f32(
1627
const global void * q_void, ulong q_offset,
1728
const global void * k_void, ulong k_offset,
@@ -80,15 +91,7 @@ __kernel void flash_attn_f32(
8091
ACC_TYPE m_i = -INFINITY;
8192
ACC_TYPE l_i = 0.0f;
8293

83-
float slope = 1.0f;
84-
if (max_bias > 0.0f) {
85-
int h = head_idx;
86-
if (h < n_head_log2) {
87-
slope = pow(m0, (float)(h + 1));
88-
} else {
89-
slope = pow(m1, (float)(2 * (h - n_head_log2) + 1));
90-
}
91-
}
94+
float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
9295

9396
__local DATA_TYPE4 l_k[BLOCK_N][DK_VEC];
9497
__local DATA_TYPE4 l_v[BLOCK_N][DV_VEC];
@@ -235,15 +238,7 @@ __kernel void flash_attn_f32_q1(
235238
q_priv[i] = CONVERT_ACC4(q_ptr[i]);
236239
}
237240

238-
float slope = 1.0f;
239-
if (max_bias > 0.0f) {
240-
int h = head_idx;
241-
if (h < n_head_log2) {
242-
slope = pow(m0, (float)(h + 1));
243-
} else {
244-
slope = pow(m1, (float)(2 * (h - n_head_log2) + 1));
245-
}
246-
}
241+
float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
247242

248243
ACC_TYPE m_i = -INFINITY;
249244
for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) {

ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,17 @@
1515
#define WG_SIZE (BLOCK_M)
1616
#define Q1_WG_SIZE 64
1717

18+
inline float get_alibi_slope(
19+
const float max_bias, const uint h, const uint n_head_log2, const float m0, const float m1
20+
) {
21+
if (max_bias <= 0.0f) {
22+
return 1.0f;
23+
}
24+
const float base = h < n_head_log2 ? m0 : m1;
25+
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
26+
27+
return pow(base, exph);
28+
}
1829
__kernel void flash_attn_f32_f16(
1930
const global void * q_void, ulong q_offset,
2031
const global void * k_void, ulong k_offset,
@@ -83,15 +94,7 @@ __kernel void flash_attn_f32_f16(
8394
ACC_TYPE m_i = -INFINITY;
8495
ACC_TYPE l_i = 0.0f;
8596

86-
float slope = 1.0f;
87-
if (max_bias > 0.0f) {
88-
int h = head_idx;
89-
if (h < n_head_log2) {
90-
slope = pow(m0, (float)(h + 1));
91-
} else {
92-
slope = pow(m1, (float)(2 * (h - n_head_log2) + 1));
93-
}
94-
}
97+
float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
9598

9699
__local KV_DATA_TYPE4 l_k[BLOCK_N][DK_VEC];
97100
__local KV_DATA_TYPE4 l_v[BLOCK_N][DV_VEC];
@@ -238,15 +241,7 @@ __kernel void flash_attn_f32_f16_q1(
238241
q_priv[i] = CONVERT_Q_ACC4(q_ptr[i]);
239242
}
240243

241-
float slope = 1.0f;
242-
if (max_bias > 0.0f) {
243-
int h = head_idx;
244-
if (h < n_head_log2) {
245-
slope = pow(m0, (float)(h + 1));
246-
} else {
247-
slope = pow(m1, (float)(2 * (h - n_head_log2) + 1));
248-
}
249-
}
244+
float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
250245

251246
ACC_TYPE m_i = -INFINITY;
252247
for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) {

0 commit comments

Comments
 (0)