Skip to content

Commit e0b9234

Browse files
committed
Use model.to approach, forget edge case for checkpoint dtype > dtype_override precision
1 parent 5daaf19 commit e0b9234

File tree

2 files changed

+11
-20
lines changed

2 files changed

+11
-20
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,9 @@
5656

5757
from .source_transformation.attention import replace_attention_to_attention_sha
5858
from .source_transformation.quantize import (
59-
set_quantized_computation_dtype,
6059
get_quant_embedding_transform,
6160
get_quant_weight_transform,
61+
set_quantized_computation_dtype,
6262
)
6363
from .source_transformation.quantized_kv_cache import (
6464
replace_kv_cache_with_custom_kv_cache,
@@ -596,31 +596,24 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
596596

597597
# At this point, the model is loaded in the default fp32.
598598

599-
# Convert the non-weights of the model (the buffers) to the dtype_override.
600-
# Need to do this before source transform quantization since the quantized
601-
# parameters become buffers.
602-
for buf in edge_manager.model.buffers():
603-
buf.data = buf.data.to(dtype=dtype_override.to_torch_dtype())
599+
# TODO: some validation for the combination of checkpoint dtype and dtype_override.
600+
601+
edge_manager.model = edge_manager.model.to(dtype=dtype_override.to_torch_dtype())
604602

605603
# We want to quantize (in the source transforms) the weights of the model
606604
# in the checkpoint dtype.
607605
logging.info(f"Checkpoint dtype: {edge_manager.model.checkpoint_dtype}")
608606
edge_manager = edge_manager.set_output_dir(output_dir_path).source_transform(
609607
_get_source_transforms(
610-
args.model,
611-
dtype_override,
612-
DType.from_torch_dtype(edge_manager.model.checkpoint_dtype),
613-
args,
608+
modelname=args.model,
609+
dtype_override=dtype_override,
610+
checkpoint_dtype=DType.from_torch_dtype(
611+
edge_manager.model.checkpoint_dtype
612+
),
613+
args=args,
614614
)
615615
)
616616

617-
# Convert the parameters to the dtype_override.
618-
# If source transform quantization has already happened at this point (-qmode),
619-
# the quantized weights will become buffers and not be returned by .parameters(),
620-
# so we don't convert them to the dtype_override.
621-
for param in edge_manager.model.parameters():
622-
param.data = param.data.to(dtype=dtype_override.to_torch_dtype())
623-
624617
return edge_manager
625618

626619

examples/models/llama/source_transformation/quantize.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -845,9 +845,7 @@ def _load_torchao_aten_lib(libname):
845845
# We want to do compute the actual ops in the dtype of the dtype_override,
846846
# since the precision of the quantized linear will initially be the dtype of the
847847
# checkpoint, not the dtype_override.
848-
def set_quantized_computation_dtype(
849-
module: nn.Module, dtype: torch.dtype
850-
) -> nn.Module:
848+
def set_quantized_computation_dtype(module: nn.Module, dtype: torch.dtype) -> nn.Module:
851849
def _set_quantized_computation_dtype_rec(
852850
module: nn.Module, dtype: torch.dtype
853851
) -> None:

0 commit comments

Comments
 (0)