@@ -464,18 +464,91 @@ def forward_func(mod):
464
464
print (
465
465
f"LoRA forward pass successful! Output shape: { output_lora_quant .shape if hasattr (output_lora_quant , 'shape' ) else 'N/A' } "
466
466
)
467
+ print (model )
467
468
print ("Test passed!" )
468
469
finally :
469
470
# Clean up model parallel groups
470
471
parallel_state .destroy_model_parallel ()
471
472
472
473
473
474
def _test_quantize_then_lora_save_restore ():
474
- pass
475
+ initialize_for_megatron (tensor_model_parallel_size = 1 , pipeline_model_parallel_size = 1 , seed = 1234 )
476
+
477
+ try :
478
+ model_ref = _gpt_model_provider (tp_size = 1 )
479
+ model_test = _gpt_model_provider (tp_size = 1 )
480
+ prompt_tokens = torch .randint (
481
+ 0 , model_test .vocab_size , (2 , model_test .max_sequence_length )
482
+ ).cuda ()
483
+
484
+ def forward_func (mod ):
485
+ output = megatron_prefill (model_ref , prompt_tokens )
486
+
487
+ mtq .quantize (model_ref , mtq .FP8_DEFAULT_CFG , forward_func )
488
+ lora_config = {
489
+ "adapter_type" : "lora" ,
490
+ "adapter_name" : "default" ,
491
+ "adapter_cfg" : {
492
+ "*attention*" : {"rank" : 32 , "scale" : 1 },
493
+ "*mlp*" : {"rank" : 64 , "scale" : 1 },
494
+ },
495
+ }
496
+ model_ref = mtp .update_model (model_ref , lora_config )
497
+ tmp_path = "./model_ref"
498
+ save_distributed_checkpoint (tmp_path , model_ref )
499
+ save_sharded_modelopt_state ([model_ref ], tmp_path )
500
+ restore_sharded_modelopt_state ([model_test ], tmp_path )
501
+ model_test = load_distributed_checkpoint (tmp_path , model_test )
502
+ # Run forward pass
503
+ output_test = megatron_prefill (model_test , prompt_tokens )
504
+ output_ref = megatron_prefill (model_ref , prompt_tokens )
505
+ print (
506
+ f"Forward pass successful! Output shape: { output_test .shape if hasattr (output_test , 'shape' ) else 'N/A' } "
507
+ )
508
+ print (model_ref )
509
+ print (f"output_test: { output_test } " )
510
+ print (f"output_ref: { output_ref } " )
511
+
512
+ finally :
513
+ # Clean up model parallel groups
514
+ parallel_state .destroy_model_parallel ()
475
515
476
516
477
517
def _test_lora_then_quantize ():
478
- pass
518
+ initialize_for_megatron (tensor_model_parallel_size = 1 , pipeline_model_parallel_size = 1 , seed = 1234 )
519
+
520
+ try :
521
+ model = _gpt_model_provider (tp_size = 1 )
522
+ prompt_tokens = torch .randint (0 , model .vocab_size , (2 , model .max_sequence_length )).cuda ()
523
+ lora_config = {
524
+ "adapter_type" : "lora" ,
525
+ "adapter_name" : "default" ,
526
+ "adapter_cfg" : {
527
+ "*attention*" : {"rank" : 32 , "scale" : 1 },
528
+ "*mlp*" : {"rank" : 64 , "scale" : 1 },
529
+ },
530
+ }
531
+
532
+ def forward_func (mod ):
533
+ output = megatron_prefill (model , prompt_tokens )
534
+
535
+ model = mtp .update_model (model , lora_config )
536
+ mtq .quantize (model , mtq .FP8_DEFAULT_CFG , forward_func )
537
+ lora_count = 0
538
+ for name , module in model .named_modules ():
539
+ if hasattr (module , "_lora_adapters" ):
540
+ lora_count += 1
541
+ print (f"LoRA module found: { name } " )
542
+ print (f"\n Total LoRA modules: { lora_count } " )
543
+ output_lora_quant = megatron_prefill (model , prompt_tokens )
544
+ print (
545
+ f"LoRA forward pass successful! Output shape: { output_lora_quant .shape if hasattr (output_lora_quant , 'shape' ) else 'N/A' } "
546
+ )
547
+ print ("Test passed!" )
548
+ print (model )
549
+ finally :
550
+ # Clean up model parallel groups
551
+ parallel_state .destroy_model_parallel ()
479
552
480
553
481
554
def _test_lora_then_quantize_save_restore ():
@@ -520,7 +593,9 @@ def main():
520
593
# _test_lora_save_and_restore()
521
594
# _test_lora_add_2nd_lora()
522
595
# _test_lora_save_and_restore_with2loras()
523
- _test_quantize_then_lora ()
596
+ # _test_quantize_then_lora()
597
+ # _test_quantize_then_lora_save_restore()
598
+ _test_lora_then_quantize ()
524
599
finally :
525
600
if torch .distributed .is_initialized ():
526
601
torch .distributed .destroy_process_group ()
0 commit comments