Skip to content

Commit 1ce2642

Browse files
authored
[AutoParallel] Add FLAGS_enable_fused_ffn_qkv_pass for llama (#9182)
1 parent c623901 commit 1ce2642

File tree

1 file changed

+18
-6
lines changed

1 file changed

+18
-6
lines changed

paddlenlp/transformers/llama/modeling_auto.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from __future__ import annotations
1717

1818
import math
19+
import os
1920
import warnings
2021
from functools import partial
2122
from typing import Optional, Tuple
@@ -82,6 +83,17 @@ def swiglu(x, y=None):
8283
]
8384

8485

86+
def enable_fuse_ffn_qkv_pass():
87+
if os.getenv("FLAGS_enable_fused_ffn_qkv_pass") in [
88+
"True",
89+
"true",
90+
"1",
91+
]:
92+
return True
93+
else:
94+
return False
95+
96+
8597
def is_pp_enable():
8698
mesh = fleet.auto.get_mesh()
8799
return "pp" in mesh.dim_names
@@ -221,7 +233,7 @@ def __init__(self, config, ipp: Optional[int] = None):
221233
self.ipp = ipp
222234
self.config = config
223235

224-
if config.fuse_attention_ffn:
236+
if config.fuse_attention_ffn and not enable_fuse_ffn_qkv_pass():
225237
self.gate_up_fused_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias_attr=False)
226238
self.gate_up_fused_proj.weight = dist.shard_tensor(
227239
self.gate_up_fused_proj.weight,
@@ -251,7 +263,7 @@ def __init__(self, config, ipp: Optional[int] = None):
251263
)
252264

253265
def forward(self, x):
254-
if self.fuse_attention_ffn:
266+
if self.fuse_attention_ffn and not enable_fuse_ffn_qkv_pass():
255267
x = swiglu(self.gate_up_fused_proj(x))
256268
else:
257269
x = swiglu(self.gate_proj(x), self.up_proj(x))
@@ -298,7 +310,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False, ipp:
298310
)
299311
self.use_fused_rope = False
300312

301-
if self.fuse_attention_qkv:
313+
if self.fuse_attention_qkv and not enable_fuse_ffn_qkv_pass():
302314
self.qkv_proj = nn.Linear(
303315
self.hidden_size,
304316
self.hidden_size + 2 * self.config.num_key_value_heads * self.head_dim,
@@ -412,7 +424,7 @@ def forward(
412424
[dist.Shard(1), dist.Replicate()],
413425
)
414426

415-
if self.fuse_attention_qkv:
427+
if self.fuse_attention_qkv and not enable_fuse_ffn_qkv_pass():
416428
target_shape = [0, 0, self.num_key_value_heads, (self.num_key_value_groups + 2) * self.head_dim]
417429
mix_layer = self.qkv_proj(hidden_states)
418430
mix_layer = paddle.reshape_(mix_layer, target_shape)
@@ -760,7 +772,7 @@ def get_tensor_parallel_split_mappings(num_layers):
760772
}
761773

762774
# Column Linear
763-
if config.fuse_attention_qkv:
775+
if config.fuse_attention_qkv and not enable_fuse_ffn_qkv_pass():
764776
base_actions["layers.0.self_attn.qkv_proj.weight"] = partial(fn, is_column=True)
765777
else:
766778
base_actions["layers.0.self_attn.q_proj.weight"] = partial(fn, is_column=True)
@@ -769,7 +781,7 @@ def get_tensor_parallel_split_mappings(num_layers):
769781
base_actions["layers.0.self_attn.k_proj.weight"] = partial(fn, is_column=True)
770782
base_actions["layers.0.self_attn.v_proj.weight"] = partial(fn, is_column=True)
771783

772-
if config.fuse_attention_ffn:
784+
if config.fuse_attention_ffn and not enable_fuse_ffn_qkv_pass():
773785
base_actions["layers.0.mlp.gate_up_fused_proj.weight"] = partial(
774786
fn, is_column=True, is_naive_2fuse=True
775787
)

0 commit comments

Comments
 (0)