Skip to content

Commit d841351

Browse files
authored
【GLM】subbatch performance and weight bug fix (#2661)
1 parent f22279a commit d841351

File tree

5 files changed

+212
-45
lines changed

5 files changed

+212
-45
lines changed

examples/run_finetune.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,12 @@ def main():
6363
training_args.print_config(model_args, "Model")
6464
training_args.print_config(data_args, "Data")
6565

66+
if training_args.pre_alloc_memory > 0:
67+
memory_size = int(training_args.pre_alloc_memory * 1024 * 1024 * 1024)
68+
x = paddle.empty([memory_size], dtype=paddle.uint8)
69+
logger.info(f"pre_alloc_memory size {x.shape}")
70+
del x
71+
6672
# Setup GPU & distributed training
6773
paddle.set_device(training_args.device)
6874
set_seed(seed=training_args.seed)
@@ -134,6 +140,7 @@ def main():
134140
model_config.max_sequence_length = training_args.max_seq_len
135141
model_config.num_nextn_predict_layers = model_args.num_nextn_predict_layers
136142
model_config._attn_implementation = model_args.attn_impl
143+
model_config.moe_subbatch_token_num = model_args.moe_subbatch_token_num
137144
logger.info(f"Final model config: {model_config}")
138145
logger.info("Creating model")
139146

paddleformers/trainer/training_args.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1096,6 +1096,10 @@ class TrainingArguments:
10961096
default=False,
10971097
metadata={"help": "Controls the parallel execution order. False (pp first), True (sharding first)."},
10981098
)
1099+
pre_alloc_memory: int = field(
1100+
default=0,
1101+
metadata={"help": "pre allocate memory size GB"},
1102+
)
10991103

11001104
def __post_init__(self):
11011105
world_size = paddle.distributed.get_world_size()

paddleformers/transformers/glm4_moe/configuration.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ def __init__(
158158
seq_aux=True,
159159
topk_method="noaux_tc",
160160
using_flex_token=True,
161+
moe_subbatch_token_num=0,
161162
**kwargs,
162163
):
163164
self.vocab_size = vocab_size
@@ -200,6 +201,7 @@ def __init__(
200201
self.topk_method = topk_method
201202
self.using_flex_token = using_flex_token
202203
self.use_fp8 = False
204+
self.moe_subbatch_token_num = moe_subbatch_token_num
203205

204206
self.pp_seg_method = pp_seg_method
205207
self.disable_ffn_model_parallel = disable_ffn_model_parallel

0 commit comments

Comments
 (0)