Skip to content

Commit 884ca21

Browse files
authored
Balance_gate & O1 recompute configuration (#10883)
* balance_gate * add configuration of o1 rc nums
1 parent 6d58c7c commit 884ca21

File tree

6 files changed

+98
-19
lines changed

6 files changed

+98
-19
lines changed

paddlenlp/transformers/deepseek_v2/configuration.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,8 +182,9 @@ def __init__(
182182
use_dualpipev=False,
183183
send_mtp_embed=False,
184184
using_post_norm_recompute=False,
185-
recompute_fwd_gate_up=False,
185+
recompute_fwd_gate_up=0,
186186
is_split_group_gemm=False,
187+
fakse_gate_restrict_balance=False,
187188
**kwargs,
188189
):
189190
self.vocab_size = vocab_size
@@ -237,6 +238,7 @@ def __init__(
237238
self.using_post_norm_recompute = using_post_norm_recompute
238239
self.recompute_fwd_gate_up = recompute_fwd_gate_up
239240
self.is_split_group_gemm = is_split_group_gemm
241+
self.fakse_gate_restrict_balance = fakse_gate_restrict_balance
240242

241243
super().__init__(
242244
pad_token_id=pad_token_id,

paddlenlp/transformers/deepseek_v2/modeling.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -772,18 +772,33 @@ def backward(ctx, d_gate_logits, d_norm_output):
772772
return dx, d_rms_norm_weight, d_moe_gate_weight
773773

774774

775+
def balance_expert_assignment(n, m, k):
776+
assert k * n % m == 0
777+
matrix = paddle.zeros((n, m), dtype=paddle.int32)
778+
for row in range(n):
779+
start_col = row % m
780+
for i in range(k):
781+
col = (start_col + i) % m
782+
matrix[row, col] = 1
783+
return matrix
784+
785+
775786
class FakeGate(paddle.autograd.PyLayer):
776787
@staticmethod
777-
def forward(ctx, hidden_states, weight):
788+
def forward(ctx, hidden_states, weight, fakse_gate_restrict_balance=False, num_experts_per_tok=8):
778789
expert_num = weight.shape[1]
779790
bsz, seq, _ = hidden_states.shape
780791

781792
ctx.x_shape = hidden_states.shape
782793
ctx.x_dtype = hidden_states.dtype
783794
ctx.y_shape = weight.shape
784795
ctx.y_dtype = weight.dtype
785-
786-
return paddle.randn([bsz, seq, expert_num]).cast(weight.dtype)
796+
if fakse_gate_restrict_balance:
797+
return paddle.reshape(
798+
balance_expert_assignment(bsz * seq, expert_num, num_experts_per_tok), [bsz, seq, expert_num]
799+
)
800+
else:
801+
return paddle.randn([bsz, seq, expert_num]).cast(weight.dtype)
787802

788803
@staticmethod
789804
def backward(ctx, grad_output):
@@ -841,11 +856,23 @@ def forward(self, hidden_states):
841856
# compute gating score
842857
if self.using_post_norm_recompute:
843858
logits, norm_out = FusedNormGateFunc.apply(hidden_states, self.norm_weight, self.weight, self.norm_eps)
859+
if hasattr(self.config, "using_fake_gate") and self.config.using_fake_gate:
860+
logits = FakeGate.apply(
861+
hidden_states,
862+
self.weight,
863+
self.config.fakse_gate_restrict_balance,
864+
self.config.num_experts_per_tok,
865+
)
844866
else:
845867
with paddle.amp.auto_cast(False):
846868
hidden_states = hidden_states.cast(self.weight.dtype)
847869
if hasattr(self.config, "using_fake_gate") and self.config.using_fake_gate:
848-
logits = FakeGate.apply(hidden_states, self.weight)
870+
logits = FakeGate.apply(
871+
hidden_states,
872+
self.weight,
873+
self.config.fakse_gate_restrict_balance,
874+
self.config.num_experts_per_tok,
875+
)
849876
else:
850877
logits = F.linear(hidden_states, self.weight, None)
851878

paddlenlp/transformers/deepseek_v2/modeling_pp.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import math
1516
import os
1617
from typing import OrderedDict, Tuple, Union
1718

@@ -29,6 +30,7 @@
2930
from paddle.distributed.fleet.recompute.recompute import recompute
3031
from paddle.distributed.fleet.utils.sequence_parallel_utils import ScatterOp
3132

33+
from ...utils.log import logger
3234
from ...utils.tools import get_env_device
3335
from ..model_utils import PipelinePretrainedModel
3436
from .modeling import (
@@ -1445,6 +1447,43 @@ def get_hcg():
14451447
LayerDesc(DeepseekV2EmbeddingPipe, config=config), self._base_model.base_model_prefix
14461448
)
14471449

1450+
def compute_recompute_fwd_gate_up_list(pp_nums, all_dl_nums, dense_dl_nums, recompute_fwd_gate_up):
1451+
all_layers_nums = all_dl_nums + 4 # embedding, rms, lm_head, mtp
1452+
segment_size = all_layers_nums // pp_nums
1453+
boundary = math.ceil((1 + dense_dl_nums) / segment_size) * segment_size
1454+
recompute_fwd_gate_up_list = [dense_dl_nums]
1455+
for idx in range(boundary - 1, all_dl_nums, segment_size):
1456+
recompute_fwd_gate_up_list.append(idx)
1457+
1458+
# If `recompute_fwd_gate_up` is a Boolean value and is True, means all O1 will be recomputed.
1459+
# Otherwise `recompute_fwd_gate_up` should be an integer representing how many O1 are recomputed.
1460+
assert isinstance(recompute_fwd_gate_up, (int, bool))
1461+
if type(recompute_fwd_gate_up) is bool:
1462+
enable_k_o1_rc = segment_size if recompute_fwd_gate_up is True else 0
1463+
else:
1464+
enable_k_o1_rc = recompute_fwd_gate_up
1465+
1466+
ret = []
1467+
for i in range(len(recompute_fwd_gate_up_list)):
1468+
for k in range(min(segment_size, enable_k_o1_rc)):
1469+
ret.append(recompute_fwd_gate_up_list[i] + k)
1470+
return ret
1471+
1472+
pp_nums = (
1473+
self.config["pipeline_parallel_degree"] * 2
1474+
if self.config.use_dualpipev
1475+
else self.config["pipeline_parallel_degree"]
1476+
)
1477+
recompute_fwd_gate_up_list = compute_recompute_fwd_gate_up_list(
1478+
pp_nums,
1479+
self.config.num_hidden_layers,
1480+
self.config.first_k_dense_replace,
1481+
self.config.recompute_fwd_gate_up,
1482+
)
1483+
1484+
logger.info(f"recompute_fwd_gate_up_list: {recompute_fwd_gate_up_list}")
1485+
config.recompute_fwd_gate_up_list = recompute_fwd_gate_up_list
1486+
14481487
for i in range(config.num_hidden_layers):
14491488
self.add_sequential_layer(
14501489
LayerDesc(
@@ -1519,8 +1558,8 @@ def overlapped_forward_backward(
15191558
backward_loss_fn_node,
15201559
backward_input_grads,
15211560
scaler,
1522-
combine_bw_event_to_wait = None,
1523-
pp_stream=None
1561+
combine_bw_event_to_wait=None,
1562+
pp_stream=None,
15241563
):
15251564
if backward_loss_fn_node is not None:
15261565
if scaler:

paddlenlp/transformers/fp8_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1117,6 +1117,11 @@ def backward_dx(self, out_grad):
11171117

11181118
self.out_grad = out_grad
11191119

1120+
# clear status for save memory
1121+
self.m_indices = None
1122+
self.unzipped_probs = None
1123+
self.input = None
1124+
11201125
# dx
11211126
dx = self.bwd_gate_up_input(do1, expert_w1, dx=out_grad[0] if isinstance(out_grad, tuple) else out_grad)
11221127

paddlenlp/transformers/moe_layer.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -668,7 +668,7 @@ def backward(self, output_grad, event_to_wait=None):
668668
if DSV3_USE_FP8_DISPATCH:
669669
if event_to_wait is not None:
670670
assert self.moe_group is not None
671-
event_to_wait.comm_stream_wait( self.moe_group.id)
671+
event_to_wait.comm_stream_wait(self.moe_group.id)
672672
buffer = get_buffer(self.token_dispatcher._comm_manager.group, get_hidden_bytes(output_grad))
673673
custom_stream = paddle.device.Stream(stream_base=buffer.runtime.get_comm_stream())
674674
else:
@@ -697,19 +697,19 @@ class FusionMlpNode:
697697
def __init__(self, custom_map, max_topk, recompute_fwd_gate_up=False, is_split_group_gemm=True):
698698
self.token_dispatcher = custom_map.token_dispatcher
699699
self.experts = custom_map.experts
700+
self.unzip_node = UnZipNode()
701+
self.zip_node = ZipNode()
700702
self.experts_group_gemm_node = FP8GroupGemmMlpFunctionNode(
701703
custom_map,
702704
recompute_fwd_gate_up=recompute_fwd_gate_up,
703705
is_split_group_gemm=is_split_group_gemm,
704706
)
705-
self.unzip_node = UnZipNode(self.token_dispatcher)
706-
self.zip_node = ZipNode(self.token_dispatcher)
707707
self.dispatched_indices = None
708708
self.dispatched_probs = None
709709
self.tokens_per_expert = None
710710
self.router_topk = max_topk
711711

712-
def reset_statue(self):
712+
def reset_statue(self, with_dw=False):
713713
"""
714714
重置所有状态变量。
715715
@@ -724,8 +724,15 @@ def reset_statue(self):
724724
self.dispatched_probs = None
725725
self.tokens_per_expert = None
726726
self.router_topk = None
727-
self.experts_group_gemm_node.reset_statue()
728-
self.experts_group_gemm_node = None
727+
728+
del self.unzip_node
729+
del self.zip_node
730+
self.unzip_node = None
731+
self.zip_node = None
732+
733+
if with_dw:
734+
self.experts_group_gemm_node.reset_statue()
735+
self.experts_group_gemm_node = None
729736

730737
@paddle.no_grad()
731738
def forward(self, hs_2d_dispatched, dispatched_indices, dispatched_probs):
@@ -847,13 +854,14 @@ def backward(self, hidden_states_out_grad, with_dw=True):
847854
self.dispatched_indices,
848855
num_experts=len(self.tokens_per_expert),
849856
)
850-
if with_dw:
851-
self.reset_statue()
857+
858+
self.reset_statue(with_dw)
852859
return hs_dispatched_grad, dispatched_probs_grad
853860

854861
@paddle.no_grad()
855862
def backward_dw(self):
856863
self.experts_group_gemm_node.backward_dw()
864+
self.reset_statue(True)
857865

858866

859867
class FusionMoeNode:

paddlenlp/transformers/moe_utils.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,7 @@ def unpermute(
125125

126126

127127
class UnZipNode:
128-
def __init__(self, token_dispatcher, name="unzip"):
129-
self.token_dispatcher = token_dispatcher
128+
def __init__(self, name="unzip"):
130129
self.name = name
131130
self.unzipped_probs = None
132131
self.zipped_expertwise_rowmap = None
@@ -199,8 +198,7 @@ def backward(self, dx, hidden_states_out_grad, probs_grad, dispatched_indices, n
199198

200199

201200
class ZipNode:
202-
def __init__(self, token_dispatcher, name="zip"):
203-
self.token_dispatcher = token_dispatcher
201+
def __init__(self, name="zip"):
204202
self.name = name
205203

206204
@paddle.no_grad()

0 commit comments

Comments
 (0)