Skip to content

Commit 68ead1b

Browse files
committed
Fix logic + separate weight and activation dtypes
1 parent 5e0f501 commit 68ead1b

File tree

2 files changed

+56
-13
lines changed

2 files changed

+56
-13
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
get_vulkan_quantizer,
4747
)
4848
from executorch.util.activation_memory_profiler import generate_memory_trace
49+
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
4950

5051
from ..model_factory import EagerModelFactory
5152
from .source_transformation.apply_spin_quant_r1_r2 import (
@@ -57,6 +58,7 @@
5758
from .source_transformation.quantize import (
5859
get_quant_embedding_transform,
5960
get_quant_weight_transform,
61+
QuantizedGroupEmbedding,
6062
)
6163
from .source_transformation.quantized_kv_cache import (
6264
replace_kv_cache_with_custom_kv_cache,
@@ -593,24 +595,53 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
593595
dtype_override=dtype_override,
594596
args=args,
595597
)
598+
599+
# # Override dtype of the model as specified by the user args.
600+
# if dtype_override:
601+
# assert isinstance(
602+
# dtype_override, DType
603+
# ), "Override dtype needs to be of type <DType>"
604+
# torch_dtype = dtype_override.to_torch_dtype()
605+
# logging.info(f"model.to {torch_dtype}")
606+
# edge_manager.model = edge_manager.model.to(dtype=torch_dtype)
607+
# metadata_str=args.metadata,
608+
# dtype_override=dtype_override,
609+
# args=args,
610+
# )
611+
596612
# Assumes the checkpoint has uniform dtype.
597613
checkpoint_dtype = next(edge_manager.model.parameters()).dtype
598614
print(f"checkpoint dtype: {checkpoint_dtype}")
599-
# We want to quantize with the model in the checkpoint dtype before casting to dtype_override.
615+
# We want to quantize the weights of the model in the checkpoint dtype.
600616
edge_manager = edge_manager.set_output_dir(output_dir_path).source_transform(
601617
_get_source_transforms(
602618
args.model, DType.from_torch_dtype(checkpoint_dtype), args
603619
)
604620
)
605621

606-
# Override dtype of the model as specified by the user args.
607-
if dtype_override:
608-
assert isinstance(
609-
dtype_override, DType
610-
), "Override dtype needs to be of type <DType>"
611-
torch_dtype = dtype_override.to_torch_dtype()
612-
logging.info(f"model.to {torch_dtype}")
613-
edge_manager.model = edge_manager.model.to(dtype=torch_dtype)
622+
quantized = torch.load("/home/jackzhxng/torchrepos/executorch/fake_quantized_weights.pt")
623+
breakpoint()
624+
# torch.testing.assert_close()
625+
626+
# We want to do compute the actual ops in the precision of the dtype_override.
627+
def _set_precision_to_fp32(module):
628+
"""
629+
Recursively iterate through the module and set the precision attribute
630+
of all Int8DynActInt4WeightLinear submodules to 'fp32'.
631+
"""
632+
for name, child in module.named_children():
633+
if isinstance(child, Int8DynActInt4WeightLinear):
634+
# Change the precision attribute to 'fp32'
635+
child.precision = torch.float32
636+
print(f"Changed precision of {name} to torch.float32")
637+
elif isinstance(child, QuantizedGroupEmbedding):
638+
child.dtype = torch.float32
639+
print(f"Changed precision of {name} to torch.float32")
640+
else:
641+
# Recursively apply to child modules
642+
_set_precision_to_fp32(child)
643+
644+
_set_precision_to_fp32(edge_manager.model)
614645

615646
return edge_manager
616647

examples/models/llama/model.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def __init__(self, **kwargs):
5454
self.output_prune_map_path = kwargs.get("output_prune_map_path", None)
5555
self.max_seq_len = kwargs.get("max_seq_len", 128)
5656
self.max_context_len = kwargs.get("max_context_len", 128)
57+
self.dtype = kwargs.get("dtype_override", None)
5758
self.args = kwargs.get("args", None)
5859

5960
assert (
@@ -123,7 +124,7 @@ def __init__(self, **kwargs):
123124
)
124125

125126
# Get checkpoint dtype.
126-
self.dtype = get_checkpoint_dtype(checkpoint)
127+
self.checkpoint_dtype = get_checkpoint_dtype(checkpoint)
127128

128129
with open(params_path, "r") as f:
129130
params = json.loads(f.read())
@@ -171,7 +172,16 @@ def __init__(self, **kwargs):
171172
# Within the device="meta" context, tensors that are created do not carry data.
172173
# They possess all other metadata a tensor carries such as size, stride, requires_grad.
173174
with torch.device("meta"):
174-
self.model_ = Transformer(model_args).to(dtype=self.dtype)
175+
# Model itself is loaded in default dtype, fp32.
176+
self.model_ = Transformer(model_args)
177+
if self.dtype:
178+
self.model_.to(dtype=self.dtype)
179+
180+
# Convert the model's weights only to the checkpoint's dtype, so that
181+
# the checkpoint can be loaded into the model's state dict in its
182+
# own dtype w/o potential precision loss.
183+
for param in self.model_.parameters():
184+
param.data = param.data.to(dtype=self.checkpoint_dtype)
175185

176186
if "int8" in str(checkpoint_path):
177187
print("Using int8 weight-only quantization!")
@@ -265,10 +275,12 @@ def __init__(self, **kwargs):
265275
self.model_ = prune_output_vocab(self.model_, output_prune_map)
266276

267277
def get_eager_model(self) -> torch.nn.Module:
268-
if self.dtype:
278+
return self.model_
279+
280+
if self.checkpoint_dtype:
269281
# convert to the type of the provided checkpoint
270282
# input and output are torch.long, so signature unchanged
271-
return self.model_.to(self.dtype)
283+
return self.model_.to(self.checkpoint_dtype)
272284
else:
273285
# int8 quantization code has some bf16,
274286
# switch all to FP32

0 commit comments

Comments
 (0)