|
46 | 46 | get_vulkan_quantizer, |
47 | 47 | ) |
48 | 48 | from executorch.util.activation_memory_profiler import generate_memory_trace |
49 | | -from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear |
50 | 49 |
|
51 | 50 | from ..model_factory import EagerModelFactory |
52 | 51 | from .source_transformation.apply_spin_quant_r1_r2 import ( |
|
56 | 55 |
|
57 | 56 | from .source_transformation.attention import replace_attention_to_attention_sha |
58 | 57 | from .source_transformation.quantize import ( |
| 58 | + _set_quantized_computation_dtype, |
59 | 59 | get_quant_embedding_transform, |
60 | 60 | get_quant_weight_transform, |
61 | | - QuantizedGroupEmbedding, |
62 | 61 | ) |
63 | 62 | from .source_transformation.quantized_kv_cache import ( |
64 | 63 | replace_kv_cache_with_custom_kv_cache, |
@@ -606,27 +605,7 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager: |
606 | 605 | ) |
607 | 606 | ) |
608 | 607 |
|
609 | | - # We want to do compute the actual ops in the precision of the dtype_override, |
610 | | - # since the precision of the quantized linear will initially be the dtype of the |
611 | | - # checkpoint, not the dtype_override. |
612 | | - def _set_precision_to_fp32(module): |
613 | | - """ |
614 | | - Recursively iterate through the module and set the precision attribute |
615 | | - of all Int8DynActInt4WeightLinear submodules to 'fp32'. |
616 | | - """ |
617 | | - for name, child in module.named_children(): |
618 | | - if isinstance(child, Int8DynActInt4WeightLinear): |
619 | | - # Change the precision attribute to 'fp32' |
620 | | - child.precision = torch.float32 |
621 | | - print(f"Changed precision of {name} to torch.float32") |
622 | | - elif isinstance(child, QuantizedGroupEmbedding): |
623 | | - child.dtype = torch.float32 |
624 | | - print(f"Changed precision of {name} to torch.float32") |
625 | | - else: |
626 | | - # Recursively apply to child modules |
627 | | - _set_precision_to_fp32(child) |
628 | | - |
629 | | - _set_precision_to_fp32(edge_manager.model) |
| 608 | + _set_quantized_computation_dtype(edge_manager.model, dtype_override.to_torch_dtype()) |
630 | 609 |
|
631 | 610 | return edge_manager |
632 | 611 |
|
|
0 commit comments