16
16
from __future__ import annotations
17
17
18
18
import math
19
+ import os
19
20
import warnings
20
21
from functools import partial
21
22
from typing import Optional , Tuple
@@ -82,6 +83,17 @@ def swiglu(x, y=None):
82
83
]
83
84
84
85
86
+ def enable_fuse_ffn_qkv_pass ():
87
+ if os .getenv ("FLAGS_enable_fused_ffn_qkv_pass" ) in [
88
+ "True" ,
89
+ "true" ,
90
+ "1" ,
91
+ ]:
92
+ return True
93
+ else :
94
+ return False
95
+
96
+
85
97
def is_pp_enable ():
86
98
mesh = fleet .auto .get_mesh ()
87
99
return "pp" in mesh .dim_names
@@ -221,7 +233,7 @@ def __init__(self, config, ipp: Optional[int] = None):
221
233
self .ipp = ipp
222
234
self .config = config
223
235
224
- if config .fuse_attention_ffn :
236
+ if config .fuse_attention_ffn and not enable_fuse_ffn_qkv_pass () :
225
237
self .gate_up_fused_proj = nn .Linear (self .hidden_size , self .intermediate_size * 2 , bias_attr = False )
226
238
self .gate_up_fused_proj .weight = dist .shard_tensor (
227
239
self .gate_up_fused_proj .weight ,
@@ -251,7 +263,7 @@ def __init__(self, config, ipp: Optional[int] = None):
251
263
)
252
264
253
265
def forward (self , x ):
254
- if self .fuse_attention_ffn :
266
+ if self .fuse_attention_ffn and not enable_fuse_ffn_qkv_pass () :
255
267
x = swiglu (self .gate_up_fused_proj (x ))
256
268
else :
257
269
x = swiglu (self .gate_proj (x ), self .up_proj (x ))
@@ -298,7 +310,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False, ipp:
298
310
)
299
311
self .use_fused_rope = False
300
312
301
- if self .fuse_attention_qkv :
313
+ if self .fuse_attention_qkv and not enable_fuse_ffn_qkv_pass () :
302
314
self .qkv_proj = nn .Linear (
303
315
self .hidden_size ,
304
316
self .hidden_size + 2 * self .config .num_key_value_heads * self .head_dim ,
@@ -412,7 +424,7 @@ def forward(
412
424
[dist .Shard (1 ), dist .Replicate ()],
413
425
)
414
426
415
- if self .fuse_attention_qkv :
427
+ if self .fuse_attention_qkv and not enable_fuse_ffn_qkv_pass () :
416
428
target_shape = [0 , 0 , self .num_key_value_heads , (self .num_key_value_groups + 2 ) * self .head_dim ]
417
429
mix_layer = self .qkv_proj (hidden_states )
418
430
mix_layer = paddle .reshape_ (mix_layer , target_shape )
@@ -760,7 +772,7 @@ def get_tensor_parallel_split_mappings(num_layers):
760
772
}
761
773
762
774
# Column Linear
763
- if config .fuse_attention_qkv :
775
+ if config .fuse_attention_qkv and not enable_fuse_ffn_qkv_pass () :
764
776
base_actions ["layers.0.self_attn.qkv_proj.weight" ] = partial (fn , is_column = True )
765
777
else :
766
778
base_actions ["layers.0.self_attn.q_proj.weight" ] = partial (fn , is_column = True )
@@ -769,7 +781,7 @@ def get_tensor_parallel_split_mappings(num_layers):
769
781
base_actions ["layers.0.self_attn.k_proj.weight" ] = partial (fn , is_column = True )
770
782
base_actions ["layers.0.self_attn.v_proj.weight" ] = partial (fn , is_column = True )
771
783
772
- if config .fuse_attention_ffn :
784
+ if config .fuse_attention_ffn and not enable_fuse_ffn_qkv_pass () :
773
785
base_actions ["layers.0.mlp.gate_up_fused_proj.weight" ] = partial (
774
786
fn , is_column = True , is_naive_2fuse = True
775
787
)
0 commit comments