@@ -813,14 +813,14 @@ def forward(
813
813
sin = sin [None , :, None , :]
814
814
q_pe , k_pe = apply_rotary_pos_emb (q_pe , k_pe , cos , sin , position_ids , self .fuse_rope )
815
815
816
- query_states = paddle .concat ([q_nope , q_pe ], axis = - 1 )
817
- key_states = paddle .concat ([k_nope , k_pe ], axis = - 1 )
816
+ query_states = paddle .cat ([q_nope , q_pe ], axis = - 1 )
817
+ key_states = paddle .cat ([k_nope , k_pe ], axis = - 1 )
818
818
819
819
# [bs, seq_len, num_head, head_dim]
820
820
if past_key_value is not None :
821
821
# reuse k, v, self_attention
822
- key_states = paddle .concat ([past_key_value [0 ], key_states ], axis = 1 )
823
- value_states = paddle .concat ([past_key_value [1 ], value_states ], axis = 1 )
822
+ key_states = paddle .cat ([past_key_value [0 ], key_states ], axis = 1 )
823
+ value_states = paddle .cat ([past_key_value [1 ], value_states ], axis = 1 )
824
824
past_key_value = (key_states , value_states ) if use_cache else None
825
825
826
826
has_gradient = not (query_states .stop_gradient and key_states .stop_gradient and value_states .stop_gradient )
@@ -1141,7 +1141,7 @@ def forward(
1141
1141
hidden_states = self .hnorm (hidden_states )
1142
1142
nextn_hidden_state = self .enorm (nextn_hidden_state )
1143
1143
1144
- concat_h = paddle .concat ([nextn_hidden_state , hidden_states ], axis = - 1 )
1144
+ concat_h = paddle .cat ([nextn_hidden_state , hidden_states ], axis = - 1 )
1145
1145
hidden_states = FP8LinearFunction .apply (concat_h , self .eh_proj )
1146
1146
1147
1147
layer_outputs = super (DeepseekV2MTPLayer , self ).forward (
@@ -1686,7 +1686,7 @@ def forward(
1686
1686
hidden_states = GatherOp .apply (hidden_states )
1687
1687
hidden_states = hidden_states .reshape ([- 1 , seq_length , hidden_states .shape [- 1 ]])
1688
1688
1689
- inputs_embeds_cur_depth = paddle .concat (
1689
+ inputs_embeds_cur_depth = paddle .cat (
1690
1690
[inputs_embeds_ori [:, (nextn + 1 ) :, :], inputs_embeds_extra [:, : (nextn + 1 ), :]], axis = 1
1691
1691
)
1692
1692
@@ -1848,7 +1848,7 @@ def _set_cos_sin_cache(self, seq_len):
1848
1848
/ yarn_get_mscale (self .scaling_factor , self .mscale_all_dim )
1849
1849
)
1850
1850
1851
- emb = paddle .concat ((freqs , freqs ), axis = - 1 )
1851
+ emb = paddle .cat ((freqs , freqs ), axis = - 1 )
1852
1852
self .cos_cached = emb .cos () * _mscale
1853
1853
self .sin_cached = emb .sin () * _mscale
1854
1854
@@ -1919,7 +1919,7 @@ def _set_cos_sin_cache(self, seq_len):
1919
1919
freqs = paddle .einsum ("i,j->ij" , t , self .inv_freq )
1920
1920
# Different from paper, but it uses a different permutation in order to obtain the same calculation
1921
1921
# [seq_len, axis]
1922
- emb = paddle .concat ([freqs , freqs ], axis = - 1 )
1922
+ emb = paddle .cat ([freqs , freqs ], axis = - 1 )
1923
1923
# [1, seqlen, 1, axis]
1924
1924
self .cos_cached = emb .cos ()[None , :, None , :]
1925
1925
self .sin_cached = emb .sin ()[None , :, None , :]
@@ -2137,8 +2137,8 @@ def qkv_pre_process_no_fuse(
2137
2137
sin = sin [None , :, None , :]
2138
2138
q_pe , k_pe = apply_rotary_pos_emb (q_pe , k_pe , cos , sin , position_ids , False )
2139
2139
2140
- query_states = paddle .concat ([q_nope , q_pe ], axis = - 1 )
2141
- key_states = paddle .concat ([k_nope , k_pe ], axis = - 1 )
2140
+ query_states = paddle .cat ([q_nope , q_pe ], axis = - 1 )
2141
+ key_states = paddle .cat ([k_nope , k_pe ], axis = - 1 )
2142
2142
2143
2143
return query_states , key_states , value_states
2144
2144
@@ -2149,7 +2149,7 @@ def rearrange_kv(kv, k_pe, qk_nope_head_dim, num_heads):
2149
2149
value_states = kv [..., qk_nope_head_dim :]
2150
2150
2151
2151
k_pe = k_pe .expand ([k_pe .shape [0 ], k_pe .shape [1 ], num_heads , k_pe .shape [3 ]])
2152
- key_states = paddle .concat ([k_nope , k_pe ], axis = - 1 )
2152
+ key_states = paddle .cat ([k_nope , k_pe ], axis = - 1 )
2153
2153
2154
2154
return key_states , value_states
2155
2155
@@ -2315,7 +2315,7 @@ def forward(
2315
2315
[bsz , kv_seq_len , v_num_heads , q_head_dim - v_head_dim ],
2316
2316
dtype = value_states .dtype ,
2317
2317
)
2318
- value_states_pad = paddle .concat ([value_states , value_padding ], axis = - 1 )
2318
+ value_states_pad = paddle .cat ([value_states , value_padding ], axis = - 1 )
2319
2319
2320
2320
attn_out , _ , softmax_lse , seed_offset = _C_ops .flash_attn (
2321
2321
query_states ,
@@ -2541,7 +2541,7 @@ def backward(ctx, dout):
2541
2541
[bsz , kv_seq_len , v_num_heads , q_head_dim - v_head_dim ],
2542
2542
dtype = value_states .dtype ,
2543
2543
)
2544
- value_states_pad = paddle .concat ([value_states , value_padding ], axis = - 1 )
2544
+ value_states_pad = paddle .cat ([value_states , value_padding ], axis = - 1 )
2545
2545
2546
2546
with paddle .no_grad ():
2547
2547
@@ -2655,7 +2655,7 @@ def kv_up_weight_grad(kv_ln_trans_fp8, kv_ln_trans_scale, d_kv_t_fp8, d_kv_t_sca
2655
2655
compressed_kv , kv_ln_weight , kv_ln_invar , d_kv_ln_t , eps
2656
2656
)
2657
2657
2658
- d_kv_init = paddle .concat ([d_compressed_kv , d_k_pe ], axis = - 1 )
2658
+ d_kv_init = paddle .cat ([d_compressed_kv , d_k_pe ], axis = - 1 )
2659
2659
2660
2660
if hasattr (q_up_weight , "main_grad" ):
2661
2661
0 commit comments