@@ -464,18 +464,91 @@ def forward_func(mod):
464464 print (
465465 f"LoRA forward pass successful! Output shape: { output_lora_quant .shape if hasattr (output_lora_quant , 'shape' ) else 'N/A' } "
466466 )
467+ print (model )
467468 print ("Test passed!" )
468469 finally :
469470 # Clean up model parallel groups
470471 parallel_state .destroy_model_parallel ()
471472
472473
473474def _test_quantize_then_lora_save_restore ():
474- pass
475+ initialize_for_megatron (tensor_model_parallel_size = 1 , pipeline_model_parallel_size = 1 , seed = 1234 )
476+
477+ try :
478+ model_ref = _gpt_model_provider (tp_size = 1 )
479+ model_test = _gpt_model_provider (tp_size = 1 )
480+ prompt_tokens = torch .randint (
481+ 0 , model_test .vocab_size , (2 , model_test .max_sequence_length )
482+ ).cuda ()
483+
484+ def forward_func (mod ):
485+ output = megatron_prefill (model_ref , prompt_tokens )
486+
487+ mtq .quantize (model_ref , mtq .FP8_DEFAULT_CFG , forward_func )
488+ lora_config = {
489+ "adapter_type" : "lora" ,
490+ "adapter_name" : "default" ,
491+ "adapter_cfg" : {
492+ "*attention*" : {"rank" : 32 , "scale" : 1 },
493+ "*mlp*" : {"rank" : 64 , "scale" : 1 },
494+ },
495+ }
496+ model_ref = mtp .update_model (model_ref , lora_config )
497+ tmp_path = "./model_ref"
498+ save_distributed_checkpoint (tmp_path , model_ref )
499+ save_sharded_modelopt_state ([model_ref ], tmp_path )
500+ restore_sharded_modelopt_state ([model_test ], tmp_path )
501+ model_test = load_distributed_checkpoint (tmp_path , model_test )
502+ # Run forward pass
503+ output_test = megatron_prefill (model_test , prompt_tokens )
504+ output_ref = megatron_prefill (model_ref , prompt_tokens )
505+ print (
506+ f"Forward pass successful! Output shape: { output_test .shape if hasattr (output_test , 'shape' ) else 'N/A' } "
507+ )
508+ print (model_ref )
509+ print (f"output_test: { output_test } " )
510+ print (f"output_ref: { output_ref } " )
511+
512+ finally :
513+ # Clean up model parallel groups
514+ parallel_state .destroy_model_parallel ()
475515
476516
477517def _test_lora_then_quantize ():
478- pass
518+ initialize_for_megatron (tensor_model_parallel_size = 1 , pipeline_model_parallel_size = 1 , seed = 1234 )
519+
520+ try :
521+ model = _gpt_model_provider (tp_size = 1 )
522+ prompt_tokens = torch .randint (0 , model .vocab_size , (2 , model .max_sequence_length )).cuda ()
523+ lora_config = {
524+ "adapter_type" : "lora" ,
525+ "adapter_name" : "default" ,
526+ "adapter_cfg" : {
527+ "*attention*" : {"rank" : 32 , "scale" : 1 },
528+ "*mlp*" : {"rank" : 64 , "scale" : 1 },
529+ },
530+ }
531+
532+ def forward_func (mod ):
533+ output = megatron_prefill (model , prompt_tokens )
534+
535+ model = mtp .update_model (model , lora_config )
536+ mtq .quantize (model , mtq .FP8_DEFAULT_CFG , forward_func )
537+ lora_count = 0
538+ for name , module in model .named_modules ():
539+ if hasattr (module , "_lora_adapters" ):
540+ lora_count += 1
541+ print (f"LoRA module found: { name } " )
542+ print (f"\n Total LoRA modules: { lora_count } " )
543+ output_lora_quant = megatron_prefill (model , prompt_tokens )
544+ print (
545+ f"LoRA forward pass successful! Output shape: { output_lora_quant .shape if hasattr (output_lora_quant , 'shape' ) else 'N/A' } "
546+ )
547+ print ("Test passed!" )
548+ print (model )
549+ finally :
550+ # Clean up model parallel groups
551+ parallel_state .destroy_model_parallel ()
479552
480553
481554def _test_lora_then_quantize_save_restore ():
@@ -520,7 +593,9 @@ def main():
520593 # _test_lora_save_and_restore()
521594 # _test_lora_add_2nd_lora()
522595 # _test_lora_save_and_restore_with2loras()
523- _test_quantize_then_lora ()
596+ # _test_quantize_then_lora()
597+ # _test_quantize_then_lora_save_restore()
598+ _test_lora_then_quantize ()
524599 finally :
525600 if torch .distributed .is_initialized ():
526601 torch .distributed .destroy_process_group ()
0 commit comments