Skip to content

Commit 3737315

Browse files
authored
best speed (#11023)
1 parent 5b4855d commit 3737315

File tree

7 files changed

+77
-55
lines changed

7 files changed

+77
-55
lines changed

llm/config/deepseek-v3/pretrain_argument.json

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,17 @@
33
"tokenizer_name_or_path": "deepseek-ai/DeepSeek-V3",
44
"input_dir": "./data",
55
"output_dir": "./checkpoints/pretrain_ckpts",
6+
"resume_from_huggingface_ckpt": "./huggingface_ckpt/",
67
"per_device_train_batch_size": 1,
78
"gradient_accumulation_steps": 60,
89
"per_device_eval_batch_size": 1,
910
"tensor_parallel_degree": 1,
1011
"pipeline_parallel_degree": 8,
11-
"pipeline_parallel_config": "use_dualpipev",
12-
"sharding_parallel_degree": 64,
13-
"sharding_parallel_config": "split_param enable_fuse_optimizer_states",
14-
"sharding_comm_buffer_size_MB": 4096,
15-
"expert_parallel_degree": 64,
12+
"pipeline_parallel_config": "use_dualpipev enable_overlap_p2p_comm",
13+
"sharding_parallel_degree": 32,
14+
"sharding_parallel_config": "split_param",
15+
"sharding_comm_buffer_size_MB": 2048,
16+
"expert_parallel_degree": 32,
1617
"sharding": "stage1",
1718
"virtual_pp_degree": 1,
1819
"sequence_parallel": 0,
@@ -47,7 +48,7 @@
4748
"use_fused_rope": true,
4849
"save_sharded_model": false,
4950
"load_sharded_model": false,
50-
"unified_checkpoint": true,
5151
"use_expert_parallel": true,
52-
"unified_checkpoint_config": "skip_save_model_weight"
52+
"unified_checkpoint_config": "skip_save_model_weight",
53+
"offload_optim": true
5354
}

llm/model_config/DeepSeek-V3/config.json

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,15 +59,20 @@
5959
"v_head_dim": 128,
6060
"vocab_size": 129280,
6161
"using_flex_token": true,
62-
"using_fake_gate": true,
62+
"using_fake_gate": false,
6363
"use_fused_rms_norm": true,
6464
"fuse_attention_ffn": true,
6565
"use_fused_rope": true,
6666
"token_drop_steps": 0,
6767
"recompute_fwd_gate_up": true,
68-
"adaptive_remained_O1_recompute_ratio": 2.0,
68+
"adaptive_remained_O1_recompute_ratio": 0,
6969
"using_post_norm_recompute": true,
7070
"is_split_group_gemm": false,
7171
"use_dualpipev": true,
72-
"send_mtp_embed": false
72+
"send_mtp_embed": false,
73+
"mlp_fwd_subbatch_rows": 0,
74+
"mlp_bwd_subbatch_rows": 65536,
75+
"output_subbatch_rows": 2048,
76+
"recompute_fa3": true,
77+
"stepped_recompute_fwd_gate_up": true
7378
}

llm/script/train_gpu.sh

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,22 @@ fi
4747
export PYTHONPATH=../:$PYTHONPATH
4848
export CUDA_PATH=/usr/local/cuda-12.9
4949

50+
# Flags for best performance
5051
export DSV3_USE_FP8_GEMM=true
5152
export DSV3_USE_ATTEN_RECOMPUTE=true
5253
export FA_VERSION=3
5354
export FLAGS_share_tensor_for_grad_tensor_holder=1
5455
export FLAGS_use_default_stream=false
5556
export DSV3_USE_FP8_DISPATCH=true
56-
export USE_DS_GEMM=false
57+
export USE_DS_GEMM=true
58+
59+
# Flags for allocator
60+
export FLAGS_large_pool_auto_growth_chunk_size_in_mb=500
61+
export FLAGS_small_pool_auto_growth_chunk_size_in_mb=20
62+
export FLAGS_small_pool_size_in_mb=10
63+
export FLAGS_samll_pool_pre_alloc_in_mb=500
64+
export FLAGS_large_pool_pre_alloc_in_mb=61440
65+
export FLAGS_deep_ep_comm_prealloc_in_mb=1000
5766

