Skip to content

Commit fc25d32

Browse files
authored
Fix ckpt convert bug1 (#9522)
* refine log * refine * fix * fix
1 parent 741785a commit fc25d32

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

paddlenlp/trainer/utils/ckpt_converter.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ def gen_metadata_and_prepare_source_state_dict(self):
269269

270270
malloc_size = 0
271271
for opt_state_name, opt_state_value in optimizer_state_dict.items():
272-
malloc_size += opt_state_value.numel() * opt_state_value.element_size()
272+
malloc_size += opt_state_value.numel().numpy() * opt_state_value.element_size()
273273
malloc_size = malloc_size / 2**20
274274
logger.debug(f"{malloc_size} MB of GPU memory were allocated.")
275275

@@ -529,6 +529,7 @@ def load_state_dict_and_rename(self):
529529
rank_access_files[self.cur_rank] = self.cur_rank_optimizer_state_file_names
530530

531531
global_rank_access_files = self.gather_global_object(rank_access_files)
532+
logger.info(f"The file(s) to be loaded for the global rank are: {global_rank_access_files}")
532533
need_read_files = get_rank_to_read_files(global_rank_access_files, global_rank_access_files)
533534
logger.info(f"The file(s) to be loaded for the current rank are: {need_read_files}")
534535
self.cur_rank_loaded_state_dict = {}
@@ -553,8 +554,7 @@ def load_state_dict_and_rename(self):
553554
memory_size = 0
554555
for file, state_dict in self.cur_rank_loaded_state_dict.items():
555556
for k, v in state_dict.items():
556-
memory_size += v.numel() * v.element_size()
557-
557+
memory_size += v.numel().numpy() * v.element_size()
558558
memory_size = memory_size / 2**20
559559
logger.debug(
560560
f"The current rank has finished loading the checkpoint file and has allocated {memory_size} MB of GPU memory."

scripts/distribute/ci_case_auto.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ function llama_case_list_auto() {
9595
llama_dy2st_auto_bs4_bf16_DP1-MP1-PP4-SD2
9696
llama_align_dygraph_dy2st_auto_bs2_bf16_DP2-MP1-PP1
9797
llama_pir_auto_fuse_ffn_attention_qkv_MP2
98+
llama_convert_hybrid_ckpt_to_auto_parallel_bs2_fp32_DP2-MP1-PP1
9899
llama_align_dygraph_dy2st_pir_auto_bs2_bf16_DP2-MP2-PP1-SP
99100
llama_align_dygraph_dy2st_pir_auto_bs2_bf16_DP2-MP2-PP2-SP
100101
llama_align_dygraph_dy2st_pir_auto_grad_merge_bs2_fp32_DP1-MP1-PP1
@@ -1420,6 +1421,7 @@ function llama_convert_hybrid_ckpt_to_auto_parallel_bs2_fp32_DP2-MP1-PP1() {
14201421
--sharding "" \
14211422
--to_static 0 \
14221423
--num_hidden_layers 2 \
1424+
--unified_checkpoint false \
14231425
>>${log_path}/$FUNCNAME 2>&1
14241426
dy_loss=`cat $dy_case_log_dir/workerlog.0 | grep 'global_step: 4' | awk -F 'loss: ' '{print $2}' | awk -F ',' '{print $1}'`
14251427
dy_ips=-1

0 commit comments

Comments
 (0)