Skip to content

Commit 9192d58

Browse files
Martin Yuanfacebook-github-bot
authored andcommitted
Refactor LLMEdgeManager's to_dtype (#9542)
Summary: Small refactor to aggregate dtype override in to_dtype method. Test Plan: CI Reviewed By: billmguo Differential Revision: D71736608 Pulled By: iseeyuan
1 parent 012f120 commit 9192d58

File tree

2 files changed

+22
-19
lines changed

2 files changed

+22
-19
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -618,25 +618,9 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
618618
)
619619

620620
# At this point, the model is loaded in the default fp32.
621-
622-
# Checkpoint dtype should be lower or equal precision to the dtype override.
621+
# override dtype
623622
checkpoint_dtype = edge_manager.model.checkpoint_dtype
624-
if not (
625-
checkpoint_dtype == dtype_override.to_torch_dtype()
626-
or (
627-
checkpoint_dtype == torch.float16
628-
and dtype_override.to_torch_dtype() == torch.float32
629-
)
630-
or (
631-
checkpoint_dtype == torch.bfloat16
632-
and dtype_override.to_torch_dtype() == torch.float32
633-
)
634-
):
635-
logging.warning(
636-
f"Checkpoint dtype {checkpoint_dtype} precision is higher than dtype override {dtype_override.to_torch_dtype()}."
637-
)
638-
639-
edge_manager.model = edge_manager.model.to(dtype=dtype_override.to_torch_dtype())
623+
edge_manager.to_dtype(dtype_override)
640624

641625
# We want to quantize (in the source transforms) the weights of the model
642626
# in the checkpoint dtype.

extension/llm/export/builder.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,26 @@ def to_dtype(self, dtype_override: Optional[DType]) -> "LLMEdgeManager":
147147
assert not dtype_override or isinstance(
148148
dtype_override, DType
149149
), "Override dtype needs to be of type <DType>"
150-
if dtype_override is not None and dtype_override != self.dtype:
150+
151+
# Checkpoint dtype should be lower or equal precision to the dtype override.
152+
if hasattr(self.model, "checkpoint_dtype"):
153+
checkpoint_dtype = self.model.checkpoint_dtype
154+
if not (
155+
checkpoint_dtype == dtype_override.to_torch_dtype()
156+
or (
157+
checkpoint_dtype == torch.float16
158+
and dtype_override.to_torch_dtype() == torch.float32
159+
)
160+
or (
161+
checkpoint_dtype == torch.bfloat16
162+
and dtype_override.to_torch_dtype() == torch.float32
163+
)
164+
):
165+
logging.warning(
166+
f"Checkpoint dtype {checkpoint_dtype} precision is higher than dtype override {dtype_override.to_torch_dtype()}."
167+
)
168+
169+
if dtype_override != self.dtype:
151170
torch_dtype = dtype_override.to_torch_dtype()
152171
logging.info(f"model.to {torch_dtype}")
153172
self.model = self.model.to(dtype=torch_dtype)

0 commit comments

Comments
 (0)