Skip to content

Commit 82d748d

Browse files
committed
Cover case where some weights are unquantized
1 parent a73f107 commit 82d748d

File tree

4 files changed

+83
-61
lines changed

4 files changed

+83
-61
lines changed

examples/models/checkpoint.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from pathlib import Path
1010
from typing import Any, Dict, Optional
1111

12+
import torch
13+
1214

1315
def get_default_model_resource_dir(model_file_path: str) -> Path:
1416
"""
@@ -52,7 +54,7 @@ def get_default_model_resource_dir(model_file_path: str) -> Path:
5254
return resource_dir
5355

5456

55-
def get_checkpoint_dtype(checkpoint: Dict[str, Any]) -> Optional[str]:
57+
def get_checkpoint_dtype(checkpoint: Dict[str, Any]) -> Optional[torch.dtype]:
5658
"""
5759
Get the dtype of the checkpoint, returning "None" if the checkpoint is empty.
5860
"""

examples/models/llama/export_llama_lib.py

Lines changed: 30 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import re
1616
import shlex
1717
from enum import Enum
18+
from functools import partial
1819
from json import JSONDecodeError
1920
from pathlib import Path
2021
from typing import Callable, List, Optional, Union
@@ -597,20 +598,31 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
597598
args=args,
598599
)
599600

600-
# Assumes the checkpoint has uniform dtype.
601-
checkpoint_dtype = next(edge_manager.model.parameters()).dtype
602-
logging.info(f"checkpoint dtype: {checkpoint_dtype}")
603-
# We want to quantize the weights of the model in the checkpoint dtype.
601+
# At this point, the model is loaded in the default fp32.
602+
603+
# Convert the non-weights of the model (the buffers) to the dtype_override.
604+
# Need to do this before source transform quantization since the quantized
605+
# parameters become buffers.
606+
for buf in edge_manager.model.buffers():
607+
buf.data = buf.data.to(dtype=dtype_override.to_torch_dtype())
608+
609+
# We want to quantize (in the source transforms) the weights of the model
610+
# in the checkpoint dtype.
611+
logging.info(f"Checkpoint dtype: {edge_manager.model.checkpoint_dtype}")
604612
edge_manager = edge_manager.set_output_dir(output_dir_path).source_transform(
605613
_get_source_transforms(
606-
args.model, DType.from_torch_dtype(checkpoint_dtype), args
614+
args.model,
615+
DType.from_torch_dtype(edge_manager.model.checkpoint_dtype),
616+
args,
607617
)
608618
)
609619

610-
_set_quantized_computation_dtype(
611-
edge_manager.model,
612-
dtype_override.to_torch_dtype(), # pyre-ignore[16]
613-
)
620+
# Convert the parameters to the dtype_override.
621+
# If source transform quantization has already happened at this point (-qmode),
622+
# the quantized weights will become buffers and not be returned by .parameters(),
623+
# so we don't convert them to the dtype_override.
624+
for param in edge_manager.model.parameters():
625+
param.data = param.data.to(dtype=dtype_override.to_torch_dtype())
614626

615627
return edge_manager
616628

@@ -785,8 +797,6 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
785797
shares=args.num_sharding,
786798
)
787799

788-
from functools import partial
789-
790800
# pyre-ignore
791801
from executorch.backends.qualcomm.quantizer.custom_annotation import (
792802
get_custom_quant_ios_dtype,
@@ -989,21 +999,11 @@ def _load_llama_model(
989999
)
9901000
)
9911001

992-
if dtype_override:
993-
assert isinstance(
994-
dtype_override, DType
995-
), "Override dtype needs to be of type <DType>"
996-
dtype = dtype_override
997-
else:
998-
checkpoint_dtype = next(model.parameters()).dtype
999-
dtype = DType.from_torch_dtype(checkpoint_dtype)
1000-
logging.info(f"Loaded model with dtype={dtype}")
1001-
10021002
return LLMEdgeManager(
10031003
model=model,
10041004
modelname=modelname,
10051005
max_seq_len=model.max_seq_len,
1006-
dtype=dtype,
1006+
dtype=dtype_override,
10071007
use_kv_cache=use_kv_cache,
10081008
generate_full_logits=generate_full_logits,
10091009
example_inputs=example_inputs,
@@ -1088,7 +1088,14 @@ def _get_source_transforms( # noqa
10881088
this wil be a no-op.
10891089
"""
10901090
modelname = f"{modelname}_e"
1091-
transforms.append(get_quant_embedding_transform(args))
1091+
transforms.append(get_quant_embedding_transform(args, dtype_override))
1092+
1093+
if args.quantization_mode or args.embedding_quantize:
1094+
transforms.append(
1095+
partial(
1096+
_set_quantized_computation_dtype, dtype=dtype_override.to_torch_dtype()
1097+
)
1098+
)
10921099