5867

5968
bash script/kill_process.sh

paddlenlp/trainer/utils/load_hf_ckpt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def _get_hf_prefix(segment_id: int, id_in_segment: int) -> str:
207207
# special_cases = {(0, 0): "model", (28, 2): "model.layers.61", (28, 3): "model"}
208208
# special_cases = {(0, 0): "model", (28, 2): "model.layers.61", (4, 1): "model"}
209209
# special_cases = {(0, 0): "model", (28, 2): "model", (28,3): "lm_head"}
210-
special_cases = {(0, 0): "model", (60, 2): "model", (60, 3): "lm_head"}
210+
special_cases = {(0, 0): "model", (60, 2): "model.layers.61", (60, 3): "model", (60, 4): "lm_head"}
211211

212212
if (segment_id, id_in_segment) in special_cases:
213213
return special_cases[(segment_id, id_in_segment)]

paddlenlp/transformers/deepseek_v2/modeling.py

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@
8585
from ..fp8_utils import (
8686
FP8KeepXLinear,
8787
FP8Linear,
88+
FP8LinearFunction,
8889
FP8LinearFunctionBase,
8990
FP8Mlp,
9091
cache_fp8_weight,
@@ -1552,7 +1553,10 @@ def backward(ctx, dout):
15521553
else:
15531554
assert False, f"invalid {FA_VERSION=}"
15541555

1555-
if (FA_VERSION == 3 and not recompute_fa3) or FA_VERSION == 2:
1556+
if FA_VERSION == 2:
1557+
assert not recompute_fa3
1558+
assert attn_out is not None and softmax_lse is not None
1559+
if FA_VERSION == 3 and not recompute_fa3:
15561560
assert attn_out is not None and softmax_lse is not None
15571561

15581562
q_ln_t, q_ln_invar = fused_ln.fused_rms_norm(q_init, q_ln_weight, eps)
@@ -1636,25 +1640,25 @@ def backward(ctx, dout):
16361640
elif FA_VERSION == 3:
16371641
# recompute fa3
16381642
if recompute_fa3:
1639-
logger.info("Enable fa3 recomputation")
1640-
attn_out, softmax_lse = _C_ops.flash_attn_v3(
1641-
query_states,
1642-
key_states,
1643-
value_states,
1644-
None, # q_v_
1645-
None, # q_descale_
1646-
None, # k_descale_
1647-
None, # v_descale_
1648-
softmax_scale,
1649-
True,
1650-
-1, # window_size_left
1651-
-1, # window_size_right
1652-
0.0, # softcap
1653-
1, # num_splits
1654-
False, # manual_set_pack_gqa
1655-
False, # pack_gqa_
1656-
0, # sm_margin
1657-
)
1643+
with paddle.no_grad():
1644+
attn_out, softmax_lse = _C_ops.flash_attn_v3(
1645+
query_states,
1646+
key_states,
1647+
value_states,
1648+
None, # q_v_
1649+
None, # q_descale_
1650+
None, # k_descale_
1651+
None, # v_descale_
1652+
softmax_scale,
1653+
True,
1654+
-1, # window_size_left
1655+
-1, # window_size_right
1656+
0.0, # softcap
1657+
1, # num_splits
1658+
False, # manual_set_pack_gqa
1659+
False, # pack_gqa_
1660+
0, # sm_margin
1661+
)
16581662
with paddle.no_grad():
16591663
q_grad, k_grad, v_grad = _C_ops.flash_attn_v3_grad(
16601664
query_states,
@@ -2587,7 +2591,7 @@ def forward(
25872591
nextn_hidden_state = self.enorm(nextn_hidden_state)
25882592

25892593
concat_h = paddle.concat([nextn_hidden_state, hidden_states], axis=-1)
2590-
hidden_states = LMHeadFunction.apply(concat_h, self.eh_proj.weight, False)
2594+
hidden_states = FP8LinearFunction.apply(concat_h, self.eh_proj)
25912595

25922596
layer_outputs = super(DeepseekV2MTPLayer, self).forward(
25932597
hidden_states,
@@ -3180,11 +3184,8 @@ def forward(
31803184
class FastCrossEntropyFunction(paddle.autograd.PyLayer):
31813185
@staticmethod
31823186
def forward(ctx, preds, labels):
3183-
31843187
softmax_val, loss = paddle._C_ops.cross_entropy_with_softmax(preds, labels, False, True, False, -100, -1)
31853188

3186-
# print("softmax val", softmax_val.dtype)
3187-
31883189
ctx.save_for_backward(labels, softmax_val)
31893190
return loss
31903191

paddlenlp/transformers/deepseek_v2/modeling_pp.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@
6767
)
6868
from paddlenlp.transformers.moe_layer import FusionMoeNode
6969

70-
from ..fp8_utils import FP8LinearFunctionBase
70+
from ..fp8_utils import FP8LinearFunction, FP8LinearFunctionBase
7171

7272
__all__ = [
7373
"DeepseekV2ForCausalLMPipe",
@@ -1204,11 +1204,12 @@ def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, p
12041204
combine_forward_event.current_stream_wait()
12051205
final_out_event.current_stream_wait()
12061206

1207-
if final_out.shape[-1] != combine_fwd_out.shape[-1]:
1208-
final_out[:, :, : combine_fwd_out.shape[-1]] += combine_fwd_out # 直接广播并相加
1209-
else:
1210-
final_out += combine_fwd_out
1211-
inputs = final_out
1207+
# TODO: check correct
1208+
# if final_out.shape[-1] != combine_fwd_out.shape[-1]:
1209+
# final_out[:, :, : combine_fwd_out.shape[-1]] += combine_fwd_out # 直接广播并相加
1210+
# else:
1211+
# final_out += combine_fwd_out
1212+
inputs = final_out + combine_fwd_out
12121213

12131214
final_out._record_stream()
12141215
combine_fwd_out._record_stream()
@@ -1400,9 +1401,7 @@ def build_overlapped_nodes(forward_chunk, backward_chunk):
14001401
backward_pre_node = ScheduleChunk(list(reversed(backward_pre_overlap_layers)))
14011402
backward_post_node = ScheduleChunk(list(reversed(backward_post_overlap_layers)))
14021403

1403-
if not forward_chunk.nodes and all(
1404-
isinstance(n, FusionFp8DecoderLayerNode) for n in backward_chunk.nodes
1405-
):
1404+
if not forward_chunk.nodes and all(isinstance(n, FusionFp8DecoderLayerNode) for n in backward_chunk.nodes):
14061405
backward_post_node = DecoderBackwardScheduleChunk(backward_post_overlap_layers)
14071406

14081407
overlap_node = OverlapedScheduleChunk(forward_overlap_layers, backward_overlap_layers, use_fuion=DSV3_USE_FP8_GEMM)
@@ -1938,7 +1937,8 @@ def attn_compute_for_fusion(self, args):
19381937
hidden_states = self.hnorm(hidden_states)
19391938
nextn_hidden_state = self.enorm(nextn_hidden_state)
19401939

1941-
hidden_states = self.eh_proj(paddle.concat([nextn_hidden_state, hidden_states], axis=-1))
1940+
concat_h = paddle.concat([nextn_hidden_state, hidden_states], axis=-1)
1941+
hidden_states = FP8LinearFunction.apply(concat_h, self.eh_proj)
19421942

19431943
# attention compute
19441944
hidden_states, residual = self.self_attn_compute(hidden_states)

paddlenlp/transformers/fp8_utils.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@ def swiglu(x, y=None):
5050
]
5151

5252

53+
def get_sm_num():
54+
return 112
55+
56+
5357
def set_parameter_color(
5458
parameters, color, group=None, offline_quant_expert_weight=True, clear_origin_weight_when_offline_quant=True
5559
):
@@ -159,7 +163,7 @@ def padding_and_quant_input(tensor):
159163
tensor_t_fp8, tensor_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
160164
tensor,
161165
output_scale_transpose=True,
162-
tquant_method="1x128",
166+
quant_method="1x128",
163167
input_transpose=True,
164168
return_transpose_only=True,
165169
)
@@ -178,7 +182,7 @@ def kitchen_gemm(
178182
if out is None:
179183
out = paddle.zeros([x_fp8.shape[0], w_fp8.shape[0]], rtn_dtype)
180184
if numpy.prod(x_fp8.shape) != 0 and numpy.prod(w_fp8.shape) != 0:
181-
deep_gemm.wgrad_gemm_fp8_fp8_fp32_nt((x_fp8, x_scale), (w_fp8, w_scale), out, num_sms=118)
185+
deep_gemm.wgrad_gemm_fp8_fp8_fp32_nt((x_fp8, x_scale), (w_fp8, w_scale), out, num_sms=get_sm_num())
182186
return out
183187

184188
if out is not None:
@@ -261,7 +265,9 @@ def compute_fp8_linear(
261265
if out is None:
262266
out = paddle.empty([input_fp8.shape[0], weight_fp8.shape[0]], dtype=weight.dtype)
263267

264-
deep_gemm.gemm_fp8_fp8_bf16_nt((input_fp8, input_scale.T), (weight_fp8, weight_scale), out, num_sms=118)
268+
deep_gemm.gemm_fp8_fp8_bf16_nt(
269+
(input_fp8, input_scale.T), (weight_fp8, weight_scale), out, num_sms=get_sm_num()
270+
)
265271

266272
# Return outputs
267273
if return_mode == "output_only":
@@ -351,7 +357,7 @@ def common_fp8_mlp_bwd(
351357
# Recompute o1 using deep_gemm(x_fp8, w1_t_fp8)
352358
w1_fp8, w1_scale = weight_quant(w1, True)
353359
o1 = paddle.empty([x_fp8.shape[0], w1_fp8.shape[0]], dtype=do3.dtype)
354-
deep_gemm.gemm_fp8_fp8_bf16_nt((x_fp8, x_scale.T), (w1_fp8, w1_scale), o1, num_sms=118)
360+
deep_gemm.gemm_fp8_fp8_bf16_nt((x_fp8, x_scale.T), (w1_fp8, w1_scale), o1, num_sms=get_sm_num())
355361

356362
# ===== [recompute] o2 = swiglu(o1) =====
357363
o2 = swiglu(o1)
@@ -838,7 +844,7 @@ def split_group_gemm(x_fp8, x_scale, w_fp8, w_scale, tokens_per_expert, gemm_out
838844
(x_fp8[start_idx:end_idx], x_scale_tma_align),
839845
(w_fp8[i], w_scale[i]),
840846
gemm_out[start_idx:end_idx],
841-
num_sms=118,
847+
num_sms=get_sm_num(),
842848
)
843849

844850
start_idx = end_idx
@@ -927,7 +933,7 @@ def fwd_gate_up(self, x, expert_w1, num_expert, tokens_per_expert, m_indices=Non
927933
(w1_t_quant, w1_t_scale),
928934
o1,
929935
m_indices=self.m_indices if m_indices is None else m_indices,
930-
num_sms=118,
936+
num_sms=get_sm_num(),
931937
)
932938

933939
if m_indices is None:
@@ -981,7 +987,7 @@ def fwd_down(
981987
(w2_quant, w2_scale),
982988
o3,
983989
m_indices=m_indices if self.fwd_subbatch else self.m_indices,
984-
num_sms=118,
990+
num_sms=get_sm_num(),
985991
)
986992

987993
return o3
@@ -1022,7 +1028,7 @@ def bwd_dowm_input(self, expert_w2, unzipped_grad, o1, tokens_per_expert, m_indi
10221028
(bw_w2_quant, bw_w2_scale),
10231029
do2_s,
10241030
m_indices=m_indices if self.bwd_subbatch else self.m_indices,
1025-
num_sms=118,
1031+
num_sms=get_sm_num(),
10261032
)
10271033

10281034
with paddle.amp.auto_cast(False):
@@ -1068,7 +1074,7 @@ def bwd_gate_up_input(self, do1, expert_w1, tokens_per_expert, m_indices=None, d
10681074
(bw_w1_quant, bw_w1_scale),
10691075
dx,
10701076
m_indices=m_indices if self.bwd_subbatch else self.m_indices,
1071-
num_sms=118,
1077+
num_sms=get_sm_num(),
10721078
)
10731079

10741080
return dx

0 commit comments

Comments
 (0)