Skip to content

Commit 942865f

Browse files
authored
[Trainer] Support release gradients for develop branch (#7594)
* support release gradients for develop branch
1 parent b541fc2 commit 942865f

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

paddlenlp/trainer/trainer.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -930,6 +930,7 @@ def train(
930930
)
931931
enable_delay_scale_loss = "enable_delay_scale_loss" in pipeline_parallel_config
932932
enable_dp_comm_overlap = "enable_dp_comm_overlap" in pipeline_parallel_config
933+
enable_release_grads = "enable_release_grads" in pipeline_parallel_config
933934

934935
# Case 3: Pipeline parallel mode, overlap with dp
935936
if isinstance(self.optimizer, HybridParallelOptimizer) and not self.do_grad_scaling:
@@ -982,7 +983,14 @@ def train(
982983
if optimizer_was_run:
983984
self.lr_scheduler.step()
984985

985-
self.optimizer.clear_grad()
986+
if enable_release_grads and args.pipeline_parallel_degree > 1:
987+
self.optimizer.clear_grad(set_to_zero=False)
988+
for _, buffers in model._chunk_2_comm_buffers.items():
989+
for buffer in buffers:
990+
buffer._clear_grad_storage()
991+
else:
992+
self.optimizer.clear_grad()
993+
986994
self.callback_handler.on_optimizer_end(
987995
args, self.state, self.control, scaler=self.scaler if self.do_grad_scaling else None
988996
)

paddlenlp/trainer/training_args.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,7 @@ class TrainingArguments:
243243
enable_delay_scale_loss, accumulate gradients util optimizer step, all gradients div by inner pipeline accumute step. instead of div accumute step on loss directly.
244244
enable_dp_comm_overlap, fuse data parallel gradient communication.
245245
enable_sharding_comm_overlap, fuse sharding stage 1 parallel gradient communication.
246+
enable_release_grads, reduce peak memory usage by releasing gradients after each iteration. The creation of gradients will be postponed until backward propagation of the next iteration.
246247
sharding_parallel_config (`str`, *optional*)(
247248
Some additional config it highly affect the useage of sharding parallel, we provide some option to config it.
248249
following config is support:
@@ -930,6 +931,7 @@ def __post_init__(self):
930931
"enable_dp_comm_overlap",
931932
"enable_sharding_comm_overlap",
932933
"enable_timer",
934+
"enable_release_grads",
933935
]:
934936
raise ValueError(
935937
f"Found unknown pipeline mode config {x}, accpet config is disable_p2p_cache_shape, disable_partial_send_recv."
@@ -950,6 +952,7 @@ def __post_init__(self):
950952
"sharding_comm_overlap": "enable_sharding_comm_overlap" in pipeline_parallel_config
951953
and self.sharding_parallel_degree > 1,
952954
"enable_timer": "enable_timer" in pipeline_parallel_config,
955+
"release_gradients": "enable_release_grads" in pipeline_parallel_config,
953956
}
954957
if dygraph_pp_configs["dp_comm_overlap"]:
955958
raise ValueError("overlap has accuracy issue") # TODO: fix `overalap` + `delay_scale` issue

0 commit comments

Comments
 (0)