@@ -11803,9 +11803,9 @@ static void ggml_compute_forward_add_rel_pos(
1180311803static void ggml_compute_forward_rwkv_wkv6_f32 (
1180411804 const struct ggml_compute_params * params ,
1180511805 struct ggml_tensor * dst ) {
11806- const int64_t T = dst -> src [1 ]-> ne [3 ];
11806+ const int64_t T = dst -> src [1 ]-> ne [2 ];
1180711807 const int64_t C = dst -> ne [0 ];
11808- const int64_t HEADS = dst -> src [1 ]-> ne [2 ];
11808+ const int64_t HEADS = dst -> src [1 ]-> ne [1 ];
1180911809 const int64_t n_seqs = dst -> src [5 ]-> ne [1 ];
1181011810 const int64_t head_size = C / HEADS ;
1181111811
@@ -12000,6 +12000,197 @@ static void ggml_compute_forward_rwkv_wkv6(
1200012000 }
1200112001}
1200212002
12003+ // ggml_compute_forward_gla
12004+
12005+ static void ggml_compute_forward_gla_f32 (
12006+ const struct ggml_compute_params * params ,
12007+ struct ggml_tensor * dst ) {
12008+ const int64_t T = dst -> src [1 ]-> ne [2 ];
12009+ const int64_t C = dst -> ne [0 ];
12010+ const int64_t HEADS = dst -> src [1 ]-> ne [1 ];
12011+ const int64_t n_seqs = dst -> src [4 ]-> ne [1 ];
12012+ const int64_t head_size = C / HEADS ;
12013+ const float scale = ggml_get_op_params_f32 (dst , 0 );
12014+
12015+ float * dst_data = (float * ) dst -> data ;
12016+ float * state = ((float * ) dst -> data ) + C * T ;
12017+
12018+ const int ith = params -> ith ;
12019+ const int nth = params -> nth ;
12020+
12021+ if (ith >= HEADS ) {
12022+ return ;
12023+ }
12024+
12025+ const int h_start = (HEADS * ith ) / nth ;
12026+ const int h_end = ((HEADS * (ith + 1 )) / nth < HEADS ) ?
12027+ (HEADS * (ith + 1 )) / nth : HEADS ;
12028+
12029+ float * k = (float * ) dst -> src [0 ]-> data ;
12030+ float * v = (float * ) dst -> src [1 ]-> data ;
12031+ float * q = (float * ) dst -> src [2 ]-> data ;
12032+ float * g = (float * ) dst -> src [3 ]-> data ;
12033+
12034+ size_t t_stride = HEADS * head_size ; // Same to C
12035+
12036+ size_t h_stride = C / HEADS ;
12037+ GGML_ASSERT (C % HEADS == 0 ); // C must be divisible by HEADS
12038+ size_t h_stride_2d = head_size * head_size ;
12039+
12040+ if (ith == 0 ) {
12041+ memset (dst_data , 0 , T * C * sizeof (float ));
12042+ }
12043+ ggml_barrier (params -> threadpool );
12044+
12045+
12046+ #if defined(__AVX__ ) && !defined(__AVX512F__ )
12047+ #define GGML_F32X GGML_F32x8
12048+ #define GGML_F32X_SET1 GGML_F32x8_SET1
12049+ #define GGML_F32X_LOAD GGML_F32x8_LOAD
12050+ #define GGML_F32X_STORE GGML_F32x8_STORE
12051+ #define GGML_F32X_MUL GGML_F32x8_MUL
12052+ #define GGML_F32X_FMA GGML_F32x8_FMA
12053+ #define GLA_VECTOR_SIZE 8
12054+ #elif defined(__AVX512F__ )
12055+ #define GGML_F32X GGML_F32x16
12056+ #define GGML_F32X_SET1 GGML_F32x16_SET1
12057+ #define GGML_F32X_LOAD GGML_F32x16_LOAD
12058+ #define GGML_F32X_STORE GGML_F32x16_STORE
12059+ #define GGML_F32X_MUL GGML_F32x16_MUL
12060+ #define GGML_F32X_FMA GGML_F32x16_FMA
12061+ #define GLA_VECTOR_SIZE 16
12062+ #elif defined(__ARM_NEON ) && defined(__aarch64__ )
12063+ #define GGML_F32X GGML_F32x4
12064+ #define GGML_F32X_SET1 GGML_F32x4_SET1
12065+ #define GGML_F32X_LOAD GGML_F32x4_LOAD
12066+ #define GGML_F32X_STORE GGML_F32x4_STORE
12067+ #define GGML_F32X_MUL GGML_F32x4_MUL
12068+ #define GGML_F32X_FMA GGML_F32x4_FMA
12069+ #define GLA_VECTOR_SIZE 4
12070+ #endif
12071+
12072+ #ifdef GLA_VECTOR_SIZE
12073+ const int64_t vec_count = head_size / GLA_VECTOR_SIZE ;
12074+
12075+ for (int64_t t = 0 ; t < T ; t ++ ) {
12076+ size_t t_offset = t * t_stride ;
12077+ size_t state_offset = head_size * C * (t / (T / n_seqs ));
12078+ float * state_cur = state + state_offset ;
12079+ float * state_prev = t % (T / n_seqs ) ? state_cur : (float * )dst -> src [4 ]-> data + state_offset ;
12080+
12081+ for (int64_t h = h_start ; h < h_end ; h ++ ) {
12082+ size_t h_offset = h * h_stride ;
12083+ size_t t_h_offset = t_offset + h_offset ;
12084+ size_t h_2d_offset = h * h_stride_2d ;
12085+
12086+ for (int64_t i = 0 ; i < head_size ; i ++ ) {
12087+ size_t t_h_i_offset = t_h_offset + i ;
12088+ size_t h_2d_i_offset = h_2d_offset + i * h_stride ;
12089+
12090+ float k_val = k [t_h_i_offset ];
12091+ float q_val = q [t_h_i_offset ] * scale ;
12092+ float g_val = g [t_h_i_offset ];
12093+
12094+ // Broadcast scalar values to vectors
12095+ GGML_F32X k_vec = GGML_F32X_SET1 (k_val );
12096+ GGML_F32X q_vec = GGML_F32X_SET1 (q_val );
12097+ GGML_F32X g_vec = GGML_F32X_SET1 (g_val );
12098+
12099+ for (int64_t j = 0 ; j < vec_count ; j ++ ) {
12100+ size_t base_j = j * GLA_VECTOR_SIZE ;
12101+ size_t t_h_j_offset = t_h_offset + base_j ;
12102+ size_t h_2d_i_j_offset = h_2d_i_offset + base_j ;
12103+
12104+ // Load x elements at once
12105+ GGML_F32X v_vec = GGML_F32X_LOAD (& v [t_h_j_offset ]);
12106+ GGML_F32X prev_state_vec = GGML_F32X_LOAD (& state_prev [h_2d_i_j_offset ]);
12107+ GGML_F32X dst_vec = GGML_F32X_LOAD (& dst_data [t_h_j_offset ]);
12108+
12109+ // Compute kv = v * k
12110+ GGML_F32X kv_vec = GGML_F32X_MUL (v_vec , k_vec );
12111+
12112+ // Compute temp = prev_state * g + kv
12113+ GGML_F32X temp_vec = GGML_F32X_FMA (kv_vec , prev_state_vec , g_vec );
12114+
12115+ // Update dst: dst += temp * q
12116+ dst_vec = GGML_F32X_FMA (dst_vec , temp_vec , q_vec );
12117+ GGML_F32X_STORE (& dst_data [t_h_j_offset ], dst_vec );
12118+
12119+ // Update state
12120+ GGML_F32X_STORE (& state_cur [h_2d_i_j_offset ], temp_vec );
12121+ }
12122+
12123+ // Handle remaining elements, this will not be used.
12124+ for (int64_t j = vec_count * GLA_VECTOR_SIZE ; j < head_size ; j ++ ) {
12125+ size_t t_h_j_offset = t_h_offset + j ;
12126+ size_t h_2d_i_j_offset = h_2d_i_offset + j ;
12127+ float v_val = v [t_h_j_offset ];
12128+ float kv_val = v_val * k_val ;
12129+ float prev_state_val = state_prev [h_2d_i_j_offset ];
12130+ float temp_val = kv_val + prev_state_val * g_val ;
12131+ dst_data [t_h_j_offset ] += temp_val * q_val ;
12132+ state_cur [h_2d_i_j_offset ] = temp_val ;
12133+ }
12134+ }
12135+ }
12136+ }
12137+
12138+ #else
12139+ for (int64_t t = 0 ; t < T ; t ++ ) {
12140+ size_t t_offset = t * t_stride ;
12141+ size_t state_offset = head_size * C * (t / (T / n_seqs ));
12142+ float * state_cur = state + state_offset ;
12143+ float * state_prev = t % (T / n_seqs ) ? state_cur : (float * )dst -> src [4 ]-> data + state_offset ;
12144+
12145+ for (int64_t h = h_start ; h < h_end ; h ++ ) {
12146+ size_t h_offset = h * h_stride ;
12147+ size_t t_h_offset = t_offset + h_offset ;
12148+ size_t h_2d_offset = h * h_stride_2d ;
12149+
12150+ for (int64_t i = 0 ; i < head_size ; i ++ ) {
12151+ size_t t_h_i_offset = t_h_offset + i ;
12152+ size_t h_2d_i_offset = h_2d_offset + i * h_stride ;
12153+
12154+ float k_val = k [t_h_i_offset ];
12155+ float q_val = q [t_h_i_offset ] * scale ;
12156+ float g_val = g [t_h_i_offset ];
12157+
12158+ for (int64_t j = 0 ; j < head_size ; j ++ ) {
12159+ size_t t_h_j_offset = t_h_offset + j ;
12160+ size_t h_2d_i_j_offset = h_2d_i_offset + j ;
12161+
12162+ float v_val = v [t_h_j_offset ];
12163+ float kv_val = v_val * k_val ;
12164+ float prev_state_val = state_prev [h_2d_i_j_offset ];
12165+ float temp_val = prev_state_val * g_val + kv_val ;
12166+ dst_data [t_h_j_offset ] += temp_val * q_val ;
12167+ state_cur [h_2d_i_j_offset ] = temp_val ;
12168+ }
12169+ }
12170+ }
12171+ }
12172+ #endif
12173+ }
12174+
12175+
12176+ static void ggml_compute_forward_gla (
12177+ const struct ggml_compute_params * params ,
12178+ struct ggml_tensor * dst ) {
12179+
12180+ const struct ggml_tensor * src0 = dst -> src [0 ];
12181+
12182+ switch (src0 -> type ) {
12183+ case GGML_TYPE_F32 :
12184+ {
12185+ ggml_compute_forward_gla_f32 (params , dst );
12186+ } break ;
12187+ default :
12188+ {
12189+ GGML_ABORT ("fatal error" );
12190+ }
12191+ }
12192+ }
12193+
1200312194// ggml_compute_forward_map_unary
1200412195
1200512196static void ggml_compute_forward_map_unary_f32 (
@@ -12749,6 +12940,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
1274912940 {
1275012941 ggml_compute_forward_rwkv_wkv6 (params , tensor );
1275112942 } break ;
12943+ case GGML_OP_GATED_LINEAR_ATTN :
12944+ {
12945+ ggml_compute_forward_gla (params , tensor );
12946+ } break ;
1275212947 case GGML_OP_MAP_UNARY :
1275312948 {
1275412949 ggml_unary_op_f32_t fun ;
@@ -13047,6 +13242,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
1304713242 case GGML_OP_WIN_UNPART :
1304813243 case GGML_OP_GET_REL_POS :
1304913244 case GGML_OP_RWKV_WKV6 :
13245+ case GGML_OP_GATED_LINEAR_ATTN :
1305013246 case GGML_OP_MAP_UNARY :
1305113247 case GGML_OP_MAP_BINARY :
1305213248 case GGML_OP_MAP_CUSTOM1_F32 :
0 commit comments