@@ -269,7 +269,7 @@ def gen_metadata_and_prepare_source_state_dict(self):
269
269
270
270
malloc_size = 0
271
271
for opt_state_name , opt_state_value in optimizer_state_dict .items ():
272
- malloc_size += opt_state_value .numel () * opt_state_value .element_size ()
272
+ malloc_size += opt_state_value .numel (). numpy () * opt_state_value .element_size ()
273
273
malloc_size = malloc_size / 2 ** 20
274
274
logger .debug (f"{ malloc_size } MB of GPU memory were allocated." )
275
275
@@ -529,6 +529,7 @@ def load_state_dict_and_rename(self):
529
529
rank_access_files [self .cur_rank ] = self .cur_rank_optimizer_state_file_names
530
530
531
531
global_rank_access_files = self .gather_global_object (rank_access_files )
532
+ logger .info (f"The file(s) to be loaded for the global rank are: { global_rank_access_files } " )
532
533
need_read_files = get_rank_to_read_files (global_rank_access_files , global_rank_access_files )
533
534
logger .info (f"The file(s) to be loaded for the current rank are: { need_read_files } " )
534
535
self .cur_rank_loaded_state_dict = {}
@@ -553,8 +554,7 @@ def load_state_dict_and_rename(self):
553
554
memory_size = 0
554
555
for file , state_dict in self .cur_rank_loaded_state_dict .items ():
555
556
for k , v in state_dict .items ():
556
- memory_size += v .numel () * v .element_size ()
557
-
557
+ memory_size += v .numel ().numpy () * v .element_size ()
558
558
memory_size = memory_size / 2 ** 20
559
559
logger .debug (
560
560
f"The current rank has finished loading the checkpoint file and has allocated { memory_size } MB of GPU memory."
0 commit comments