@@ -367,6 +367,32 @@ def get_learning_rate_scheduler(optimizer):
367367 return lr_scheduler
368368
369369
370+ def sync_hp_to_lp (optimizer ):
371+
372+ optimizer .update_lp_params ()
373+
374+ # for n,p in model.named_parameters():
375+ # print(n)
376+
377+ # if p._hp_mapping is not None:
378+ # #print(f'rank {rank} fixing hp for input_layernorm')
379+ # #p._hp_mapping.update_hp()
380+
381+ # hp = p._hp_mapping.hp_fragment
382+
383+
384+
385+ # torch.distributed.all_reduce(hp, op=torch.distributed.ReduceOp.AVG, group=mpu.get_tensor_model_parallel_group())
386+
387+ # # 3. optim states
388+ # for key in ['exp_avg', 'exp_avg_sq']:
389+ # optim_state_fragment = p._hp_mapping.get_optim_state_fragment(key)
390+ # #print(f'rank {rank} before reduce optim state fragment {key} = {optim_state_fragment}')
391+ # torch.distributed.all_reduce(optim_state_fragment, op=torch.distributed.ReduceOp.AVG, group=mpu.get_tensor_model_parallel_group())
392+ # #print(f'rank {rank} after reduce optim state fragment {key} = {optim_state_fragment}')
393+
394+
395+
370396def setup_model_and_optimizer (model_provider_func ):
371397 """Setup model and optimizer."""
372398 args = get_args ()
@@ -386,12 +412,21 @@ def setup_model_and_optimizer(model_provider_func):
386412
387413 if args .deepspeed :
388414 print_rank_0 ("DeepSpeed is enabled." )
389- pp = mpu .get_pipeline_model_parallel_world_size ()
415+ #pp = mpu.get_pipeline_model_parallel_world_size()
416+
417+ import json
418+ import io
419+ with io .open (args .deepspeed_config , "r" , encoding = "utf-8" ) as f :
420+ config = json .load (f )
421+ if args .universal_checkpoint :
422+ config ["checkpoint" ] = {"load_universal" : True }
423+
390424 model , optimizer , _ , lr_scheduler = deepspeed .initialize (
391425 model = model [0 ],
392426 optimizer = optimizer ,
427+ lr_scheduler = lr_scheduler ,
428+ config = config ,
393429 args = args ,
394- lr_scheduler = lr_scheduler
395430 )
396431
397432 assert model .fp16_enabled () == args .fp16 , "megatron fp16 config does not match deepspeed"
@@ -416,8 +451,37 @@ def setup_model_and_optimizer(model_provider_func):
416451 torch .distributed .barrier ()
417452 timers ('load-checkpoint' ).stop ()
418453 timers .log (['load-checkpoint' ])
454+
455+
456+ # hp -> lp
457+ if args .deepspeed and args .universal_checkpoint :
458+ sync_hp_to_lp (optimizer )
459+
460+
419461 else :
420462 args .iteration = 0
463+
464+ from .utils import dump_weights
465+ dump_weights (f'{ args .universal_checkpoint = } ' , args .iteration , model , optimizer )
466+
467+ # tp_rank = mpu.get_tensor_model_parallel_rank()
468+ # pp_rank = mpu.get_pipeline_model_parallel_rank()
469+ # dp_rank = mpu.get_data_parallel_rank()
470+ # for n,p in model[0].named_parameters():
471+ # if 'word_embeddings.weight' not in n:
472+ # continue
473+ # if tp_rank == 0 and pp_rank == 0:
474+ # print(f"{tp_rank=}{pp_rank=}{dp_rank=} bf16 {n=} {p[:10]=}")
475+ # if p._hp_mapping is not None:
476+ # hp = p._hp_mapping.hp_fragment
477+ # print(f'{tp_rank=}{pp_rank=}{dp_rank=} fp32 {n=} {hp[:10]=}')
478+
479+ # if tp_rank == 0 and pp_rank == mpu.get_pipeline_model_parallel_world_size() - 1:
480+ # print(f"{tp_rank=}{pp_rank=}{dp_rank=} bf16 {n=} {p[:10]=}")
481+ # if p._hp_mapping is not None:
482+ # hp = p._hp_mapping.hp_fragment
483+ # print(f'{tp_rank=}{pp_rank=}{dp_rank=} fp32 {n=} {hp[:10]=}')
484+
421485
422486 # We only support local DDP with multiple micro-batches.
423487 if len (model ) > 1 or mpu .get_pipeline_model_parallel_world_size () > 1 :
0 commit comments