Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions examples/run_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,19 @@
os.environ["USE_CASUAL_MASK"] = "False"


def mock_offload_optimizer():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unified_checkpoint_config: ignore_merge_optimizer
optim: adamw_custom
tensorwise_offload_optimizer: True
训练添加上述也可以做到offload optimizer降低显存,暂时先把optimizer相关修改删除,这个PR可以先合一版

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已删

"""
mock offload optimizer
"""
try:
from paddleformers.trainer.utils.offload_optimizer import hack_offload_optimizer

hack_offload_optimizer()
logger.warning("hack_offload_optimizer called.")
except ImportError:
logger.warning("hack_offload_optimizer is not imported")


def main():
parser = PdArgumentParser((ModelConfig, DataConfig, SFTConfig))
if len(sys.argv) >= 2 and sys.argv[1].endswith(".json"):
Expand All @@ -60,9 +73,18 @@ def main():
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()

if training_args.tensorwise_offload_optimizer:
mock_offload_optimizer()

training_args.print_config(model_args, "Model")
training_args.print_config(data_args, "Data")

if training_args.pre_alloc_memory > 0:
memory_size = int(training_args.pre_alloc_memory * 1024 * 1024 * 1024)
x = paddle.empty([memory_size], dtype=paddle.uint8)
logger.info(f"pre_alloc_memory size {x.shape}")
del x

# Setup GPU & distributed training
paddle.set_device(training_args.device)
set_seed(seed=training_args.seed)
Expand Down Expand Up @@ -134,6 +156,7 @@ def main():
model_config.max_sequence_length = training_args.max_seq_len
model_config.num_nextn_predict_layers = model_args.num_nextn_predict_layers
model_config._attn_implementation = model_args.attn_impl
model_config.moe_subbatch_token_num = model_args.moe_subbatch_token_num
logger.info(f"Final model config: {model_config}")
logger.info("Creating model")

Expand Down
2 changes: 1 addition & 1 deletion paddleformers/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2788,7 +2788,7 @@ def _save_checkpoint(self, model, metrics=None):
optimizer_name = _add_variant(PADDLE_OPTIMIZER_NAME, self.args.optimizer_name_suffix)
saved_signal_path = os.path.join(output_dir, f"saved_signal_{dist.get_rank()}")

if self.args.unified_checkpoint and self.args.offload_optim:
if self.args.unified_checkpoint and (self.args.offload_optim or self.args.tensorwise_offload_optimizer):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个会造成这保存optimizer参数的时候显存异常上涨,不建议这么写

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

self._reload_optimizer()

if self.args.use_hybrid_parallel:
Expand Down
4 changes: 4 additions & 0 deletions paddleformers/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1096,6 +1096,10 @@ class TrainingArguments:
default=False,
metadata={"help": "Controls the parallel execution order. False (pp first), True (sharding first)."},
)
pre_alloc_memory: int = field(
default=0,
metadata={"help": "pre allocate memory size GB"},
)

def __post_init__(self):
world_size = paddle.distributed.get_world_size()
Expand Down
81 changes: 81 additions & 0 deletions paddleformers/trainer/utils/offload_optimizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
from paddle import _C_ops
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.hybrid_parallel_optimizer import (
HybridParallelOptimizer,
)
from paddle.optimizer import Optimizer

from .sharding_io import to_device


def offload(tensor):
if paddle.is_compiled_with_cuda():
place = paddle.CUDAPinnedPlace()
else:
place = paddle.CPUPlace()

new_tensor = to_device(tensor, place)
assert new_tensor is tensor, "to_device must be inplace operation"


def reload(tensor):
new_tensor = to_device(tensor)
assert new_tensor is tensor, "to_device must be inplace operation"


def hack_offload_optimizer():
# Step 1: mock _add_accumulator
origin_add_accumulator = getattr(Optimizer, "_add_accumulator")

def new_add_accumulator(self, *args, **kwargs):
x = origin_add_accumulator(self, *args, **kwargs)
offload(x)
return x

setattr(Optimizer, "_add_accumulator", new_add_accumulator)

# Step 2: mock _C_ops.adamw_ and _C_ops.adamw
for name in ["adam_", "adamw_"]:
origin_op = getattr(_C_ops, name)

def new_opt_op(*args):
for arg in args:
if isinstance(arg, paddle.Tensor):
reload(arg)

ret = origin_op(*args)

for i, arg in enumerate(args):
if i >= 2 and isinstance(arg, paddle.Tensor): # do not offload parameter and gradient
offload(arg)
return ret

setattr(_C_ops, name, new_opt_op)

# Step 3: mock _insert_sync
opt_type = HybridParallelOptimizer
origin_insert_sync = getattr(opt_type, "_insert_sync")

def new_insert_sync(self, sync_var, *args, **kwargs):
origin_place = sync_var.place
reload(sync_var)
ret = origin_insert_sync(self, sync_var, *args, **kwargs)
new_sync_var = to_device(sync_var, origin_place)
assert new_sync_var is sync_var, "to_device must be inplace operation"
return ret

setattr(opt_type, "_insert_sync", new_insert_sync)
2 changes: 2 additions & 0 deletions paddleformers/transformers/glm4_moe/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ def __init__(
seq_aux=True,
topk_method="noaux_tc",
using_flex_token=True,
moe_subbatch_token_num=0,
**kwargs,
):
self.vocab_size = vocab_size
Expand Down Expand Up @@ -200,6 +201,7 @@ def __init__(
self.topk_method = topk_method
self.using_flex_token = using_flex_token
self.use_fp8 = False
self.moe_subbatch_token_num = moe_subbatch_token_num

self.pp_seg_method = pp_seg_method
self.disable_ffn_model_parallel = disable_ffn_model_parallel
Expand Down
Loading
Loading