Skip to content

Commit 38851a1

Browse files
authored
Fix xnnpack quantization discrepancy for non-fp32
Differential Revision: D70184325 Pull Request resolved: #8488
1 parent 78f0e67 commit 38851a1

File tree

6 files changed

+159
-28
lines changed

6 files changed

+159
-28
lines changed

examples/models/checkpoint.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from pathlib import Path
1010
from typing import Any, Dict, Optional
1111

12+
import torch
13+
1214

1315
def get_default_model_resource_dir(model_file_path: str) -> Path:
1416
"""
@@ -52,7 +54,7 @@ def get_default_model_resource_dir(model_file_path: str) -> Path:
5254
return resource_dir
5355

5456

55-
def get_checkpoint_dtype(checkpoint: Dict[str, Any]) -> Optional[str]:
57+
def get_checkpoint_dtype(checkpoint: Dict[str, Any]) -> Optional[torch.dtype]:
5658
"""
5759
Get the dtype of the checkpoint, returning "None" if the checkpoint is empty.
5860
"""

examples/models/llama/export_llama_lib.py

Lines changed: 60 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import re
1717
import shlex
1818
from enum import Enum
19+
from functools import partial
1920
from json import JSONDecodeError
2021
from pathlib import Path
2122
from typing import Callable, List, Optional, Union
@@ -594,9 +595,36 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
594595
)
595596

596597
# At this point, the model is loaded in the default fp32.
598+
599+
# Checkpoint dtype should be lower or equal precision to the dtype override.
600+
checkpoint_dtype = edge_manager.model.checkpoint_dtype
601+
if not (
602+
checkpoint_dtype == dtype_override.to_torch_dtype()
603+
or (
604+
checkpoint_dtype == torch.float16
605+
and dtype_override.to_torch_dtype() == torch.float32
606+
)
607+
or (
608+
checkpoint_dtype == torch.bfloat16
609+
and dtype_override.to_torch_dtype() == torch.float32
610+
)
611+
):
612+
logging.warning(
613+
f"Checkpoint dtype {checkpoint_dtype} precision is higher than dtype override {dtype_override.to_torch_dtype()}."
614+
)
615+
597616
edge_manager.model = edge_manager.model.to(dtype=dtype_override.to_torch_dtype())
598-
edge_manager.set_output_dir(output_dir_path).source_transform(
599-
_get_source_transforms(args.model, dtype_override, args)
617+
618+
# We want to quantize (in the source transforms) the weights of the model
619+
# in the checkpoint dtype.
620+
logging.info(f"Checkpoint dtype: {edge_manager.model.checkpoint_dtype}")
621+
edge_manager = edge_manager.set_output_dir(output_dir_path).source_transform(
622+
_get_source_transforms(
623+
modelname=args.model,
624+
dtype_override=dtype_override,
625+
checkpoint_dtype=DType.from_torch_dtype(checkpoint_dtype),
626+
args=args,
627+
)
600628
)
601629

602630
return edge_manager
@@ -784,8 +812,6 @@ def _to_edge_and_lower_llama( # noqa: C901
784812
shares=args.num_sharding,
785813
)
786814

787-
from functools import partial
788-
789815
# pyre-ignore
790816
from executorch.backends.qualcomm.quantizer.custom_annotation import (
791817
get_custom_quant_ios_dtype,
@@ -1069,8 +1095,31 @@ def _load_llama_model(
10691095

10701096

10711097
def _get_source_transforms( # noqa
1072-
modelname: str, dtype_override: Optional[DType], args
1098+
modelname: str,
1099+
dtype_override: DType,
1100+
*,
1101+
checkpoint_dtype: Optional[DType] = None,
1102+
args,
10731103
) -> List[Callable[[torch.nn.Module], torch.nn.Module]]:
1104+
"""
1105+
Return a list of functions that transform a graph.
1106+
1107+
Args:
1108+
modelname: The name of the model.
1109+
dtype_override: The dtype to use for the model.
1110+
checkpoint_dtype: The dtype of the checkpoint. At the moment, if this is specified,
1111+
it means that you want to run quantize transformations on the weights represented
1112+
in their original dtype, while the overall dtype of the model maybe something
1113+
different. If not specified, defaults to dtype_override.
1114+
args: The arguments passed to the script.
1115+
1116+
Returns:
1117+
A list of transformation functions.
1118+
"""
1119+
1120+
if not checkpoint_dtype:
1121+
checkpoint_dtype = dtype_override
1122+
10741123
transforms = []
10751124

10761125
if args.use_spin_quant:
@@ -1103,7 +1152,11 @@ def _get_source_transforms( # noqa
11031152
"""
11041153
modelname = f"{modelname}_q"
11051154
transforms.append(
1106-
get_quant_weight_transform(args, dtype_override, verbose_export())
1155+
get_quant_weight_transform(
1156+
args=args,
1157+
computation_dtype=dtype_override,
1158+
checkpoint_dtype=checkpoint_dtype,
1159+
)
11071160
)
11081161

