|
19 | 19 |
|
20 | 20 |
|
21 | 21 | import modelopt.torch.peft as mtpf
|
| 22 | +import modelopt.torch.quantization as mtq |
22 | 23 | from modelopt.torch.peft.lora.layer import LoRAModule
|
23 | 24 | from modelopt.torch.utils.plugins import megatron_prefill
|
24 | 25 |
|
| 26 | +NVFP4_DEFAULT_CONFIG = { |
| 27 | + "quant_cfg": { |
| 28 | + "*weight_quantizer": { |
| 29 | + "num_bits": (2, 1), |
| 30 | + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, |
| 31 | + "axis": None, |
| 32 | + "enable": True, |
| 33 | + }, |
| 34 | + "*input_quantizer": { |
| 35 | + "num_bits": (2, 1), |
| 36 | + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, |
| 37 | + "axis": None, |
| 38 | + "enable": True, |
| 39 | + }, |
| 40 | + "*output_quantizer": {"enable": False}, |
| 41 | + "*output_layer*": {"enable": False}, # Note: only output_layer is disabled. |
| 42 | + "default": {"enable": False}, |
| 43 | + }, |
| 44 | + "algorithm": "max", |
| 45 | +} |
| 46 | + |
25 | 47 | DEFAULT_LORA_CFG_TEST = {
|
26 | 48 | "adapter_type": "lora",
|
27 | 49 | "adapter_name": "default",
|
@@ -333,7 +355,7 @@ def test_attr_changes_with_one_lora(lora_config):
|
333 | 355 |
|
334 | 356 |
|
335 | 357 | def _test_mcore_save_restore(lora_config, tmp_path, rank, size):
|
336 |
| - hidden_size = 1280 |
| 358 | + hidden_size = 512 |
337 | 359 | initialize_for_megatron(tensor_model_parallel_size=size, pipeline_model_parallel_size=1)
|
338 | 360 | model_ref = _gpt_model_provider(tp_size=size, hidden_size=hidden_size)
|
339 | 361 | model_test = _gpt_model_provider(tp_size=size, hidden_size=hidden_size)
|
@@ -375,11 +397,8 @@ def test_mcore_save_restore(device_count, lora_config, tmp_path):
|
375 | 397 | )
|
376 | 398 |
|
377 | 399 |
|
378 |
| -# TODO: Save and restore 2 loras |
379 |
| - |
380 |
| - |
381 | 400 | def _test_adapter_gradient_flow_freeze_base_model(lora_config, tmp_path, rank, size):
|
382 |
| - hidden_size = 1280 |
| 401 | + hidden_size = 512 |
383 | 402 | initialize_for_megatron(tensor_model_parallel_size=size, pipeline_model_parallel_size=1)
|
384 | 403 | model = _gpt_model_provider(tp_size=size, hidden_size=hidden_size)
|
385 | 404 | prompt_tokens = torch.randint(0, model.vocab_size, (2, model.max_sequence_length)).cuda()
|
@@ -441,7 +460,7 @@ def test_adapter_gradient_flow_freeze_base_model(device_count, lora_config, tmp_
|
441 | 460 |
|
442 | 461 |
|
443 | 462 | def _test_adapter_gradient_flow_freeze_lora_model(lora_config, tmp_path, rank, size):
|
444 |
| - hidden_size = 1280 |
| 463 | + hidden_size = 512 |
445 | 464 | lora_config["freeze_lora_weights"] = True
|
446 | 465 | lora_config["freeze_base_model"] = False
|
447 | 466 |
|
@@ -502,7 +521,7 @@ def test_adapter_gradient_flow_freeze_lora_model(device_count, lora_config, tmp_
|
502 | 521 |
|
503 | 522 |
|
504 | 523 | def _test_adapter_gradient_flow(lora_config, tmp_path, rank, size):
|
505 |
| - hidden_size = 1280 |
| 524 | + hidden_size = 512 |
506 | 525 | lora_config["freeze_lora_weights"] = False
|
507 | 526 | lora_config["freeze_base_model"] = False
|
508 | 527 |
|
@@ -566,3 +585,249 @@ def test_adapter_gradient_flow(device_count, lora_config, tmp_path):
|
566 | 585 | job=partial(_test_adapter_gradient_flow, lora_config, str(tmp_path)),
|
567 | 586 | backend="nccl",
|
568 | 587 | )
|
| 588 | + |
| 589 | + |
| 590 | +def _test_quantize_then_lora(lora_config, tmp_path, rank, size): |
| 591 | + hidden_size = 512 |
| 592 | + initialize_for_megatron(tensor_model_parallel_size=size, pipeline_model_parallel_size=1) |
| 593 | + model = _gpt_model_provider(tp_size=size, hidden_size=hidden_size) |
| 594 | + prompt_tokens = torch.randint(0, model.vocab_size, (2, model.max_sequence_length)).cuda() |
| 595 | + |
| 596 | + def forward_func(mod): |
| 597 | + _ = megatron_prefill(model, prompt_tokens) |
| 598 | + |
| 599 | + mtq.quantize(model, NVFP4_DEFAULT_CONFIG, forward_func) |
| 600 | + |
| 601 | + # Then add the lora |
| 602 | + mtpf.update_model(model, lora_config) |
| 603 | + |
| 604 | + # Bypass the output layer |
| 605 | + for name, module in model.named_modules(): |
| 606 | + if isinstance(module, LoRAModule) and "output_layer" not in name: |
| 607 | + assert hasattr(module, "input_quantizer") |
| 608 | + assert hasattr(module, "weight_quantizer") |
| 609 | + assert hasattr(module.input_quantizer, "amax") |
| 610 | + assert hasattr(module.weight_quantizer, "amax") |
| 611 | + assert getattr(module.input_quantizer, "amax") is not None |
| 612 | + assert getattr(module.weight_quantizer, "amax") is not None |
| 613 | + # Check if the lora have teh quantizer, they should not have them. |
| 614 | + for adapter_name in module._lora_adapters: |
| 615 | + lora_a = module._lora_adapters[adapter_name]["lora_a"] |
| 616 | + lora_b = module._lora_adapters[adapter_name]["lora_b"] |
| 617 | + assert not hasattr(lora_a, "input_quantizer") |
| 618 | + assert not hasattr(lora_b, "weight_quantizer") |
| 619 | + |
| 620 | + quantized_lora_output = megatron_prefill(model, prompt_tokens) |
| 621 | + mtq.disable_quantizer(model, "*") |
| 622 | + unquantized_lora_output = megatron_prefill(model, prompt_tokens) |
| 623 | + # Task: Quantize and unquantize should produce different tensor values |
| 624 | + assert not torch.allclose(quantized_lora_output, unquantized_lora_output) |
| 625 | + |
| 626 | + |
| 627 | +@pytest.mark.parametrize("device_count", get_device_counts()) |
| 628 | +@pytest.mark.parametrize( |
| 629 | + "lora_config", |
| 630 | + [ |
| 631 | + LARGE_LORA_CFG_RANDOM_INIT_TEST, # Use random init so gradients flow to both lora_a and lora_b |
| 632 | + ], |
| 633 | +) |
| 634 | +def test_quantize_then_lora(device_count, lora_config, tmp_path): |
| 635 | + spawn_multiprocess_job( |
| 636 | + size=device_count, |
| 637 | + job=partial(_test_quantize_then_lora, lora_config, str(tmp_path)), |
| 638 | + backend="nccl", |
| 639 | + ) |
| 640 | + |
| 641 | + |
| 642 | +def _test_lora_then_quantize(lora_config, tmp_path, rank, size): |
| 643 | + hidden_size = 512 |
| 644 | + initialize_for_megatron(tensor_model_parallel_size=size, pipeline_model_parallel_size=1) |
| 645 | + model = _gpt_model_provider(tp_size=size, hidden_size=hidden_size) |
| 646 | + prompt_tokens = torch.randint(0, model.vocab_size, (2, model.max_sequence_length)).cuda() |
| 647 | + |
| 648 | + mtpf.update_model(model, lora_config) |
| 649 | + lora_output = megatron_prefill(model, prompt_tokens) |
| 650 | + |
| 651 | + def forward_func(mod): |
| 652 | + _ = megatron_prefill(model, prompt_tokens) |
| 653 | + |
| 654 | + mtq.quantize(model, NVFP4_DEFAULT_CONFIG, forward_func) |
| 655 | + quantized_output = megatron_prefill(model, prompt_tokens) |
| 656 | + # Bypass the output layer |
| 657 | + for name, module in model.named_modules(): |
| 658 | + if isinstance(module, LoRAModule) and "output_layer" not in name: |
| 659 | + assert hasattr(module, "input_quantizer") |
| 660 | + assert hasattr(module, "weight_quantizer") |
| 661 | + assert hasattr(module.input_quantizer, "amax") |
| 662 | + assert hasattr(module.weight_quantizer, "amax") |
| 663 | + assert getattr(module.input_quantizer, "amax") is not None |
| 664 | + assert getattr(module.weight_quantizer, "amax") is not None |
| 665 | + # Check if the lora have teh quantizer, they should not have them. |
| 666 | + for adapter_name in module._lora_adapters: |
| 667 | + lora_a = module._lora_adapters[adapter_name]["lora_a"] |
| 668 | + lora_b = module._lora_adapters[adapter_name]["lora_b"] |
| 669 | + assert hasattr(lora_a, "input_quantizer") |
| 670 | + assert hasattr(lora_b, "weight_quantizer") |
| 671 | + assert hasattr(lora_a.input_quantizer, "amax") |
| 672 | + assert hasattr(lora_b.weight_quantizer, "amax") |
| 673 | + assert getattr(lora_a.input_quantizer, "amax") is not None |
| 674 | + assert getattr(lora_b.weight_quantizer, "amax") is not None |
| 675 | + |
| 676 | + assert not torch.allclose(lora_output, quantized_output) |
| 677 | + |
| 678 | + mtq.disable_quantizer(model, "*lora_a*") |
| 679 | + disabled_lora_a_quantized_output = megatron_prefill(model, prompt_tokens) |
| 680 | + # Should not be the same since we disable the lora_a quantizers |
| 681 | + assert not torch.allclose(disabled_lora_a_quantized_output, quantized_output) |
| 682 | + |
| 683 | + mtq.disable_quantizer(model, "*lora_b*") |
| 684 | + disabled_lora_ab_quantized_output = megatron_prefill(model, prompt_tokens) |
| 685 | + assert not torch.allclose(disabled_lora_a_quantized_output, disabled_lora_ab_quantized_output) |
| 686 | + assert not torch.allclose(quantized_output, disabled_lora_ab_quantized_output) |
| 687 | + |
| 688 | + |
| 689 | +@pytest.mark.parametrize("device_count", get_device_counts()) |
| 690 | +@pytest.mark.parametrize( |
| 691 | + "lora_config", |
| 692 | + [ |
| 693 | + LARGE_LORA_CFG_RANDOM_INIT_TEST, # Use random init so gradients flow to both lora_a and lora_b |
| 694 | + ], |
| 695 | +) |
| 696 | +def test_lora_then_quantize(device_count, lora_config, tmp_path): |
| 697 | + spawn_multiprocess_job( |
| 698 | + size=device_count, |
| 699 | + job=partial(_test_lora_then_quantize, lora_config, str(tmp_path)), |
| 700 | + backend="nccl", |
| 701 | + ) |
| 702 | + |
| 703 | + |
| 704 | +def _test_mcore_quantize_then_lora_save_restore(lora_config, tmp_path, rank, size): |
| 705 | + hidden_size = 512 |
| 706 | + initialize_for_megatron(tensor_model_parallel_size=size, pipeline_model_parallel_size=1) |
| 707 | + model_ref = _gpt_model_provider(tp_size=size, hidden_size=hidden_size) |
| 708 | + model_test = _gpt_model_provider(tp_size=size, hidden_size=hidden_size) |
| 709 | + prompt_tokens = torch.randint( |
| 710 | + 0, model_ref.vocab_size, (2, model_ref.max_sequence_length) |
| 711 | + ).cuda() |
| 712 | + original_output_test = megatron_prefill(model_test, prompt_tokens) |
| 713 | + |
| 714 | + def forward_func(mod): |
| 715 | + _ = megatron_prefill(model_ref, prompt_tokens) |
| 716 | + |
| 717 | + mtq.quantize(model_ref, NVFP4_DEFAULT_CONFIG, forward_func) |
| 718 | + mtpf.update_model(model_ref, lora_config) |
| 719 | + |
| 720 | + quantize_lora_output_ref = megatron_prefill(model_ref, prompt_tokens) |
| 721 | + |
| 722 | + save_distributed_checkpoint(tmp_path, model_ref) |
| 723 | + save_sharded_modelopt_state([model_ref], tmp_path) |
| 724 | + |
| 725 | + restore_sharded_modelopt_state([model_test], tmp_path) |
| 726 | + model_test = load_distributed_checkpoint(tmp_path, model_test) |
| 727 | + |
| 728 | + quantize_lora_output_test = megatron_prefill(model_test, prompt_tokens) |
| 729 | + |
| 730 | + # Task: If the save and restore functions work correctly, they should produce the same output. |
| 731 | + assert torch.allclose(quantize_lora_output_test, quantize_lora_output_ref) |
| 732 | + |
| 733 | + assert not torch.allclose(original_output_test, quantize_lora_output_test) |
| 734 | + |
| 735 | + # Check the quantizer and lora layers after restore |
| 736 | + for name, module in model_test.named_modules(): |
| 737 | + if isinstance(module, LoRAModule) and "output_layer" not in name: |
| 738 | + # print(f"{name} {module}") |
| 739 | + assert hasattr(module, "input_quantizer") |
| 740 | + assert hasattr(module, "weight_quantizer") |
| 741 | + assert hasattr(module.input_quantizer, "amax") |
| 742 | + assert hasattr(module.weight_quantizer, "amax") |
| 743 | + assert getattr(module.input_quantizer, "amax") is not None |
| 744 | + assert getattr(module.weight_quantizer, "amax") is not None |
| 745 | + # Check if the lora have teh quantizer, they should not have them. |
| 746 | + for adapter_name in module._lora_adapters: |
| 747 | + lora_a = module._lora_adapters[adapter_name]["lora_a"] |
| 748 | + lora_b = module._lora_adapters[adapter_name]["lora_b"] |
| 749 | + assert not hasattr(lora_a, "input_quantizer") |
| 750 | + assert not hasattr(lora_b, "weight_quantizer") |
| 751 | + |
| 752 | + |
| 753 | +@pytest.mark.parametrize("device_count", get_device_counts()) |
| 754 | +@pytest.mark.parametrize( |
| 755 | + "lora_config", |
| 756 | + [ |
| 757 | + DEFAULT_LORA_CFG_RANDOM_INIT_TEST, |
| 758 | + ], |
| 759 | +) |
| 760 | +def test_mcore_quantize_then_lora_save_restore(device_count, lora_config, tmp_path): |
| 761 | + spawn_multiprocess_job( |
| 762 | + size=device_count, |
| 763 | + job=partial(_test_mcore_quantize_then_lora_save_restore, lora_config, str(tmp_path)), |
| 764 | + backend="nccl", |
| 765 | + ) |
| 766 | + |
| 767 | + |
| 768 | +def _test_mcore_lora_then_quantize_save_restore(lora_config, tmp_path, rank, size): |
| 769 | + hidden_size = 512 |
| 770 | + initialize_for_megatron(tensor_model_parallel_size=size, pipeline_model_parallel_size=1) |
| 771 | + model_ref = _gpt_model_provider(tp_size=size, hidden_size=hidden_size) |
| 772 | + model_test = _gpt_model_provider(tp_size=size, hidden_size=hidden_size) |
| 773 | + prompt_tokens = torch.randint( |
| 774 | + 0, model_ref.vocab_size, (2, model_ref.max_sequence_length) |
| 775 | + ).cuda() |
| 776 | + original_output_test = megatron_prefill(model_test, prompt_tokens) |
| 777 | + |
| 778 | + mtpf.update_model(model_ref, lora_config) |
| 779 | + |
| 780 | + def forward_func(mod): |
| 781 | + _ = megatron_prefill(model_ref, prompt_tokens) |
| 782 | + |
| 783 | + mtq.quantize(model_ref, NVFP4_DEFAULT_CONFIG, forward_func) |
| 784 | + |
| 785 | + lora_quantize_output_ref = megatron_prefill(model_ref, prompt_tokens) |
| 786 | + |
| 787 | + save_distributed_checkpoint(tmp_path, model_ref) |
| 788 | + save_sharded_modelopt_state([model_ref], tmp_path) |
| 789 | + |
| 790 | + restore_sharded_modelopt_state([model_test], tmp_path) |
| 791 | + model_test = load_distributed_checkpoint(tmp_path, model_test) |
| 792 | + |
| 793 | + lora_quantize_output_test = megatron_prefill(model_test, prompt_tokens) |
| 794 | + |
| 795 | + # Task: If the save and restore functions work correctly, they should produce the same output. |
| 796 | + assert torch.allclose(lora_quantize_output_test, lora_quantize_output_ref) |
| 797 | + |
| 798 | + assert not torch.allclose(original_output_test, lora_quantize_output_test) |
| 799 | + |
| 800 | + # Check the lora and quantize layers after restore |
| 801 | + for name, module in model_test.named_modules(): |
| 802 | + if isinstance(module, LoRAModule) and "output_layer" not in name: |
| 803 | + assert hasattr(module, "input_quantizer") |
| 804 | + assert hasattr(module, "weight_quantizer") |
| 805 | + assert hasattr(module.input_quantizer, "amax") |
| 806 | + assert hasattr(module.weight_quantizer, "amax") |
| 807 | + assert getattr(module.input_quantizer, "amax") is not None |
| 808 | + assert getattr(module.weight_quantizer, "amax") is not None |
| 809 | + # Check if the lora have teh quantizer, they should not have them. |
| 810 | + for adapter_name in module._lora_adapters: |
| 811 | + lora_a = module._lora_adapters[adapter_name]["lora_a"] |
| 812 | + lora_b = module._lora_adapters[adapter_name]["lora_b"] |
| 813 | + assert hasattr(lora_a, "input_quantizer") |
| 814 | + assert hasattr(lora_b, "weight_quantizer") |
| 815 | + assert hasattr(lora_a.input_quantizer, "amax") |
| 816 | + assert hasattr(lora_b.weight_quantizer, "amax") |
| 817 | + assert getattr(lora_a.input_quantizer, "amax") is not None |
| 818 | + assert getattr(lora_b.weight_quantizer, "amax") is not None |
| 819 | + |
| 820 | + |
| 821 | +@pytest.mark.parametrize("device_count", get_device_counts()) |
| 822 | +@pytest.mark.parametrize( |
| 823 | + "lora_config", |
| 824 | + [ |
| 825 | + DEFAULT_LORA_CFG_RANDOM_INIT_TEST, |
| 826 | + ], |
| 827 | +) |
| 828 | +def test_mcore_lora_quantize_save_restore(device_count, lora_config, tmp_path): |
| 829 | + spawn_multiprocess_job( |
| 830 | + size=device_count, |
| 831 | + job=partial(_test_mcore_lora_then_quantize_save_restore, lora_config, str(tmp_path)), |
| 832 | + backend="nccl", |
| 833 | + ) |
0 commit comments