Skip to content

Commit c8f1414

Browse files
authored
add config memory_growth_threshold for save hf (#11183)
1 parent e4d28de commit c8f1414

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

paddlenlp/trainer/trainer_utils.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1632,8 +1632,15 @@ def save_hf_checkpoint(
16321632
shard_idx,
16331633
path,
16341634
):
1635+
1636+
memory_growth_threshold = 8 * (2**30)
16351637
itr = model.full(
1636-
aoa_config=aoa_config, h_group=h_group, v_group=v_group, num_splits=num_splits, shard_idx=shard_idx
1638+
aoa_config=aoa_config,
1639+
h_group=h_group,
1640+
v_group=v_group,
1641+
num_splits=num_splits,
1642+
shard_idx=shard_idx,
1643+
memory_growth_threshold=memory_growth_threshold,
16371644
)
16381645
num_saver_ranks = h_group.nranks * v_group.nranks
16391646
rank = h_group.rank + v_group.rank * h_group.nranks

0 commit comments

Comments
 (0)