Skip to content

Commit df5be3e

Browse files
committed
Move to helper function
1 parent 5760f82 commit df5be3e

File tree

2 files changed

+28
-24
lines changed

2 files changed

+28
-24
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 2 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@
4646
get_vulkan_quantizer,
4747
)
4848
from executorch.util.activation_memory_profiler import generate_memory_trace
49-
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
5049

5150
from ..model_factory import EagerModelFactory
5251
from .source_transformation.apply_spin_quant_r1_r2 import (
@@ -56,9 +55,9 @@
5655

5756
from .source_transformation.attention import replace_attention_to_attention_sha
5857
from .source_transformation.quantize import (
58+
_set_quantized_computation_dtype,
5959
get_quant_embedding_transform,
6060
get_quant_weight_transform,
61-
QuantizedGroupEmbedding,
6261
)
6362
from .source_transformation.quantized_kv_cache import (
6463
replace_kv_cache_with_custom_kv_cache,
@@ -606,27 +605,7 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
606605
)
607606
)
608607

609-
# We want to do compute the actual ops in the precision of the dtype_override,
610-
# since the precision of the quantized linear will initially be the dtype of the
611-
# checkpoint, not the dtype_override.
612-
def _set_precision_to_fp32(module):
613-
"""
614-
Recursively iterate through the module and set the precision attribute
615-
of all Int8DynActInt4WeightLinear submodules to 'fp32'.
616-
"""
617-
for name, child in module.named_children():
618-
if isinstance(child, Int8DynActInt4WeightLinear):
619-
# Change the precision attribute to 'fp32'
620-
child.precision = torch.float32
621-
print(f"Changed precision of {name} to torch.float32")
622-
elif isinstance(child, QuantizedGroupEmbedding):
623-
child.dtype = torch.float32
624-
print(f"Changed precision of {name} to torch.float32")
625-
else:
626-
# Recursively apply to child modules
627-
_set_precision_to_fp32(child)
628-
629-
_set_precision_to_fp32(edge_manager.model)
608+
_set_quantized_computation_dtype(edge_manager.model, dtype_override.to_torch_dtype())
630609

631610
return edge_manager
632611

examples/models/llama/source_transformation/quantize.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,11 @@
1515
import torch.nn.functional as F
1616

1717
from executorch.backends.vulkan._passes import VkInt4WeightOnlyQuantizer
18-
1918
from executorch.extension.llm.export.builder import DType
2019

2120
from sentencepiece import SentencePieceProcessor
21+
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
22+
2223

2324
try:
2425
from fairseq2.nn.embedding import (
@@ -827,4 +828,28 @@ def _load_torchao_aten_lib(libname):
827828
torch.ops.load_library(libs[0])
828829

829830

831+
# We want to do compute the actual ops in the dtype of the dtype_override,
832+
# since the precision of the quantized linear will initially be the dtype of the
833+
# checkpoint, not the dtype_override.
834+
# TODO(#8652): this is a temporary solution for until we can support the new ao,
835+
# quantize_ api, which apparently can support different dtypes at quantization and
836+
# computation.
837+
def _set_quantized_computation_dtype(module: nn.Module, dtype: torch.dtype):
838+
"""
839+
Recursively iterate through the module and set the dtype/precision attributes
840+
of all Int8DynActInt4WeightLinear and QuantizedGroupEmbedding submodules to 'fp32'.
841+
"""
842+
for name, child in module.named_children():
843+
if isinstance(child, Int8DynActInt4WeightLinear):
844+
# Change the precision attribute to 'fp32'
845+
child.precision = dtype
846+
print(f"Changed precision of {name} to {dtype}")
847+
elif isinstance(child, QuantizedGroupEmbedding):
848+
child.dtype = dtype
849+
print(f"Changed precision of {name} to {dtype}")
850+
else:
851+
# Recursively apply to child modules
852+
_set_quantized_computation_dtype(child, dtype)
853+
854+
830855
############################ Source Transform End #######################

0 commit comments

Comments
 (0)