Skip to content

Commit 5e0f501

Browse files
committed
Fixes
1 parent cf0cb9d commit 5e0f501

File tree

2 files changed

+37
-5
lines changed

2 files changed

+37
-5
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -561,11 +561,13 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
561561
output_dir_path = canonical_path(args.output_dir, dir=True)
562562
weight_type = WeightType.FAIRSEQ2 if args.fairseq2 else WeightType.LLAMA
563563

564-
# Conver dtype override string to actual type.
565-
if args.quantization_mode in ["8da4w", "8da4w-gptq"]:
564+
# Convert dtype override string to actual type.
565+
if args.dtype_override is not None:
566+
dtype_override = DType[args.dtype_override]
567+
elif args.quantization_mode in ["8da4w", "8da4w-gptq"]:
566568
dtype_override = DType["fp16"]
567569
else:
568-
dtype_override = DType[args.dtype_override]
570+
dtype_override = None
569571

570572
edge_manager = _load_llama_model(
571573
args.model,
@@ -590,7 +592,16 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
590592
metadata_str=args.metadata,
591593
dtype_override=dtype_override,
592594
args=args,
593-
).set_output_dir(output_dir_path).source_transform(_get_source_transforms(args.model, dtype_override, args))
595+
)
596+
# Assumes the checkpoint has uniform dtype.
597+
checkpoint_dtype = next(edge_manager.model.parameters()).dtype
598+
print(f"checkpoint dtype: {checkpoint_dtype}")
599+
# We want to quantize with the model in the checkpoint dtype before casting to dtype_override.
600+
edge_manager = edge_manager.set_output_dir(output_dir_path).source_transform(
601+
_get_source_transforms(
602+
args.model, DType.from_torch_dtype(checkpoint_dtype), args
603+
)
604+
)
594605

595606
# Override dtype of the model as specified by the user args.
596607
if dtype_override:
@@ -977,11 +988,21 @@ def _load_llama_model(
977988
)
978989
)
979990

991+
if dtype_override:
992+
assert isinstance(
993+
dtype_override, DType
994+
), "Override dtype needs to be of type <DType>"
995+
dtype = dtype_override
996+
else:
997+
checkpoint_dtype = next(model.parameters()).dtype
998+
dtype = DType.from_torch_dtype(checkpoint_dtype)
999+
logging.info(f"Loaded model with dtype={dtype}")
1000+
9801001
return LLMEdgeManager(
9811002
model=model,
9821003
modelname=modelname,
9831004
max_seq_len=model.max_seq_len,
984-
dtype=dtype_override,
1005+
dtype=dtype,
9851006
use_kv_cache=use_kv_cache,
9861007
generate_full_logits=generate_full_logits,
9871008
example_inputs=example_inputs,

extension/llm/export/builder.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,17 @@ def to_torch_dtype(self) -> torch.dtype:
6161
raise ValueError(f"Unsupported dtype {self}")
6262
return mapping[self]
6363

64+
@staticmethod
65+
def from_torch_dtype(dtype: torch.dtype):
66+
mapping = {
67+
torch.float32: DType.fp32,
68+
torch.float16: DType.fp16,
69+
torch.bfloat16: DType.bf16,
70+
}
71+
if dtype not in mapping:
72+
raise ValueError(f"Unsupported torch.dtype {dtype}")
73+
return mapping[dtype]
74+
6475

6576
class LLMEdgeManager:
6677
"""

0 commit comments

Comments
 (0)