11091162
if args.embedding_quantize:
@@ -1117,7 +1170,7 @@ def _get_source_transforms( # noqa
11171170
this wil be a no-op.
11181171
"""
11191172
modelname = f"{modelname}_e"
1120-
transforms.append(get_quant_embedding_transform(args))
1173+
transforms.append(get_quant_embedding_transform(args, checkpoint_dtype))
11211174

11221175
if args.expand_rope_table:
11231176
transforms.append(materialze_broadcast_of_rope_freq_cis)

examples/models/llama/source_transformation/quantize.py

Lines changed: 82 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
from sentencepiece import SentencePieceProcessor
2020

21+
2122
try:
2223
from fairseq2.nn.embedding import (
2324
Embedding as fsEmbedding,
@@ -36,7 +37,8 @@
3637
def quantize( # noqa C901
3738
model: torch.nn.Module,
3839
qmode: str,
39-
activation_dtype: Optional[DType],
40+
computation_dtype: Optional[DType] = None,
41+
checkpoint_dtype: Optional[DType] = None,
4042
checkpoint_path: Optional[Path] = None,
4143
# following arguments only available when setting int4 or gptq quantization.
4244
group_size: Optional[int] = 128,
@@ -52,20 +54,33 @@ def quantize( # noqa C901
5254
) -> torch.nn.Module:
5355
"""
5456
Quantizes a model by converting all weights to int8.
57+
5558
Args:
56-
model: A model to quantize.
57-
qmode: quantization mode, e.g. int8, 8da4w, 8da4w-gptq
59+
model: The model to quantize.
60+
qmode: The quantization mode, e.g. int8, 8da4w, 8da4w-gptq.
61+
computation_dtype: The dtype that ops are performed in (the resulting dtype of dequantization).
62+
Also the dtype of the rest of the non-quantized compoents of the model.
63+
checkpoint_dtype: The dtype of the checkpoint, this arg exists since it is more accurate to
64+
quantize the weight in its original dtype.
65+
5866
Returns:
5967
A quantized model.
6068
"""
61-
if activation_dtype is not None:
62-
torch_dtype = activation_dtype.to_torch_dtype()
69+
if computation_dtype:
70+
computation_torch_dtype = computation_dtype.to_torch_dtype()
6371
else:
64-
torch_dtype = torch.float16
72+
computation_torch_dtype = torch.float32
73+
74+
if not checkpoint_dtype:
75+
checkpoint_torch_dtype = computation_torch_dtype
76+
else:
77+
checkpoint_torch_dtype = checkpoint_dtype.to_torch_dtype()
6578

6679
if qmode == "int8":
6780
# Add quantization mode options here: group size, bit width, etc.
68-
return WeightOnlyInt8QuantHandler(model).quantized_model()
81+
return WeightOnlyInt8QuantHandler(
82+
model, precision=checkpoint_torch_dtype
83+
).quantized_model()
6984
elif qmode.startswith("torchao:fpa"):
7085
pattern = r"torchao:fpa(\d+)w"
7186
matches = re.findall(pattern, qmode)
@@ -75,10 +90,12 @@ def quantize( # noqa C901
7590
from torchao.experimental.quant_api import UIntxWeightOnlyLinearQuantizer
7691

7792
with torch.no_grad():
93+
# This quantize() is currently doing a model.to(self.precision) so cannot
94+
# decouple computation and checkpoint dtypes.
7895
model = (
7996
UIntxWeightOnlyLinearQuantizer(
8097
device="mps",
81-
precision=torch.float32,
98+
precision=computation_torch_dtype,
8299
groupsize=group_size,
83100
bitwidth=bitwidth,
84101
)
@@ -101,6 +118,8 @@ def quantize( # noqa C901
101118
from torchao.utils import unwrap_tensor_subclass
102119

103120
with torch.no_grad():
121+
# Computation dtype is fixed to fp32 in the implementation of quantize_, so
122+
# no way to decouple checkpoint and computation dtype.
104123
quantize_(
105124
model,
106125
Int8DynamicActivationIntxWeightConfig(
@@ -121,9 +140,12 @@ def quantize( # noqa C901
121140
raise Exception("For 8da4w quantization, group size must be specified.")
122141
from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer
123142

143+
# 1. Quantize in checkpoint dtype.
124144
model = Int8DynActInt4WeightQuantizer(
125-
precision=torch_dtype, groupsize=group_size
145+
precision=checkpoint_torch_dtype, groupsize=group_size
126146
).quantize(model)
147+
# 2. Set the computation dtype (what weights/acts dequantize to).
148+
model = set_8da4w_computation_dtype(model, computation_torch_dtype)
127149

128150
if verbose:
129151
print("quantized model:", model)
@@ -177,7 +199,7 @@ def quantize( # noqa C901
177199
blocksize,
178200
percdamp,
179201
group_size,
180-
)
202+
) # TODO: separate computation and checkpoint dtype for GPTQ.
181203
model = gptq_quantizer.quantize(model, inputs)
182204
return model
183205
elif qmode == "vulkan_4w":
@@ -190,9 +212,12 @@ def quantize( # noqa C901
190212
# at the moment
191213
from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer
192214

215+
# 1. Quantize in checkpoint dtype.
193216
model = Int8DynActInt4WeightQuantizer(
194-
precision=torch_dtype, groupsize=q_group_size
217+
precision=checkpoint_torch_dtype, groupsize=q_group_size
195218
).quantize(model)
219+
# 2. Set the computation dtype (what weights/acts dequantize to).
220+
model = set_8da4w_computation_dtype(model, computation_torch_dtype)
196221

197222
return model
198223
else:
@@ -348,6 +373,7 @@ def __init__(
348373
node_type: str = "*",
349374
bitwidth: Optional[int] = None,
350375
group_size: Optional[int] = None,
376+
precision: torch.dtype = torch.float32,
351377
):
352378
self.mod = mod
353379
self.group_size = group_size
@@ -356,6 +382,7 @@ def __init__(
356382
self.bitwidth = 8
357383
else:
358384
self.bitwidth = bitwidth
385+
self.precision = precision
359386

360387
@torch.no_grad()
361388
def create_quantized_state_dict(self) -> Dict:
@@ -391,7 +418,7 @@ def create_quantized_state_dict(self) -> Dict:
391418

392419
# print(f"expanded weight shape {input_weight.shape}")
393420
weight, scales, _ = dynamically_quantize_per_channel(
394-
input_weight,
421+
input_weight.to(dtype=self.precision),
395422
range_min,
396423
range_max,
397424
torch.int8,
@@ -576,6 +603,7 @@ def __init__(
576603
bitwidth: int = 8,
577604
group_size: Optional[int] = None,
578605
packed=False,
606+
precision: Optional[torch.dtype] = None,
579607
):
580608
if isinstance(packed, str):
581609
packed = packed == "True"
@@ -584,6 +612,8 @@ def __init__(
584612
self.group_size = group_size
585613
self.bitwidth = bitwidth
586614
self.packed = packed
615+
# Dtype of the weights right before quantization.
616+
self.precision = precision
587617
if (bitwidth not in [2, 4]) and packed:
588618
raise RuntimeError("pack only works with bitsize 2, 4")
589619

@@ -614,7 +644,11 @@ def create_quantized_state_dict(self, packed=False) -> Dict:
614644
f"quantize {fqn, mod} with group_size {self.group_size}, bitwidth {self.bitwidth}"
615645
)
616646
weight, scales, _ = dynamically_quantize_per_channel(
617-
mod.weight.float(),
647+
(
648+
mod.weight.to(dtype=self.precision)
649+
if self.precision
650+
else mod.weight
651+
),
618652
range_min,
619653
range_max,
620654
torch.int8,
@@ -750,7 +784,7 @@ def forward(self, indices: torch.Tensor) -> torch.Tensor:
750784
############################ Source Transform Start #######################
751785

752786

753-
def get_quant_embedding_transform(args):
787+
def get_quant_embedding_transform(args, dtype_override: Optional[DType] = None):
754788
if args.embedding_quantize.startswith("torchao:"):
755789
bitwidth, group_size = args.embedding_quantize.split(":")[1].split(",")
756790
group_size = int(group_size)
@@ -775,16 +809,22 @@ def _torchao_embedding_quantizer(model):
775809
else:
776810
group_size = int(group_size)
777811
bitwidth = int(bitwidth)
812+
torch_dtype = dtype_override.to_torch_dtype() if dtype_override else None
778813
return lambda model: EmbeddingQuantHandler(
779814
model,
780815
bitwidth=bitwidth,
781816
group_size=group_size,
782817
packed=(bitwidth in [2, 4]),
818+
precision=torch_dtype,
783819
).quantized_model()
784820

785821

786-
def get_quant_weight_transform(args, dtype_override, verbose):
787-
# If these optional args are None, don't provide them to quantize()
822+
def get_quant_weight_transform(
823+
args,
824+
computation_dtype: Optional[DType] = None,
825+
checkpoint_dtype: Optional[DType] = None,
826+
):
827+
# If these optional args are None, don't provide them to quantize().
788828
quant_args_str = [
789829
"group_size",
790830
"calibration_tasks",
@@ -802,7 +842,8 @@ def get_quant_weight_transform(args, dtype_override, verbose):
802842
quantize,
803843
**quant_args,
804844
qmode=args.quantization_mode,
805-
activation_dtype=dtype_override,
845+
computation_dtype=computation_dtype,
846+
checkpoint_dtype=checkpoint_dtype,
806847
checkpoint_path=(Path(path) if (path := args.checkpoint) is not None else None),
807848
tokenizer_path=(
808849
Path(path) if (path := args.tokenizer_path) is not None else None
@@ -829,4 +870,28 @@ def _load_torchao_aten_lib(libname):
829870
torch.ops.load_library(libs[0])
830871

831872

873+
# We want to do compute the actual ops in the computation dtype, since the precision of the
874+
# quantized linear will initially be the dtype of the checkpoint.
875+
def set_8da4w_computation_dtype(
876+
module: nn.Module, computation_dtype: torch.dtype
877+
) -> nn.Module:
878+
879+
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
880+
881+
def _set_8da4w_computation_dtype(module: nn.Module, dtype: torch.dtype) -> None:
882+
"""
883+
Recursively iterate through the module and set the precision attributes
884+
of all Int8DynActInt4WeightLinears.
885+
"""
886+
for _name, child in module.named_children():
887+
if isinstance(child, Int8DynActInt4WeightLinear):
888+
child.precision = dtype
889+
else:
890+
# Recursively apply to child modules
891+
_set_8da4w_computation_dtype(child, dtype)
892+
893+
_set_8da4w_computation_dtype(module, computation_dtype)
894+
return module
895+
896+
832897
############################ Source Transform End #######################

examples/models/llava/export_llava.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def forward(self, input_pos, embeddings):
100100
args = parser.parse_args(
101101
["-X", "-qmode", "8da4w", "--group_size", "128", "--embedding-quantize", "4,32"]
102102
)
103-
quant_transform = get_quant_weight_transform(args, dtype_override, False)
103+
quant_transform = get_quant_weight_transform(args, dtype_override)
104104
_, quantizers, _ = get_quantizer_and_quant_params(args)
105105
source_transforms = []
106106
if llava.use_sdpa_with_kv_cache_op:

exir/tests/test_memory_planning.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -708,10 +708,10 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
708708
et_program = et.executorch_program
709709
inputs = et_program.execution_plan[0].inputs
710710
self.assertNotEqual(
711-
et_program.execution_plan[0] # pyre-ignore
711+
et_program.execution_plan[0]
712712
.values[inputs[0]]
713713
.val.allocation_info.memory_offset_low,
714-
et_program.execution_plan[0] # pyre-ignore
714+
et_program.execution_plan[0]
715715
.values[inputs[1]]
716716
.val.allocation_info.memory_offset_low,
717717
)

0 commit comments

Comments
 (0)