Skip to content

Commit 91c0d0c

Browse files
jackzhxngfacebook-github-bot
authored andcommitted
Fix xnnpack quantization discrepancy for non-fp32 (#8488)
Summary: Perform quantization on the weights expressed in their original dtype (from the checkpoint) by performing source transformations before dtype cast. Previously the model was being converted to the `dtype_override` arg's dtype and then quantized. This eliminates supposedly eliminates quantization noise. Note - no need to worry about https://github.com/pytorch/ao/blob/main/torchao/quantization/GPTQ.py#L1168, precision is passed in with the checkpoint dtype ### Comparison of arbitrary q_proj tensor from sample Llama checkpoint: Before: ``` Mismatched elements: 3260378 / 4194304 (77.7%) Greatest absolute difference: 0.08802086114883423 at index (1129, 604) (up to 1e-05 allowed) Greatest relative difference: 1.0 at index (0, 1350) (up to 1.3e-06 allowed) Signal-to-noise: 32.8974 dB ``` After: no difference Test Plan: ### Manual testing ``` python -m examples.models.llama.export_llama \ -v -c xl_consolidated/consolidated_renamed.pth \ -p xl_consolidated/et_params.json -kv -d fp32 \ -qmode 8da4w --group_size 32 -X \ --use_sdpa_with_kv_cache \ --output_name quantized_baseline.pte \ --max_context_length 4096 -E 4,32 ``` With the following inserted after the quantization: ``` edge_manager.model( torch.tensor([[2, 3, 4]], dtype=torch.long), {"input_pos": torch.tensor([0], dtype=torch.long)}, ) ``` And the following modifications to GPTQ.py in torchao: pytorch/ao#1756 for testing. ### Automated testing + existing CI tests ### Regression testing TBD Differential Revision: D70184325 Pulled By: jackzhxng
1 parent dedfdaf commit 91c0d0c

File tree

6 files changed

+147
-84
lines changed

6 files changed

+147
-84
lines changed

examples/models/checkpoint.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from pathlib import Path
1010
from typing import Any, Dict, Optional
1111

12+
import torch
13+
1214

1315
def get_default_model_resource_dir(model_file_path: str) -> Path:
1416
"""
@@ -52,7 +54,7 @@ def get_default_model_resource_dir(model_file_path: str) -> Path:
5254
return resource_dir
5355

5456

55-
def get_checkpoint_dtype(checkpoint: Dict[str, Any]) -> Optional[str]:
57+
def get_checkpoint_dtype(checkpoint: Dict[str, Any]) -> Optional[torch.dtype]:
5658
"""
5759
Get the dtype of the checkpoint, returning "None" if the checkpoint is empty.
5860
"""

examples/models/llama/export_llama_lib.py

Lines changed: 71 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import re
1616
import shlex
1717
from enum import Enum
18+
from functools import partial
1819
from json import JSONDecodeError
1920
from pathlib import Path
2021
from typing import Callable, List, Optional, Union
@@ -55,6 +56,7 @@
5556

5657
from .source_transformation.attention import replace_attention_to_attention_sha
5758
from .source_transformation.quantize import (
59+
set_quantized_computation_dtype,
5860
get_quant_embedding_transform,
5961
get_quant_weight_transform,
6062
)
@@ -563,43 +565,63 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
563565
output_dir_path = canonical_path(args.output_dir, dir=True)
564566
weight_type = WeightType.FAIRSEQ2 if args.fairseq2 else WeightType.LLAMA
565567

566-
# dtype override
567-
if args.dtype_override is not None:
568-
dtype_override = DType[args.dtype_override]
569-
elif args.quantization_mode in ["8da4w", "8da4w-gptq"]:
570-
dtype_override = DType["fp16"]
571-
else:
572-
dtype_override = None
568+
# Convert dtype override string arg to actual type.
569+
dtype_override = DType[args.dtype_override]
570+
571+
edge_manager = _load_llama_model(
572+
args.model,
573+
checkpoint=checkpoint_path,
574+
checkpoint_dir=checkpoint_dir,
575+
params_path=params_path,
576+
use_kv_cache=args.use_kv_cache,
577+
use_sdpa_with_kv_cache=args.use_sdpa_with_kv_cache,
578+
generate_full_logits=args.generate_full_logits,
579+
weight_type=weight_type,
580+
enable_dynamic_shape=args.enable_dynamic_shape,
581+
calibration_tasks=args.calibration_tasks,
582+
calibration_limit=args.calibration_limit,
583+
calibration_seq_length=args.calibration_seq_length,
584+
calibration_data=args.calibration_data,
585+
tokenizer_path=args.tokenizer_path,
586+
verbose=args.verbose,
587+
max_seq_len=args.max_seq_length,
588+
max_context_len=args.max_context_length,
589+
input_prune_map_path=args.input_prune_map,
590+
output_prune_map_path=args.output_prune_map,
591+
metadata_str=args.metadata,
592+
dtype_override=dtype_override,
593+
args=args,
594+
)
595+
596+
# At this point, the model is loaded in the default fp32.
597+
598+
# Convert the non-weights of the model (the buffers) to the dtype_override.
599+
# Need to do this before source transform quantization since the quantized
600+
# parameters become buffers.
601+
for buf in edge_manager.model.buffers():
602+
buf.data = buf.data.to(dtype=dtype_override.to_torch_dtype())
573603

574-
return (
575-
_load_llama_model(
604+
# We want to quantize (in the source transforms) the weights of the model
605+
# in the checkpoint dtype.
606+
logging.info(f"Checkpoint dtype: {edge_manager.model.checkpoint_dtype}")
607+
edge_manager = edge_manager.set_output_dir(output_dir_path).source_transform(
608+
_get_source_transforms(
576609
args.model,
577-
checkpoint=checkpoint_path,
578-
checkpoint_dir=checkpoint_dir,
579-
params_path=params_path,
580-
use_kv_cache=args.use_kv_cache,
581-
use_sdpa_with_kv_cache=args.use_sdpa_with_kv_cache,
582-
generate_full_logits=args.generate_full_logits,
583-
weight_type=weight_type,
584-
enable_dynamic_shape=args.enable_dynamic_shape,
585-
calibration_tasks=args.calibration_tasks,
586-
calibration_limit=args.calibration_limit,
587-
calibration_seq_length=args.calibration_seq_length,
588-
calibration_data=args.calibration_data,
589-
tokenizer_path=args.tokenizer_path,
590-
verbose=args.verbose,
591-
max_seq_len=args.max_seq_length,
592-
max_context_len=args.max_context_length,
593-
input_prune_map_path=args.input_prune_map,
594-
output_prune_map_path=args.output_prune_map,
595-
metadata_str=args.metadata,
596-
dtype_override=dtype_override,
597-
args=args,
610+
dtype_override,
611+
DType.from_torch_dtype(edge_manager.model.checkpoint_dtype),
612+
args,
598613
)
599-
.set_output_dir(output_dir_path)
600-
.source_transform(_get_source_transforms(args.model, dtype_override, args))
601614
)
602615

616+
# Convert the parameters to the dtype_override.
617+
# If source transform quantization has already happened at this point (-qmode),
618+
# the quantized weights will become buffers and not be returned by .parameters(),
619+
# so we don't convert them to the dtype_override.
620+
for param in edge_manager.model.parameters():
621+
param.data = param.data.to(dtype=dtype_override.to_torch_dtype())
622+
623+
return edge_manager
624+
603625

604626
def get_quantizer_and_quant_params(args):
605627
pt2e_quant_params = get_pt2e_quantization_params(
@@ -783,8 +805,6 @@ def _to_edge_and_lower_llama( # noqa: C901
783805
shares=args.num_sharding,
784806
)
785807

786-
from functools import partial
787-
788808
# pyre-ignore
789809
from executorch.backends.qualcomm.quantizer.custom_annotation import (
790810
get_custom_quant_ios_dtype,
@@ -1004,6 +1024,8 @@ def _load_llama_model(
10041024
else:
10051025
raise ValueError(f"{modelname} is not a valid Llama model.")
10061026

1027+
torch_dtype = dtype_override.to_torch_dtype() if dtype_override else None
1028+
10071029
model, example_inputs, example_kwarg_inputs, dynamic_shapes = (
10081030
EagerModelFactory.create_model(
10091031
module_name,
@@ -1020,41 +1042,16 @@ def _load_llama_model(
10201042
enable_dynamic_shape=enable_dynamic_shape,
10211043
input_prune_map_path=input_prune_map_path,
10221044
output_prune_map_path=output_prune_map_path,
1045+
dtype=torch_dtype,
10231046
args=args,
10241047
)
10251048
)
1026-
if dtype_override:
1027-
assert isinstance(
1028-
dtype_override, DType
1029-
), "Override dtype needs to be of type <DType>"
1030-
torch_dtype = dtype_override.to_torch_dtype()
1031-
logging.info(f"model.to {torch_dtype}")
1032-
model = model.to(dtype=torch_dtype)
1033-
dtype = dtype_override
1034-
else:
1035-
state_dict = model.state_dict()
1036-
dtype = state_dict[next(iter(state_dict))].dtype
1037-
assert dtype in [
1038-
torch.bfloat16,
1039-
torch.float16,
1040-
torch.float32,
1041-
], f"Only support bfloat16, fp16 or fp32 got {dtype}"
1042-
logging.info(f"Loaded model with dtype={dtype}")
1043-
1044-
if dtype == torch.bfloat16:
1045-
dtype = DType.bf16
1046-
elif dtype == torch.float16:
1047-
dtype = DType.fp16
1048-
elif dtype == torch.float32:
1049-
dtype = DType.fp32
1050-
else:
1051-
raise ValueError(f"Unsupported dtype {dtype}")
10521049

10531050
return LLMEdgeManager(
10541051
model=model,
10551052
modelname=modelname,
10561053
max_seq_len=model.max_seq_len,
1057-
dtype=dtype,
1054+
dtype=dtype_override,
10581055
use_kv_cache=use_kv_cache,
10591056
generate_full_logits=generate_full_logits,
10601057
example_inputs=example_inputs,
@@ -1091,7 +1088,10 @@ def _load_llama_model(
10911088

10921089

10931090
def _get_source_transforms( # noqa
1094-
modelname: str, dtype_override: Optional[DType], args
1091+
modelname: str,
1092+
dtype_override: DType,
1093+
checkpoint_dtype: Optional[DType],
1094+
args,
10951095
) -> List[Callable[[torch.nn.Module], torch.nn.Module]]:
10961096
transforms = []
10971097

@@ -1125,7 +1125,7 @@ def _get_source_transforms( # noqa
11251125
"""
11261126
modelname = f"{modelname}_q"
11271127
transforms.append(
1128-
get_quant_weight_transform(args, dtype_override, verbose_export())
1128+
get_quant_weight_transform(args, checkpoint_dtype, verbose_export())
11291129
)
11301130

11311131
if args.embedding_quantize:
@@ -1139,7 +1139,14 @@ def _get_source_transforms( # noqa
11391139
this wil be a no-op.
11401140
"""
11411141
modelname = f"{modelname}_e"
1142-
transforms.append(get_quant_embedding_transform(args))
1142+
transforms.append(get_quant_embedding_transform(args, checkpoint_dtype))
1143+
1144+
if args.quantization_mode or args.embedding_quantize:
1145+
transforms.append(
1146+
partial(
1147+
set_quantized_computation_dtype, dtype=dtype_override.to_torch_dtype()
1148+
)
1149+
)
11431150

11441151
if args.expand_rope_table:
11451152
transforms.append(materialze_broadcast_of_rope_freq_cis)

examples/models/llama/model.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,6 @@ def __init__(self, **kwargs):
122122
"""
123123
)
124124

125-
# Get checkpoint dtype.
126-
self.dtype = get_checkpoint_dtype(checkpoint)
127-
128125
with open(params_path, "r") as f:
129126
params = json.loads(f.read())
130127
output_prune_map = None
@@ -171,7 +168,9 @@ def __init__(self, **kwargs):
171168
# Within the device="meta" context, tensors that are created do not carry data.
172169
# They possess all other metadata a tensor carries such as size, stride, requires_grad.
173170
with torch.device("meta"):
171+
# Model itself is loaded in default dtype, fp32.
174172
self.model_ = Transformer(model_args)
173+
self.model_.checkpoint_dtype = get_checkpoint_dtype(checkpoint)
175174

176175
if "int8" in str(checkpoint_path):
177176
print("Using int8 weight-only quantization!")
@@ -241,6 +240,10 @@ def __init__(self, **kwargs):
241240
# assign=True: load params/buffers by assignment instead of performing an in-place copy.
242241
# Because we are using device="meta", tensors do not have memory associated with them
243242
# and an in-place copy is a no-op. Use assign=True in load_state_dict for this scenario.
243+
244+
# Also, the checkpoint is loaded and dtype promoted to the transformer's dtype, which is
245+
# by default initialized to fp32. This is fine because every other supported type
246+
# losslessly converts to fp32, so we don't lose precision here.
244247
missing, unexpected = self.model_.load_state_dict(
245248
checkpoint,
246249
strict=False,
@@ -277,14 +280,7 @@ def __init__(self, **kwargs):
277280
self.model_ = prune_output_vocab(self.model_, output_prune_map)
278281

279282
def get_eager_model(self) -> torch.nn.Module:
280-
if self.dtype:
281-
# convert to the type of the provided checkpoint
282-
# input and output are torch.long, so signature unchanged
283-
return self.model_.to(self.dtype)
284-
else:
285-
# int8 quantization code has some bf16,
286-
# switch all to FP32
287-
return self.model_.to(torch.float32)
283+
return self.model_
288284

289285
def get_example_inputs(self):
290286
if self.use_kv_cache:

0 commit comments

Comments
 (0)