@@ -813,24 +813,28 @@ def _init_rope(self):
813
813
self .rotary_emb = LlamaRotaryEmbedding (
814
814
self .head_dim ,
815
815
max_position_embeddings = self .max_position_embeddings ,
816
+ base = self .config .rope_theta ,
816
817
)
817
818
elif self .config .rope_scaling_type == "linear" :
818
819
self .rotary_emb = LlamaLinearScalingRotaryEmbedding (
819
820
self .head_dim ,
820
821
max_position_embeddings = self .max_position_embeddings ,
821
822
scaling_factor = self .config .rope_scaling_factor ,
823
+ base = self .config .rope_theta ,
822
824
)
823
825
elif self .config .rope_scaling_type == "ntk" :
824
826
self .rotary_emb = LlamaNTKScalingRotaryEmbedding (
825
827
self .head_dim ,
826
828
max_position_embeddings = self .max_position_embeddings ,
827
829
scaling_factor = self .config .rope_scaling_factor ,
830
+ base = self .config .rope_theta ,
828
831
)
829
832
elif self .config .rope_scaling_type == "dynamic_ntk" :
830
833
self .rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding (
831
834
self .head_dim ,
832
835
max_position_embeddings = self .max_position_embeddings ,
833
836
scaling_factor = self .config .rope_scaling_factor ,
837
+ base = self .config .rope_theta ,
834
838
)
835
839
else :
836
840
raise ValueError (f"Unknown RoPE scaling type { self .config .rope_scaling_type } " )
@@ -903,6 +907,7 @@ def forward(
903
907
query_states = self .q_proj (hidden_states )
904
908
key_states = self .k_proj (hidden_states )
905
909
value_states = self .v_proj (hidden_states )
910
+
906
911
if self .reshard_layer is not None :
907
912
if self .sequence_parallel :
908
913
assert self .seq_length % self .config .sep_parallel_degree == 0
@@ -1027,7 +1032,6 @@ def forward(
1027
1032
value_states = paddle .concat ([past_key_value [1 ], value_states ], axis = 1 )
1028
1033
1029
1034
past_key_value = (key_states , value_states ) if use_cache else None
1030
-
1031
1035
if self .kv_indices is not None :
1032
1036
key_states = paddle .index_select (key_states , self .kv_indices , axis = 2 )
1033
1037
value_states = paddle .index_select (value_states , self .kv_indices , axis = 2 )
@@ -1036,7 +1040,7 @@ def forward(
1036
1040
# repeat k/v heads if n_kv_heads < n_heads
1037
1041
# paddle version > 2.6 or develop support flash-attn with gqa/mqa
1038
1042
paddle_version = float (paddle .__version__ [:3 ])
1039
- if ( paddle_version != 0.0 ) and (paddle_version <= 2.6 ):
1043
+ if not self . config . use_flash_attention or (( paddle_version != 0.0 ) and (paddle_version <= 2.6 ) ):
1040
1044
key_states = repeat_kv (key_states , self .num_key_value_groups )
1041
1045
value_states = repeat_kv (value_states , self .num_key_value_groups )
1042
1046
@@ -1560,7 +1564,6 @@ def forward(
1560
1564
else :
1561
1565
attention_mask = attention_mask .astype ("bool" )
1562
1566
hidden_states = inputs_embeds
1563
-
1564
1567
# decoder layers
1565
1568
all_hidden_states = () if output_hidden_states else None
1566
1569
all_self_attns = () if output_attentions else None
0 commit comments