Skip to content

Commit 672ee98

Browse files
authored
[Unified Checkpoint] fix checkpoint names (#7794)
when skipping model weighs save and saving master weights as model weights, unified checkpoint needs choose the model weights to load into master weights.
1 parent bb9062e commit 672ee98

File tree

1 file changed

+51
-35
lines changed

1 file changed

+51
-35
lines changed

paddlenlp/trainer/plugins/unified_checkpoint.py

Lines changed: 51 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -115,44 +115,47 @@ def save_unified_checkpoint(args, model, optimizer, output_dir, safe_serializati
115115
else:
116116
raise ValueError("Unified checkpoint only supports PretrainedModel")
117117

118+
skip_save_model_weight = False
118119
if UnifiedCheckpointOption.SKIP_SAVE_MODEL_WEIGHT.value in args.unified_checkpoint_config:
119120
if is_need_master_weight(optimizer, is_fp16_or_bp16=(args.fp16 or args.bf16)):
120121
logger.info(
121122
f"With {UnifiedCheckpointOption.SKIP_SAVE_MODEL_WEIGHT.value}, skip the model checkpoint save."
122123
"The master weight will be loaded as model weights for next resumption."
123124
)
124125
# not save model weight, load from master weight
125-
return
126-
config_to_save = None
127-
state_dict, config_to_save, shard_file, sharded_index = unified_checkpoint_into_shards(
128-
args, model_to_save, safe_serialization=safe_serialization
129-
)
126+
skip_save_model_weight = True
130127

131128
save_directory = output_dir
132129
os.makedirs(save_directory, exist_ok=True)
133130

134-
is_sync_save = True
135-
if "async_save" in args.unified_checkpoint_config:
136-
is_sync_save = False
137-
file_save_async_or_sync(
138-
state_dict, os.path.join(save_directory, shard_file), safe_serialization, is_sync=is_sync_save
139-
)
131+
# save model weights
132+
if not skip_save_model_weight:
133+
state_dict, shard_file, sharded_index = unified_checkpoint_into_shards(
134+
args, model_to_save, safe_serialization=safe_serialization
135+
)
136+
is_sync_save = True
137+
if "async_save" in args.unified_checkpoint_config:
138+
is_sync_save = False
139+
file_save_async_or_sync(
140+
state_dict, os.path.join(save_directory, shard_file), safe_serialization, is_sync=is_sync_save
141+
)
142+
143+
if sharded_index is not None:
144+
if not safe_serialization:
145+
path = os.path.join(output_dir, PADDLE_WEIGHTS_INDEX_NAME)
146+
else:
147+
path = os.path.join(output_dir, SAFE_WEIGHTS_INDEX_NAME)
148+
149+
with open(path, "w") as f:
150+
json.dump(sharded_index, f, indent=4)
140151

152+
# save the config
153+
config_to_save = save_config(model_to_save)
141154
# Attach architecture to the config
142155
config_to_save.architectures = [model_to_save.__class__.__name__]
143-
# Save the config
144156
if args.should_save:
145157
config_to_save.save_pretrained(save_directory)
146158

147-
if sharded_index is not None:
148-
if not safe_serialization:
149-
path = os.path.join(output_dir, PADDLE_WEIGHTS_INDEX_NAME)
150-
else:
151-
path = os.path.join(output_dir, SAFE_WEIGHTS_INDEX_NAME)
152-
153-
with open(path, "w") as f:
154-
json.dump(sharded_index, f, indent=4)
155-
156159

157160
def load_unified_checkpoint(args, model, optimizer, resume_from_checkpoint: str, safe_serialization=False) -> None:
158161
"""Load potential model checkpoint
@@ -252,6 +255,18 @@ def _remove_unused_keys(
252255
raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
253256

254257

258+
def save_config(model_to_save):
259+
dtype = get_parameter_dtype(model_to_save)
260+
model_to_save.config.dtype = str(dtype).split(".")[1]
261+
config_to_save = copy.deepcopy(model_to_save.config)
262+
263+
if config_to_save.tensor_parallel_degree > 1:
264+
# do we need to change?
265+
config_to_save.tensor_parallel_degree = 1
266+
267+
return config_to_save
268+
269+
255270
def unified_checkpoint_into_shards(
256271
args,
257272
model_to_save,
@@ -272,8 +287,6 @@ def unified_checkpoint_into_shards(
272287

273288
all_filter_keys = filter_params(model_to_save, state_dict)
274289

275-
dtype = get_parameter_dtype(model_to_save)
276-
model_to_save.config.dtype = str(dtype).split(".")[1]
277290
config_to_save = copy.deepcopy(model_to_save.config)
278291

279292
if config_to_save.tensor_parallel_degree > 1:
@@ -282,10 +295,6 @@ def unified_checkpoint_into_shards(
282295
)
283296
state_dict = merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys)
284297

285-
if config_to_save.tensor_parallel_degree > 1:
286-
# do we need to change?
287-
config_to_save.tensor_parallel_degree = 1
288-
289298
# build index json file
290299
index_weight_file = {}
291300
total_size = 0
@@ -302,7 +311,7 @@ def unified_checkpoint_into_shards(
302311
total_size_list,
303312
)
304313

305-
return state_dict, config_to_save, shard_file, sharded_index
314+
return state_dict, shard_file, sharded_index
306315

307316

308317
def save_unified_optimizer(args, model, optimizer, output_dir, safe_serialization=False):
@@ -343,16 +352,17 @@ def save_unified_optimizer(args, model, optimizer, output_dir, safe_serializatio
343352
)
344353

345354
if sharded_optim_index is not None:
346-
if not safe_serialization:
347-
path = os.path.join(output_dir, PADDLE_OPTIMIZER_INDEX_NAME)
348-
master_path = os.path.join(output_dir, PADDLE_MASTER_WEIGHTS_INDEX_NAME)
349-
else:
350-
path = os.path.join(output_dir, SAFE_OPTIMIZER_INDEX_NAME)
351-
master_path = os.path.join(output_dir, SAFE_MASTER_WEIGHTS_INDEX_NAME)
352-
355+
optimizer_index_name = SAFE_OPTIMIZER_INDEX_NAME if safe_serialization else PADDLE_OPTIMIZER_INDEX_NAME
356+
path = os.path.join(output_dir, optimizer_index_name)
353357
with open(path, "w") as f:
354358
json.dump(sharded_optim_index, f, indent=4)
355359

360+
master_weights_name = (
361+
SAFE_MASTER_WEIGHTS_INDEX_NAME if safe_serialization else PADDLE_MASTER_WEIGHTS_INDEX_NAME
362+
)
363+
if UnifiedCheckpointOption.SKIP_SAVE_MODEL_WEIGHT.value in args.unified_checkpoint_config:
364+
master_weights_name = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else PADDLE_WEIGHTS_INDEX_NAME
365+
master_path = os.path.join(output_dir, master_weights_name)
356366
if master_weight_state_dict is not None:
357367
with open(master_path, "w") as f:
358368
json.dump(sharded_master_weight_index, f, indent=4)
@@ -561,6 +571,8 @@ def unified_optimizer_into_shards(
561571
total_optim_size, total_master_weight_size = 0, 0
562572
optimizer_name = SAFE_OPTIMIZER_NAME if safe_serialization else PADDLE_OPTIMIZER_NAME
563573
master_weights_name = SAFE_MASTER_WEIGHTS_NAME if safe_serialization else PADDLE_MASTER_WEIGHTS_NAME
574+
if UnifiedCheckpointOption.SKIP_SAVE_MODEL_WEIGHT.value in args.unified_checkpoint_config:
575+
master_weights_name = SAFE_WEIGHTS_NAME if safe_serialization else PADDLE_WEIGHTS_NAME
564576
shard_optimizer_file = get_sharded_file_name(args, optimizer_name, is_optimizer=True)
565577
shard_master_weight_file = get_sharded_file_name(args, master_weights_name, is_optimizer=True)
566578

@@ -1648,6 +1660,10 @@ def update_master_weight_status(args, optimizer, has_master_weight, safe_seriali
16481660
index_filename_master_weights = (
16491661
PADDLE_MASTER_WEIGHTS_INDEX_NAME if not safe_serialization else SAFE_MASTER_WEIGHTS_INDEX_NAME
16501662
)
1663+
if UnifiedCheckpointOption.SKIP_SAVE_MODEL_WEIGHT.value in args.unified_checkpoint_config:
1664+
index_filename_master_weights = (
1665+
PADDLE_WEIGHTS_INDEX_NAME if not safe_serialization else SAFE_WEIGHTS_INDEX_NAME
1666+
)
16511667
else:
16521668
has_master_weight = False
16531669
index_filename_master_weights = None

0 commit comments

Comments
 (0)