Skip to content

Commit 3c6c1ab

Browse files
author
Martin Yuan
committed
Refactor LLMEdgeManager's to_dtype
1 parent c890809 commit 3c6c1ab

File tree

2 files changed

+25
-19
lines changed

2 files changed

+25
-19
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -594,25 +594,9 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
594594
)
595595

596596
# At this point, the model is loaded in the default fp32.
597-
598-
# Checkpoint dtype should be lower or equal precision to the dtype override.
597+
# override dtype
599598
checkpoint_dtype = edge_manager.model.checkpoint_dtype
600-
if not (
601-
checkpoint_dtype == dtype_override.to_torch_dtype()
602-
or (
603-
checkpoint_dtype == torch.float16
604-
and dtype_override.to_torch_dtype() == torch.float32
605-
)
606-
or (
607-
checkpoint_dtype == torch.bfloat16
608-
and dtype_override.to_torch_dtype() == torch.float32
609-
)
610-
):
611-
logging.warning(
612-
f"Checkpoint dtype {checkpoint_dtype} precision is higher than dtype override {dtype_override.to_torch_dtype()}."
613-
)
614-
615-
edge_manager.model = edge_manager.model.to(dtype=dtype_override.to_torch_dtype())
599+
edge_manager.to_dtype(dtype_override)
616600

617601
# We want to quantize (in the source transforms) the weights of the model
618602
# in the checkpoint dtype.

extension/llm/export/builder.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,29 @@ def to_dtype(self, dtype_override: Optional[DType]) -> "LLMEdgeManager":
147147
assert not dtype_override or isinstance(
148148
dtype_override, DType
149149
), "Override dtype needs to be of type <DType>"
150-
if dtype_override is not None and dtype_override != self.dtype:
150+
151+
# Checkpoint dtype should be lower or equal precision to the dtype override.
152+
checkpoint_dtype = (
153+
self.model.checkpoint_dtype
154+
if hasattr(self.model, "checkpoint_dtype")
155+
else None
156+
)
157+
if not (
158+
checkpoint_dtype == dtype_override.to_torch_dtype()
159+
or (
160+
checkpoint_dtype == torch.float16
161+
and dtype_override.to_torch_dtype() == torch.float32
162+
)
163+
or (
164+
checkpoint_dtype == torch.bfloat16
165+
and dtype_override.to_torch_dtype() == torch.float32
166+
)
167+
):
168+
logging.warning(
169+
f"Checkpoint dtype {checkpoint_dtype} precision is higher than dtype override {dtype_override.to_torch_dtype()}."
170+
)
171+
172+
if dtype_override != self.dtype:
151173
torch_dtype = dtype_override.to_torch_dtype()
152174
logging.info(f"model.to {torch_dtype}")
153175
self.model = self.model.to(dtype=torch_dtype)

0 commit comments

Comments
 (0)