Skip to content

Commit 2c1a355

Browse files
jackzhxngfacebook-github-bot
authored andcommitted
Fix xnnpack quantization discrepancy for non-fp32 (#8488)
Summary: Perform quantization on the weights expressed in their original dtype (from the checkpoint) by passing in the checkpoint dtype to the quantization source transformation and modifying the computation dtype (the result dtype of the dequant, the dtype that the ops are actually computed in) to the dtype override. We must do it this way since the checkpoint and computation dtype are coupled into a single `precision` parameter in the torchao api, and that is something that we cannot change. Note - no need to worry about https://github.com/pytorch/ao/blob/main/torchao/quantization/GPTQ.py#L1168, precision is passed in with the checkpoint dtype ### Comparison of arbitrary q_proj tensor from sample Llama checkpoint: Before: ``` Mismatched elements: 3260378 / 4194304 (77.7%) Greatest absolute difference: 0.08802086114883423 at index (1129, 604) (up to 1e-05 allowed) Greatest relative difference: 1.0 at index (0, 1350) (up to 1.3e-06 allowed) Signal-to-noise: 32.8974 dB ``` After: no difference Test Plan: ### Manual testing ``` python -m examples.models.llama.export_llama \ -v -c xl_consolidated/consolidated_renamed.pth \ -p xl_consolidated/et_params.json -kv -d fp32 \ -qmode 8da4w --group_size 32 -X \ --use_sdpa_with_kv_cache \ --output_name quantized_baseline.pte \ --max_context_length 4096 -E 4,32 ``` With the following inserted after the quantization: ``` edge_manager.model( torch.tensor([[2, 3, 4]], dtype=torch.long), {"input_pos": torch.tensor([0], dtype=torch.long)}, ) ``` And the following modifications to GPTQ.py in torchao: pytorch/ao#1756 for testing. ### Automated testing + existing CI tests ### Regression testing TBD Reviewed By: kimishpatel Differential Revision: D70184325 Pulled By: jackzhxng
1 parent d16b867 commit 2c1a355

File tree

6 files changed

+159
-28
lines changed

6 files changed

+159
-28
lines changed

examples/models/checkpoint.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from pathlib import Path
1010
from typing import Any, Dict, Optional
1111

12+
import torch
13+
1214

1315
def get_default_model_resource_dir(model_file_path: str) -> Path:
1416
"""
@@ -52,7 +54,7 @@ def get_default_model_resource_dir(model_file_path: str) -> Path:
5254
return resource_dir
5355

5456

55-
def get_checkpoint_dtype(checkpoint: Dict[str, Any]) -> Optional[str]:
57+
def get_checkpoint_dtype(checkpoint: Dict[str, Any]) -> Optional[torch.dtype]:
5658
"""
5759
Get the dtype of the checkpoint, returning "None" if the checkpoint is empty.
5860
"""

examples/models/llama/export_llama_lib.py

Lines changed: 60 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import re
1717
import shlex
1818
from enum import Enum
19+
from functools import partial
1920
from json import JSONDecodeError
2021
from pathlib import Path
2122
from typing import Callable, List, Optional, Union
@@ -594,9 +595,36 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
594595
)
595596

596597
# At this point, the model is loaded in the default fp32.
598+
599+
# Checkpoint dtype should be lower or equal precision to the dtype override.
600+
checkpoint_dtype = edge_manager.model.checkpoint_dtype
601+
if not (
602+
checkpoint_dtype == dtype_override.to_torch_dtype()
603+
or (
604+
checkpoint_dtype == torch.float16
605+
and dtype_override.to_torch_dtype() == torch.float32
606+
)
607+
or (
608+
checkpoint_dtype == torch.bfloat16
609+
and dtype_override.to_torch_dtype() == torch.float32
610+
)
611+
):
612+
logging.warning(
613+
f"Checkpoint dtype {checkpoint_dtype} precision is higher than dtype override {dtype_override.to_torch_dtype()}."
614+
)
615+
597616
edge_manager.model = edge_manager.model.to(dtype=dtype_override.to_torch_dtype())
598-
edge_manager.set_output_dir(output_dir_path).source_transform(
599-
_get_source_transforms(args.model, dtype_override, args)
617+
618+
# We want to quantize (in the source transforms) the weights of the model
619+
# in the checkpoint dtype.
620+
logging.info(f"Checkpoint dtype: {edge_manager.model.checkpoint_dtype}")
621+
edge_manager = edge_manager.set_output_dir(output_dir_path).source_transform(
622+
_get_source_transforms(
623+
modelname=args.model,
624+
dtype_override=dtype_override,
625+
checkpoint_dtype=DType.from_torch_dtype(checkpoint_dtype),
626+
args=args,
627+
)
600628
)
601629

602630
return edge_manager
@@ -784,8 +812,6 @@ def _to_edge_and_lower_llama( # noqa: C901
784812
shares=args.num_sharding,
785813
)
786814

787-
from functools import partial
788-
789815
# pyre-ignore
790816
from executorch.backends.qualcomm.quantizer.custom_annotation import (
791817
get_custom_quant_ios_dtype,
@@ -1069,8 +1095,31 @@ def _load_llama_model(
10691095

10701096

10711097
def _get_source_transforms( # noqa
1072-
modelname: str, dtype_override: Optional[DType], args
1098+
modelname: str,
1099+
dtype_override: DType,
1100+
*,
1101+
checkpoint_dtype: Optional[DType] = None,
1102+
args,
10731103
) -> List[Callable[[torch.nn.Module], torch.nn.Module]]:
1104+
"""
1105+
Return a list of functions that transform a graph.
1106+
1107+
Args:
1108+
modelname: The name of the model.
1109+
dtype_override: The dtype to use for the model.
1110+
checkpoint_dtype: The dtype of the checkpoint. At the moment, if this is specified,
1111+
it means that you want to run quantize transformations on the weights represented
1112+
in their original dtype, while the overall dtype of the model maybe something
1113+
different. If not specified, defaults to dtype_override.
1114+
args: The arguments passed to the script.
1115+
1116+
Returns:
1117+
A list of transformation functions.
1118+
"""
1119+
1120+
if not checkpoint_dtype:
1121+
checkpoint_dtype = dtype_override
1122+
10741123
transforms = []
10751124

10761125
if args.use_spin_quant:
@@ -1103,7 +1152,11 @@ def _get_source_transforms( # noqa
11031152
"""
11041153
modelname = f"{modelname}_q"
11051154
transforms.append(
1106-
get_quant_weight_transform(args, dtype_override, verbose_export())
1155+
get_quant_weight_transform(
1156+
args=args,
1157+
computation_dtype=dtype_override,
1158+
checkpoint_dtype=checkpoint_dtype,
1159+
)
11071160
)
11081161

11091162
if args.embedding_quantize:
@@ -1117,7 +1170,7 @@ def _get_source_transforms( # noqa
11171170
this wil be a no-op.
11181171
"""
11191172
modelname = f"{modelname}_e"
1120-
transforms.append(get_quant_embedding_transform(args))
1173+
transforms.append(get_quant_embedding_transform(args, checkpoint_dtype))
11211174

11221175
if args.expand_rope_table:
11231176
transforms.append(materialze_broadcast_of_rope_freq_cis)

examples/models/llama/source_transformation/quantize.py

Lines changed: 82 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
from sentencepiece import SentencePieceProcessor
2020

21+
2122
try:
2223
from fairseq2.nn.embedding import (
2324
Embedding as fsEmbedding,
@@ -36,7 +37,8 @@
3637
def quantize( # noqa C901
3738
model: torch.nn.Module,
3839
qmode: str,
39-
activation_dtype: Optional[DType],
40+
computation_dtype: Optional[DType] = None,
41+
checkpoint_dtype: Optional[DType] = None,
4042
checkpoint_path: Optional[Path] = None,
4143
# following arguments only available when setting int4 or gptq quantization.
4244
group_size: Optional[int] = 128,
@@ -52,20 +54,33 @@ def quantize( # noqa C901
5254
) -> torch.nn.Module:
5355
"""
5456
Quantizes a model by converting all weights to int8.
57+
5558
Args:
56-
model: A model to quantize.
57-
qmode: quantization mode, e.g. int8, 8da4w, 8da4w-gptq
59+
model: The model to quantize.
60+
qmode: The quantization mode, e.g. int8, 8da4w, 8da4w-gptq.
61+
computation_dtype: The dtype that ops are performed in (the resulting dtype of dequantization).
62+
Also the dtype of the rest of the non-quantized compoents of the model.
63+
checkpoint_dtype: The dtype of the checkpoint, this arg exists since it is more accurate to
64+
quantize the weight in its original dtype.
65+
5866
Returns:
5967
A quantized model.
6068
"""
61-
if activation_dtype is not None:
62-
torch_dtype = activation_dtype.to_torch_dtype()
69+
if computation_dtype:
70+
computation_torch_dtype = computation_dtype.to_torch_dtype()
6371
else:
64-
torch_dtype = torch.float16
72+
computation_torch_dtype = torch.float32
73+
74+
if not checkpoint_dtype:
75+
checkpoint_torch_dtype = computation_torch_dtype
76+
else:
77+
checkpoint_torch_dtype = checkpoint_dtype.to_torch_dtype()
6578

6679
if qmode == "int8":
6780
# Add quantization mode options here: group size, bit width, etc.
68-
return WeightOnlyInt8QuantHandler(model).quantized_model()
81+
return WeightOnlyInt8QuantHandler(
82+
model, precision=checkpoint_torch_dtype
83+
).quantized_model()
6984
elif qmode.startswith("torchao:fpa"):
7085
pattern = r"torchao:fpa(\d+)w"
7186
matches = re.findall(pattern, qmode)
@@ -75,10 +90,12 @@ def quantize( # noqa C901
7590
from torchao.experimental.quant_api import UIntxWeightOnlyLinearQuantizer
7691

7792
with torch.no_grad():
93+
# This quantize() is currently doing a model.to(self.precision) so cannot
94+
# decouple computation and checkpoint dtypes.
7895
model = (
7996
UIntxWeightOnlyLinearQuantizer(
8097
device="mps",
81-
precision=torch.float32,
98+
precision=computation_torch_dtype,
8299
groupsize=group_size,
83100
bitwidth=bitwidth,
84101
)
@@ -101,6 +118,8 @@ def quantize( # noqa C901
101118
from torchao.utils import unwrap_tensor_subclass
102119

103120
with torch.no_grad():
121+
# Computation dtype is fixed to fp32 in the implementation of quantize_, so
122+
# no way to decouple checkpoint and computation dtype.
104123
quantize_(
105124
model,
106125
Int8DynamicActivationIntxWeightConfig(
@@ -121,9 +140,12 @@ def quantize( # noqa C901
121140
raise Exception("For 8da4w quantization, group size must be specified.")
122141
from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer
123142

143+
# 1. Quantize in checkpoint dtype.
124144
model = Int8DynActInt4WeightQuantizer(
125-
precision=torch_dtype, groupsize=group_size
145+
precision=checkpoint_torch_dtype, groupsize=group_size
126146
).quantize(model)
147+
# 2. Set the computation dtype (what weights/acts dequantize to).
148+
model = set_8da4w_computation_dtype(model, computation_torch_dtype)
127149

128150
if verbose:
129151
print("quantized model:", model)
@@ -177,7 +199,7 @@ def quantize( # noqa C901
177199
blocksize,
178200
percdamp,
179201
group_size,
180-
)
202+
) # TODO: separate computation and checkpoint dtype for GPTQ.
181203
model = gptq_quantizer.quantize(model, inputs)
182204
return model
183205
elif qmode == "vulkan_4w":
@@ -190,9 +212,12 @@ def quantize( # noqa C901
190212
# at the moment
191213
from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer
192214

215+
# 1. Quantize in checkpoint dtype.
193216
model = Int8DynActInt4WeightQuantizer(
194-
precision=torch_dtype, groupsize=q_group_size
217+
precision=checkpoint_torch_dtype, groupsize=q_group_size
195218
).quantize(model)
219+
# 2. Set the computation dtype (what weights/acts dequantize to).
220+
model = set_8da4w_computation_dtype(model, computation_torch_dtype)
196221

197222
return model
198223
else:
@@ -348,6 +373,7 @@ def __init__(
348373
node_type: str = "*",
349374
bitwidth: Optional[int] = None,
350375
group_size: Optional[int] = None,
376+
precision: torch.dtype = torch.float32,
351377
):
352378
self.mod = mod
353379
self.group_size = group_size
@@ -356,6 +382,7 @@ def __init__(
356382
self.bitwidth = 8
357383
else:
358384
self.bitwidth = bitwidth
385+
self.precision = precision
359386

360387
@torch.no_grad()
361388
def create_quantized_state_dict(self) -> Dict:
@@ -391,7 +418,7 @@ def create_quantized_state_dict(self) -> Dict:
391418

392419
# print(f"expanded weight shape {input_weight.shape}")
393420
weight, scales, _ = dynamically_quantize_per_channel(
394-
input_weight,
421+
input_weight.to(dtype=self.precision),
395422
range_min,
396423
range_max,
397424
torch.int8,
@@ -576,6 +603,7 @@ def __init__(
576603
bitwidth: int = 8,
577604
group_size: Optional[int] = None,
578605
packed=False,
606+
precision: Optional[torch.dtype] = None,
579607
):
580608
if isinstance(packed, str):
581609
packed = packed == "True"
@@ -584,6 +612,8 @@ def __init__(
584612
self.group_size = group_size
585613
self.bitwidth = bitwidth
586614
self.packed = packed
615+
# Dtype of the weights right before quantization.
616+
self.precision = precision
587617
if (bitwidth not in [2, 4]) and packed:
588618
raise RuntimeError("pack only works with bitsize 2, 4")
589619

@@ -614,7 +644,11 @@ def create_quantized_state_dict(self, packed=False) -> Dict:
614644
f"quantize {fqn, mod} with group_size {self.group_size}, bitwidth {self.bitwidth}"
615645
)
616646
weight, scales, _ = dynamically_quantize_per_channel(
617-
mod.weight.float(),
647+
(
648+
mod.weight.to(dtype=self.precision)
649+
if self.precision
650+
else mod.weight
651+
),
618652
range_min,
619653
range_max,
620654
torch.int8,
@@ -750,7 +784,7 @@ def forward(self, indices: torch.Tensor) -> torch.Tensor:
750784
############################ Source Transform Start #######################
751785

752786

753-
def get_quant_embedding_transform(args):
787+
def get_quant_embedding_transform(args, dtype_override: Optional[DType] = None):
754788
if args.embedding_quantize.startswith("torchao:"):
755789
bitwidth, group_size = args.embedding_quantize.split(":")[1].split(",")
756790
group_size = int(group_size)
@@ -775,16 +809,22 @@ def _torchao_embedding_quantizer(model):
775809
else:
776810
group_size = int(group_size)
777811
bitwidth = int(bitwidth)
812+
torch_dtype = dtype_override.to_torch_dtype() if dtype_override else None
778813
return lambda model: EmbeddingQuantHandler(
779814
model,
780815
bitwidth=bitwidth,
781816
group_size=group_size,
782817
packed=(bitwidth in [2, 4]),
818+
precision=torch_dtype,
783819
).quantized_model()
784820

785821

786-
def get_quant_weight_transform(args, dtype_override, verbose):
787-
# If these optional args are None, don't provide them to quantize()
822+
def get_quant_weight_transform(
823+
args,
824+
computation_dtype: Optional[DType] = None,
825+
checkpoint_dtype: Optional[DType] = None,
826+
):
827+
# If these optional args are None, don't provide them to quantize().
788828
quant_args_str = [
789829
"group_size",
790830
"calibration_tasks",
@@ -802,7 +842,8 @@ def get_quant_weight_transform(args, dtype_override, verbose):
802842
quantize,
803843
**quant_args,
804844
qmode=args.quantization_mode,
805-
activation_dtype=dtype_override,
845+
computation_dtype=computation_dtype,
846+
checkpoint_dtype=checkpoint_dtype,
806847
checkpoint_path=(Path(path) if (path := args.checkpoint) is not None else None),
807848
tokenizer_path=(
808849
Path(path) if (path := args.tokenizer_path) is not None else None
@@ -829,4 +870,28 @@ def _load_torchao_aten_lib(libname):
829870
torch.ops.load_library(libs[0])
830871

831872

873+
# We want to do compute the actual ops in the computation dtype, since the precision of the
874+
# quantized linear will initially be the dtype of the checkpoint.
875+
def set_8da4w_computation_dtype(
876+
module: nn.Module, computation_dtype: torch.dtype
877+
) -> nn.Module:
878+
879+
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
880+
881+
def _set_8da4w_computation_dtype(module: nn.Module, dtype: torch.dtype) -> None:
882+
"""
883+
Recursively iterate through the module and set the precision attributes
884+
of all Int8DynActInt4WeightLinears.
885+
"""
886+
for _name, child in module.named_children():
887+
if isinstance(child, Int8DynActInt4WeightLinear):
888+
child.precision = dtype
889+
else:
890+
# Recursively apply to child modules
891+
_set_8da4w_computation_dtype(child, dtype)
892+
893+
_set_8da4w_computation_dtype(module, computation_dtype)
894+
return module
895+
896+
832897
############################ Source Transform End #######################

examples/models/llava/export_llava.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def forward(self, input_pos, embeddings):
100100
args = parser.parse_args(
101101
["-X", "-qmode", "8da4w", "--group_size", "128", "--embedding-quantize", "4,32"]
102102
)
103-
quant_transform = get_quant_weight_transform(args, dtype_override, False)
103+
quant_transform = get_quant_weight_transform(args, dtype_override)
104104
_, quantizers, _ = get_quantizer_and_quant_params(args)
105105
source_transforms = []
106106
if llava.use_sdpa_with_kv_cache_op:

exir/tests/test_memory_planning.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -708,10 +708,10 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
708708
et_program = et.executorch_program
709709
inputs = et_program.execution_plan[0].inputs
710710
self.assertNotEqual(
711-
et_program.execution_plan[0] # pyre-ignore
711+
et_program.execution_plan[0]
712712
.values[inputs[0]]
713713
.val.allocation_info.memory_offset_low,
714-
et_program.execution_plan[0] # pyre-ignore
714+
et_program.execution_plan[0]
715715
.values[inputs[1]]
716716
.val.allocation_info.memory_offset_low,
717717
)

0 commit comments

Comments
 (0)