Skip to content

Commit 3d58d14

Browse files
authored
Merge branch 'pytorch:main' into exynos-quantize-support
2 parents 44a5e9e + d39992f commit 3d58d14

File tree

3 files changed

+40
-17
lines changed

3 files changed

+40
-17
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1238,12 +1238,15 @@ def _load_llama_model(llm_config: LlmConfig) -> "LLMEdgeManager":
12381238
else:
12391239
raise ValueError(f"{modelname} is not a valid Llama model.")
12401240

1241-
model, example_inputs, example_kwarg_inputs, dynamic_shapes = (
1242-
EagerModelFactory.create_model(
1243-
module_name,
1244-
model_class_name,
1245-
llm_config=llm_config,
1246-
)
1241+
(
1242+
model,
1243+
example_inputs,
1244+
example_kwarg_inputs,
1245+
dynamic_shapes,
1246+
) = EagerModelFactory.create_model(
1247+
module_name,
1248+
model_class_name,
1249+
llm_config=llm_config,
12471250
)
12481251
# Convert dtype override string to actual type.
12491252
dtype_override = DType[llm_config.model.dtype_override.value]
@@ -1322,6 +1325,7 @@ def _get_source_transforms( # noqa
13221325
local_global_attention: Optional[List[int]] = None,
13231326
use_torchao_kernels_linear: bool = False,
13241327
use_torchao_kernels_tied_embedding: bool = False,
1328+
quantize_with_hqq: bool = True,
13251329
) -> List[Callable[[torch.nn.Module], torch.nn.Module]]:
13261330
"""
13271331
Return a list of functions that transform a graph.
@@ -1391,7 +1395,10 @@ def _get_source_transforms( # noqa
13911395
"""
13921396
transforms.append(
13931397
get_quant_embedding_transform(
1394-
embedding_quantize, use_shared_embedding, checkpoint_dtype
1398+
embedding_quantize,
1399+
use_shared_embedding,
1400+
checkpoint_dtype,
1401+
quantize_with_hqq,
13951402
)
13961403
)
13971404

@@ -1422,6 +1429,7 @@ def _get_source_transforms( # noqa
14221429
calibration_tasks=calibration_tasks,
14231430
calibration_limit=calibration_limit,
14241431
calibration_seq_length=calibration_seq_length,
1432+
quantize_with_hqq=quantize_with_hqq,
14251433
)
14261434
)
14271435