10931100
if args.expand_rope_table:
10941101
transforms.append(materialze_broadcast_of_rope_freq_cis)

examples/models/llama/model.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@ def __init__(self, **kwargs):
5454
self.output_prune_map_path = kwargs.get("output_prune_map_path", None)
5555
self.max_seq_len = kwargs.get("max_seq_len", 128)
5656
self.max_context_len = kwargs.get("max_context_len", 128)
57-
self.dtype = kwargs.get("dtype", None)
5857
self.args = kwargs.get("args", None)
5958

6059
assert (
@@ -123,9 +122,6 @@ def __init__(self, **kwargs):
123122
"""
124123
)
125124

126-
# Get checkpoint dtype.
127-
self.checkpoint_dtype = get_checkpoint_dtype(checkpoint)
128-
129125
with open(params_path, "r") as f:
130126
params = json.loads(f.read())
131127
output_prune_map = None
@@ -174,14 +170,7 @@ def __init__(self, **kwargs):
174170
with torch.device("meta"):
175171
# Model itself is loaded in default dtype, fp32.
176172
self.model_ = Transformer(model_args)
177-
if self.dtype:
178-
self.model_.to(dtype=self.dtype)
179-
180-
# Convert the model's weights only to the checkpoint's dtype, so that
181-
# the checkpoint can be loaded into the model's state dict in its
182-
# own dtype w/o potential precision loss.
183-
for param in self.model_.parameters():
184-
param.data = param.data.to(dtype=self.checkpoint_dtype)
173+
self.model_.checkpoint_dtype = get_checkpoint_dtype(checkpoint)
185174

186175
if "int8" in str(checkpoint_path):
187176
print("Using int8 weight-only quantization!")
@@ -251,6 +240,10 @@ def __init__(self, **kwargs):
251240
# assign=True: load params/buffers by assignment instead of performing an in-place copy.
252241
# Because we are using device="meta", tensors do not have memory associated with them
253242
# and an in-place copy is a no-op. Use assign=True in load_state_dict for this scenario.
243+
244+
# Also, the checkpoint is loaded and dtype promoted to the transformer's dtype, which is
245+
# by default initialized to fp32. This is fine because every other supported type
246+
# losslessly converts to fp32, so we don't lose precision here.
254247
missing, unexpected = self.model_.load_state_dict(
255248
checkpoint,
256249
strict=False,

examples/models/llama/source_transformation/quantize.py

Lines changed: 45 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,9 @@ def quantize( # noqa C901
7272

7373
if qmode == "int8":
7474
# Add quantization mode options here: group size, bit width, etc.
75-
return WeightOnlyInt8QuantHandler(model).quantized_model()
75+
return WeightOnlyInt8QuantHandler(
76+
model, precision=torch_dtype
77+
).quantized_model()
7678
elif qmode.startswith("torchao:fpa"):
7779
pattern = r"torchao:fpa(\d+)w"
7880
matches = re.findall(pattern, qmode)
@@ -85,7 +87,7 @@ def quantize( # noqa C901
8587
model = (
8688
UIntxWeightOnlyLinearQuantizer(
8789
device="mps",
88-
precision=torch.float32,
90+
precision=torch_dtype,
8991
groupsize=group_size,
9092
bitwidth=bitwidth,
9193
)
@@ -107,7 +109,7 @@ def quantize( # noqa C901
107109
with torch.no_grad():
108110
model = Int8DynActIntxWeightLinearQuantizer(
109111
device="cpu",
110-
precision=torch.float32,
112+
precision=torch_dtype,
111113
groupsize=group_size,
112114
bitwidth=bitwidth,
113115
has_weight_zeros=False,
@@ -346,6 +348,7 @@ def __init__(
346348
node_type: str = "*",
347349
bitwidth: Optional[int] = None,
348350
group_size: Optional[int] = None,
351+
precision: torch.dtype = torch.float32,
349352
):
350353
self.mod = mod
351354
self.group_size = group_size
@@ -354,6 +357,7 @@ def __init__(
354357
self.bitwidth = 8
355358
else:
356359
self.bitwidth = bitwidth
360+
self.precision = precision
357361

358362
@torch.no_grad()
359363
def create_quantized_state_dict(self) -> Dict:
@@ -389,7 +393,7 @@ def create_quantized_state_dict(self) -> Dict:
389393

390394
# print(f"expanded weight shape {input_weight.shape}")
391395
weight, scales, _ = dynamically_quantize_per_channel(
392-
input_weight,
396+
input_weight.to(dtype=self.precision),
393397
range_min,
394398
range_max,
395399
torch.int8,
@@ -574,6 +578,7 @@ def __init__(
574578
bitwidth: int = 8,
575579
group_size: Optional[int] = None,
576580
packed=False,
581+
precision: Optional[torch.dtype] = None,
577582
):
578583
if isinstance(packed, str):
579584
packed = packed == "True"
@@ -582,6 +587,8 @@ def __init__(
582587
self.group_size = group_size
583588
self.bitwidth = bitwidth
584589
self.packed = packed
590+
# Dtype of the weights right before quantization.
591+
self.precision = precision
585592
if (bitwidth not in [2, 4]) and packed:
586593
raise RuntimeError("pack only works with bitsize 2, 4")
587594

@@ -612,7 +619,11 @@ def create_quantized_state_dict(self, packed=False) -> Dict:
612619
f"quantize {fqn, mod} with group_size {self.group_size}, bitwidth {self.bitwidth}"
613620
)
614621
weight, scales, _ = dynamically_quantize_per_channel(
615-
mod.weight,
622+
(
623+
mod.weight.to(dtype=self.precision)
624+
if self.precision
625+
else mod.weight
626+
),
616627
range_min,
617628
range_max,
618629
torch.int8,
@@ -748,7 +759,7 @@ def forward(self, indices: torch.Tensor) -> torch.Tensor:
748759
############################ Source Transform Start #######################
749760

750761

751-
def get_quant_embedding_transform(args):
762+
def get_quant_embedding_transform(args, dtype_override: Optional[DType] = None):
752763
if args.embedding_quantize.startswith("torchao:"):
753764
bitwidth, group_size = args.embedding_quantize.split(":")[1].split(",")
754765
group_size = int(group_size)
@@ -774,11 +785,13 @@ def _torchao_embedding_quantizer(model):
774785
else:
775786
group_size = int(group_size)
776787
bitwidth = int(bitwidth)
788+
torch_dtype = dtype_override.to_torch_dtype() if dtype_override else None
777789
return lambda model: EmbeddingQuantHandler(
778790
model,
779791
bitwidth=bitwidth,
780792
group_size=group_size,
781793
packed=(bitwidth in [2, 4]),
794+
precision=torch_dtype,
782795
).quantized_model()
783796

784797

@@ -831,25 +844,32 @@ def _load_torchao_aten_lib(libname):
831844
# We want to do compute the actual ops in the dtype of the dtype_override,
832845
# since the precision of the quantized linear will initially be the dtype of the
833846
# checkpoint, not the dtype_override.
834-
# TODO(#8652): this is a temporary solution for until we can support the new ao,
835-
# quantize_ api, which apparently can support different dtypes at quantization and
836-
# computation.
837-
def _set_quantized_computation_dtype(module: nn.Module, dtype: torch.dtype):
838-
"""
839-
Recursively iterate through the module and set the dtype/precision attributes
840-
of all Int8DynActInt4WeightLinear and QuantizedGroupEmbedding submodules to 'fp32'.
841-
"""
842-
for name, child in module.named_children():
843-
if isinstance(child, Int8DynActInt4WeightLinear):
844-
# Change the precision attribute to 'fp32'
845-
child.precision = dtype
846-
print(f"Changed precision of {name} to {dtype}")
847-
elif isinstance(child, QuantizedGroupEmbedding):
848-
child.dtype = dtype
849-
print(f"Changed precision of {name} to {dtype}")
850-
else:
851-
# Recursively apply to child modules
852-
_set_quantized_computation_dtype(child, dtype)
847+
def _set_quantized_computation_dtype(
848+
module: nn.Module, dtype: torch.dtype
849+
) -> nn.Module:
850+
def _set_quantized_computation_dtype_rec(
851+
module: nn.Module, dtype: torch.dtype
852+
) -> None:
853+
"""
854+
Recursively iterate through the module and set the dtype/precision attributes
855+
of all Int8DynActInt4WeightLinear and QuantizedGroupEmbedding submodules to 'fp32'.
856+
"""
857+
for name, child in module.named_children():
858+
if isinstance(child, Int8DynActInt4WeightLinear):
859+
# Change the precision attribute to 'fp32'
860+
child.precision = dtype
861+
print(f"Changed precision of {name} to {dtype}")
862+
elif isinstance(child, QuantizedGroupEmbedding):
863+
child.dtype = dtype
864+
print(f"Changed precision of {name} to {dtype}")
865+
elif isinstance(child, WeightOnlyInt8Linear):
866+
child.dtype = dtype
867+
else:
868+
# Recursively apply to child modules
869+
_set_quantized_computation_dtype_rec(child, dtype)
870+
871+
_set_quantized_computation_dtype_rec(module, dtype)
872+
return module
853873

854874

855875
############################ Source Transform End #######################

0 commit comments

Comments
 (0)