|
46 | 46 | get_vulkan_quantizer, |
47 | 47 | ) |
48 | 48 | from executorch.util.activation_memory_profiler import generate_memory_trace |
| 49 | +from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear |
49 | 50 |
|
50 | 51 | from ..model_factory import EagerModelFactory |
51 | 52 | from .source_transformation.apply_spin_quant_r1_r2 import ( |
|
57 | 58 | from .source_transformation.quantize import ( |
58 | 59 | get_quant_embedding_transform, |
59 | 60 | get_quant_weight_transform, |
| 61 | + QuantizedGroupEmbedding, |
60 | 62 | ) |
61 | 63 | from .source_transformation.quantized_kv_cache import ( |
62 | 64 | replace_kv_cache_with_custom_kv_cache, |
@@ -593,24 +595,53 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager: |
593 | 595 | dtype_override=dtype_override, |
594 | 596 | args=args, |
595 | 597 | ) |
| 598 | + |
| 599 | + # # Override dtype of the model as specified by the user args. |
| 600 | + # if dtype_override: |
| 601 | + # assert isinstance( |
| 602 | + # dtype_override, DType |
| 603 | + # ), "Override dtype needs to be of type <DType>" |
| 604 | + # torch_dtype = dtype_override.to_torch_dtype() |
| 605 | + # logging.info(f"model.to {torch_dtype}") |
| 606 | + # edge_manager.model = edge_manager.model.to(dtype=torch_dtype) |
| 607 | + # metadata_str=args.metadata, |
| 608 | + # dtype_override=dtype_override, |
| 609 | + # args=args, |
| 610 | + # ) |
| 611 | + |
596 | 612 | # Assumes the checkpoint has uniform dtype. |
597 | 613 | checkpoint_dtype = next(edge_manager.model.parameters()).dtype |
598 | 614 | print(f"checkpoint dtype: {checkpoint_dtype}") |
599 | | - # We want to quantize with the model in the checkpoint dtype before casting to dtype_override. |
| 615 | + # We want to quantize the weights of the model in the checkpoint dtype. |
600 | 616 | edge_manager = edge_manager.set_output_dir(output_dir_path).source_transform( |
601 | 617 | _get_source_transforms( |
602 | 618 | args.model, DType.from_torch_dtype(checkpoint_dtype), args |
603 | 619 | ) |
604 | 620 | ) |
605 | 621 |
|
606 | | - # Override dtype of the model as specified by the user args. |
607 | | - if dtype_override: |
608 | | - assert isinstance( |
609 | | - dtype_override, DType |
610 | | - ), "Override dtype needs to be of type <DType>" |
611 | | - torch_dtype = dtype_override.to_torch_dtype() |
612 | | - logging.info(f"model.to {torch_dtype}") |
613 | | - edge_manager.model = edge_manager.model.to(dtype=torch_dtype) |
| 622 | + quantized = torch.load("/home/jackzhxng/torchrepos/executorch/fake_quantized_weights.pt") |
| 623 | + breakpoint() |
| 624 | + # torch.testing.assert_close() |
| 625 | + |
| 626 | + # We want to do compute the actual ops in the precision of the dtype_override. |
| 627 | + def _set_precision_to_fp32(module): |
| 628 | + """ |
| 629 | + Recursively iterate through the module and set the precision attribute |
| 630 | + of all Int8DynActInt4WeightLinear submodules to 'fp32'. |
| 631 | + """ |
| 632 | + for name, child in module.named_children(): |
| 633 | + if isinstance(child, Int8DynActInt4WeightLinear): |
| 634 | + # Change the precision attribute to 'fp32' |
| 635 | + child.precision = torch.float32 |
| 636 | + print(f"Changed precision of {name} to torch.float32") |
| 637 | + elif isinstance(child, QuantizedGroupEmbedding): |
| 638 | + child.dtype = torch.float32 |
| 639 | + print(f"Changed precision of {name} to torch.float32") |
| 640 | + else: |
| 641 | + # Recursively apply to child modules |
| 642 | + _set_precision_to_fp32(child) |
| 643 | + |
| 644 | + _set_precision_to_fp32(edge_manager.model) |
614 | 645 |
|
615 | 646 | return edge_manager |
616 | 647 |
|
|
0 commit comments