examples/models/llama/source_transformation/quantize.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def quantize( # noqa C901
4949
blocksize: int = 128,
5050
tokenizer_path: Optional[Path] = None,
5151
verbose: bool = False,
52+
quantize_with_hqq: bool = True,
5253
) -> torch.nn.Module:
5354
"""
5455
Quantizes a model by converting all weights to int8.
@@ -119,7 +120,6 @@ def quantize( # noqa C901
119120
from torchao.quantization.granularity import PerAxis, PerGroup
120121
from torchao.quantization.quant_api import (
121122
Int8DynamicActivationIntxWeightConfig,
122-
MappingType,
123123
quantize_,
124124
)
125125
from torchao.utils import unwrap_tensor_subclass
@@ -134,9 +134,12 @@ def quantize( # noqa C901
134134
weight_granularity=(
135135
PerAxis(0) if group_size == 0 else PerGroup(group_size)
136136
),
137-
weight_mapping_type=MappingType.SYMMETRIC,
138137
# pyre-ignore[6]
139138
intx_packing_format="opaque_torchao_auto",
139+
# pyre-ignore[6]
140+
intx_choose_qparams_algorithm=(
141+
"hqq_scale_only" if quantize_with_hqq else "affine"
142+
),
140143
),
141144
)
142145
model = unwrap_tensor_subclass(model)
@@ -170,6 +173,10 @@ def filter_fn(m, fqn):
170173
# pyre-ignore[16]
171174
weight_dtype=torch.int4,
172175
weight_granularity=PerGroup(group_size),
176+
# pyre-ignore[6]
177+
intx_choose_qparams_algorithm=(
178+
"hqq_scale_only" if quantize_with_hqq else "affine"
179+
),
173180
),
174181
filter_fn=filter_fn,
175182
)
@@ -191,6 +198,10 @@ def filter_fn(m, fqn):
191198
# pyre-ignore[16]
192199
weight_dtype=torch.int4,
193200
granularity=PerGroup(q_group_size),
201+
# pyre-ignore[6]
202+
intx_choose_qparams_algorithm=(
203+
"hqq_scale_only" if quantize_with_hqq else "affine"
204+
),
194205
)
195206
quantize_(model, q_config)
196207
model = unwrap_tensor_subclass(model)
@@ -580,6 +591,7 @@ def __init__(
580591
group_size: Optional[int] = None,
581592
packed=False,
582593
precision: Optional[torch.dtype] = None,
594+
quantize_with_hqq: bool = True,
583595
):
584596
if isinstance(packed, str):
585597
packed = packed == "True"
@@ -592,15 +604,12 @@ def __init__(
592604
self.precision = precision
593605
if (bitwidth not in [2, 4]) and packed:
594606
raise RuntimeError("pack only works with bitsize 2, 4")
607+
self.quantize_with_hqq = quantize_with_hqq
595608

596609
@torch.no_grad()
597610
def create_quantized_state_dict(self, packed=False) -> Dict:
598611
from torchao.quantization.granularity import PerAxis, PerGroup
599-
from torchao.quantization.quant_api import (
600-
IntxWeightOnlyConfig,
601-
MappingType,
602-
quantize_,
603-
)
612+
from torchao.quantization.quant_api import IntxWeightOnlyConfig, quantize_
604613

605614
cur_state_dict = self.mod.state_dict()
606615

@@ -627,7 +636,10 @@ def create_quantized_state_dict(self, packed=False) -> Dict:
627636
if (self.group_size is None or self.group_size == 0)
628637
else PerGroup(self.group_size)
629638
),
630-
mapping_type=MappingType.SYMMETRIC,
639+
# pyre-ignore[6]
640+
intx_choose_qparams_algorithm=(
641+
"hqq_scale_only" if self.quantize_with_hqq else "affine"
642+
),
631643
)
632644
quantize_(tmp_model, config, lambda m, fqn: isinstance(m, nn.Embedding))
633645
weight = tmp_model.weight.qdata # pyre-ignore[16]
@@ -765,6 +777,7 @@ def get_quant_embedding_transform(
765777
embedding_quantize: str,
766778
use_shared_embedding: bool = False,
767779
dtype_override: Optional[DType] = None,
780+
quantize_with_hqq: bool = True,
768781
):
769782
if embedding_quantize.startswith("torchao:"):
770783
from torchao.prototype.quantization.embedding.api import (
@@ -825,6 +838,7 @@ def _torchao_embedding_quantizer(model):
825838
group_size=group_size,
826839
packed=(bitwidth in [2, 4]),
827840
precision=torch_dtype,
841+
quantize_with_hqq=quantize_with_hqq,
828842
).quantized_model()
829843

830844

@@ -838,6 +852,7 @@ def get_quant_weight_transform(
838852
calibration_tasks: Optional[list] = None,
839853
calibration_limit: Optional[int] = None,
840854
calibration_seq_length: Optional[int] = None,
855+
quantize_with_hqq: bool = True,
841856
):
842857
return partial(
843858
quantize,
@@ -850,6 +865,7 @@ def get_quant_weight_transform(
850865
calibration_limit=calibration_limit,
851866
calibration_seq_length=calibration_seq_length,
852867
tokenizer_path=(Path(path) if (path := tokenizer_path) is not None else None),
868+
quantize_with_hqq=quantize_with_hqq,
853869
)
854870

855871

@@ -877,7 +893,6 @@ def _load_torchao_aten_lib(libname):
877893
def set_8da4w_computation_dtype(
878894
module: nn.Module, computation_dtype: torch.dtype
879895
) -> nn.Module:
880-
881896
from torchao.quantization.linear_quant_modules import Int8DynActInt4WeightLinear
882897

883898
def _set_8da4w_computation_dtype(module: nn.Module, dtype: torch.dtype) -> None:

third-party/ao

Submodule ao updated 67 files

0 commit comments

Comments
 (0)