Skip to content

Commit a652816

Browse files
committed
Fix FlashAttention as GEMM kernel requires 64b-aligned matrices
1 parent c52d7ef commit a652816

File tree

4 files changed

+24
-22
lines changed

4 files changed

+24
-22
lines changed

sw/kernels/blas/gemm/src/gemm_fp16.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,7 @@ void gemm_fp16_opt_ex(uint32_t setup_ssr, uint32_t partition_banks,
306306
#ifdef SNRT_SUPPORTS_FREP
307307
__fp16* A = (__fp16*)A_p;
308308
__fp16* B = (__fp16*)B_p;
309-
__fp16* C = (__fp16*)C_p;
309+
__fp16* C = (__fp16*)C_p; // Should be double-aligned (see fsd below)
310310

311311
// Unrolling factor of most inner loop.
312312
// Should be at least as high as the FMA delay

sw/kernels/blas/gemm/src/gemm_fp32.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ void gemm_fp32_opt(uint32_t setup_ssr, uint32_t partition_banks,
214214
// cast void pointers to float pointers
215215
float* A = (float*)A_p;
216216
float* B = (float*)B_p;
217-
float* C = (float*)C_p;
217+
float* C = (float*)C_p; // Should be double-aligned (see fsd below)
218218
// Unrolling factor of most inner loop.
219219
// Should be at least as high as the FMA delay
220220
// for maximum utilization

sw/kernels/dnn/flashattention_2/src/flashattention_2_fp16.h

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -56,27 +56,28 @@ static inline void flashattention_2_fp16(flashattention_2_layer_t layer) {
5656
uint32_t shifted_exp_size = B_r * sizeof(float);
5757

5858
// allocate memory in TCDM
59+
// align to size of double since this is required for some GEMM arrays
5960
__fp16 *Q_fa =
60-
(__fp16 *)snrt_l1_alloc_cluster_local(q_fa_size, alignof(__fp16));
61+
(__fp16 *)snrt_l1_alloc_cluster_local(q_fa_size, alignof(double));
6162
__fp16 *K_fa =
62-
(__fp16 *)snrt_l1_alloc_cluster_local(k_fa_size, alignof(__fp16));
63+
(__fp16 *)snrt_l1_alloc_cluster_local(k_fa_size, alignof(double));
6364
__fp16 *V_fa =
64-
(__fp16 *)snrt_l1_alloc_cluster_local(v_fa_size, alignof(__fp16));
65+
(__fp16 *)snrt_l1_alloc_cluster_local(v_fa_size, alignof(double));
6566
__fp16 *S_fa =
66-
(__fp16 *)snrt_l1_alloc_cluster_local(s_fa_size, alignof(__fp16));
67+
(__fp16 *)snrt_l1_alloc_cluster_local(s_fa_size, alignof(double));
6768
__fp16 *P_fa =
68-
(__fp16 *)snrt_l1_alloc_cluster_local(p_fa_size, alignof(__fp16));
69+
(__fp16 *)snrt_l1_alloc_cluster_local(p_fa_size, alignof(double));
6970
__fp16 *O_fa =
70-
(__fp16 *)snrt_l1_alloc_cluster_local(o_fa_size, alignof(__fp16));
71-
float *m_i = (float *)snrt_l1_alloc_cluster_local(m_i_size, alignof(float));
71+
(__fp16 *)snrt_l1_alloc_cluster_local(o_fa_size, alignof(double));
72+
float *m_i = (float *)snrt_l1_alloc_cluster_local(m_i_size, alignof(double));
7273
float *m_i_prev =
73-
(float *)snrt_l1_alloc_cluster_local(m_i_size, alignof(float));
74-
float *l_i = (float *)snrt_l1_alloc_cluster_local(l_i_size, alignof(float));
74+
(float *)snrt_l1_alloc_cluster_local(m_i_size, alignof(double));
75+
float *l_i = (float *)snrt_l1_alloc_cluster_local(l_i_size, alignof(double));
7576

7677
// Allocate space for V^t
7778
__fp16 *V_t;
7879
if (!baseline) {
79-
V_t = (__fp16 *)snrt_l1_alloc_cluster_local(v_fa_size, alignof(__fp16));
80+
V_t = (__fp16 *)snrt_l1_alloc_cluster_local(v_fa_size, alignof(double));
8081
}
8182

8283
float shifted_exp;

sw/kernels/dnn/flashattention_2/src/flashattention_2_fp32.h

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -56,27 +56,28 @@ static inline void flashattention_2_fp32(flashattention_2_layer_t layer) {
5656
uint32_t shifted_exp_size = B_r * sizeof(float);
5757

5858
// allocate memory in TCDM
59+
// align to size of double since this is required for some GEMM arrays
5960
float *Q_fa =
60-
(float *)snrt_l1_alloc_cluster_local(q_fa_size, alignof(float));
61+
(float *)snrt_l1_alloc_cluster_local(q_fa_size, alignof(double));
6162
float *K_fa =
62-
(float *)snrt_l1_alloc_cluster_local(k_fa_size, alignof(float));
63+
(float *)snrt_l1_alloc_cluster_local(k_fa_size, alignof(double));
6364
float *V_fa =
64-
(float *)snrt_l1_alloc_cluster_local(v_fa_size, alignof(float));
65+
(float *)snrt_l1_alloc_cluster_local(v_fa_size, alignof(double));
6566
float *S_fa =
66-
(float *)snrt_l1_alloc_cluster_local(s_fa_size, alignof(float));
67+
(float *)snrt_l1_alloc_cluster_local(s_fa_size, alignof(double));
6768
float *P_fa =
68-
(float *)snrt_l1_alloc_cluster_local(p_fa_size, alignof(float));
69+
(float *)snrt_l1_alloc_cluster_local(p_fa_size, alignof(double));
6970
float *O_fa =
70-
(float *)snrt_l1_alloc_cluster_local(o_fa_size, alignof(float));
71-
float *m_i = (float *)snrt_l1_alloc_cluster_local(m_i_size, alignof(float));
71+
(float *)snrt_l1_alloc_cluster_local(o_fa_size, alignof(double));
72+
float *m_i = (float *)snrt_l1_alloc_cluster_local(m_i_size, alignof(double));
7273
float *m_i_prev =
73-
(float *)snrt_l1_alloc_cluster_local(m_i_size, alignof(float));
74-
float *l_i = (float *)snrt_l1_alloc_cluster_local(l_i_size, alignof(float));
74+
(float *)snrt_l1_alloc_cluster_local(m_i_size, alignof(double));
75+
float *l_i = (float *)snrt_l1_alloc_cluster_local(l_i_size, alignof(double));
7576

7677
// allocate space for V^t when using optimized kernels
7778
float *V_t;
7879
if (!baseline) {
79-
V_t = (float *)snrt_l1_alloc_cluster_local(v_fa_size, alignof(float));
80+
V_t = (float *)snrt_l1_alloc_cluster_local(v_fa_size, alignof(double));
8081
}
8182

8283
float shifted_exp;

0 commit comments

Comments
 (0)